sitammeur commited on
Commit
e05119b
·
verified ·
1 Parent(s): df22269

Update src/app/response.py

Browse files
Files changed (1) hide show
  1. src/app/response.py +13 -8
src/app/response.py CHANGED
@@ -17,13 +17,14 @@ model, processor = load_model_and_processor(model_name, device)
17
 
18
 
19
  @spaces.GPU
20
- def caption_image(image: PIL.Image.Image, max_new_tokens: int) -> str:
21
  """
22
  Generates a caption based on the given image using the model.
23
 
24
  Args:
25
  - image (PIL.Image.Image): The input image to be processed.
26
  - max_new_tokens (int): The maximum number of new tokens to generate.
 
27
 
28
  Returns:
29
  str: The generated caption text.
@@ -35,22 +36,26 @@ def caption_image(image: PIL.Image.Image, max_new_tokens: int) -> str:
35
 
36
  # Prepare the inputs
37
  prompt = "caption en"
38
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
 
 
 
 
 
39
 
40
  # Generate the response
41
  with torch.inference_mode():
42
- generated_ids = model.generate(
43
- **inputs, max_new_tokens=max_new_tokens, do_sample=sampling
44
  )
45
-
46
- # Decode the generated response
47
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)
48
 
49
  # Log the successful generation of the caption
50
  logging.info("Caption generated successfully.")
51
 
52
  # Return the generated caption
53
- return result[0][len(prompt) :].lstrip("\n")
54
 
55
  # Handle exceptions that may occur during caption generation
56
  except Exception as e:
 
17
 
18
 
19
  @spaces.GPU
20
+ def caption_image(image: PIL.Image.Image, max_new_tokens: int, sampling: bool) -> str:
21
  """
22
  Generates a caption based on the given image using the model.
23
 
24
  Args:
25
  - image (PIL.Image.Image): The input image to be processed.
26
  - max_new_tokens (int): The maximum number of new tokens to generate.
27
+ - sampling (bool): Whether to use sampling or not.
28
 
29
  Returns:
30
  str: The generated caption text.
 
36
 
37
  # Prepare the inputs
38
  prompt = "caption en"
39
+ model_inputs = (
40
+ processor(text=prompt, images=image, return_tensors="pt")
41
+ .to(torch.bfloat16)
42
+ .to(device)
43
+ )
44
+ input_len = model_inputs["input_ids"].shape[-1]
45
 
46
  # Generate the response
47
  with torch.inference_mode():
48
+ generation = model.generate(
49
+ **model_inputs, max_new_tokens=max_new_tokens, do_sample=sampling
50
  )
51
+ generation = generation[0][input_len:]
52
+ decoded = processor.decode(generation, skip_special_tokens=True)
 
53
 
54
  # Log the successful generation of the caption
55
  logging.info("Caption generated successfully.")
56
 
57
  # Return the generated caption
58
+ return decoded
59
 
60
  # Handle exceptions that may occur during caption generation
61
  except Exception as e: