sitammeur commited on
Commit
2fd914d
·
verified ·
1 Parent(s): 7f731e7

Update src/app/response.py

Browse files
Files changed (1) hide show
  1. src/app/response.py +59 -59
src/app/response.py CHANGED
@@ -1,59 +1,59 @@
1
- # Necessary imports
2
- import sys
3
- import PIL.Image
4
- import torch
5
- import gradio as gr
6
- import spaces
7
-
8
- # Local imports
9
- from src.config import device, model_name, sampling
10
- from src.app.model import load_model_and_tokenizer
11
- from src.logger import logging
12
- from src.exception import CustomExceptionHandling
13
-
14
-
15
- # Model, tokenizer and processor
16
- model, tokenizer, processor = load_model_and_tokenizer(model_name, device)
17
-
18
-
19
- @spaces.GPU
20
- def describe_image(text: str, image: PIL.Image.Image, max_new_tokens: int) -> str:
21
- """
22
- Generates a response based on the given text and image using the model.
23
-
24
- Args:
25
- - text (str): The input text to be processed.
26
- - image (PIL.Image.Image): The input image to be processed.
27
- - max_new_tokens (int): The maximum number of new tokens to generate.
28
-
29
- Returns:
30
- str: The generated response text.
31
- """
32
- try:
33
- # Check if image or text is None
34
- if not image or not text:
35
- gr.Warning("Please provide an image and a question.")
36
-
37
- # Prepare the inputs
38
- text = "answer en " + text
39
- inputs = processor(text=text, images=image, return_tensors="pt").to(device)
40
-
41
- # Generate the response
42
- with torch.inference_mode():
43
- generated_ids = model.generate(
44
- **inputs, max_new_tokens=max_new_tokens, do_sample=sampling
45
- )
46
-
47
- # Decode the generated response
48
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)
49
-
50
- # Log the successful generation of the answer
51
- logging.info("Answer generated successfully.")
52
-
53
- # Return the generated response
54
- return result[0][len(text) :].lstrip("\n")
55
-
56
- # Handle exceptions that may occur during answer generation
57
- except Exception as e:
58
- # Custom exception handling
59
- raise CustomExceptionHandling(e, sys) from e
 
1
+ # Necessary imports
2
+ import sys
3
+ import PIL.Image
4
+ import torch
5
+ import gradio as gr
6
+ import spaces
7
+
8
+ # Local imports
9
+ from src.config import device, model_name, sampling
10
+ from src.app.model import load_model_and_processor
11
+ from src.logger import logging
12
+ from src.exception import CustomExceptionHandling
13
+
14
+
15
+ # Model and processor
16
+ model, processor = load_model_and_processor(model_name, device)
17
+
18
+
19
+ @spaces.GPU
20
+ def describe_image(text: str, image: PIL.Image.Image, max_new_tokens: int) -> str:
21
+ """
22
+ Generates a response based on the given text and image using the model.
23
+
24
+ Args:
25
+ - text (str): The input text to be processed.
26
+ - image (PIL.Image.Image): The input image to be processed.
27
+ - max_new_tokens (int): The maximum number of new tokens to generate.
28
+
29
+ Returns:
30
+ str: The generated response text.
31
+ """
32
+ try:
33
+ # Check if image or text is None
34
+ if not image or not text:
35
+ gr.Warning("Please provide an image and a question.")
36
+
37
+ # Prepare the inputs
38
+ text = "answer en " + text
39
+ inputs = processor(text=text, images=image, return_tensors="pt").to(device)
40
+
41
+ # Generate the response
42
+ with torch.inference_mode():
43
+ generated_ids = model.generate(
44
+ **inputs, max_new_tokens=max_new_tokens, do_sample=sampling
45
+ )
46
+
47
+ # Decode the generated response
48
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)
49
+
50
+ # Log the successful generation of the answer
51
+ logging.info("Answer generated successfully.")
52
+
53
+ # Return the generated response
54
+ return result[0][len(text) :].lstrip("\n")
55
+
56
+ # Handle exceptions that may occur during answer generation
57
+ except Exception as e:
58
+ # Custom exception handling
59
+ raise CustomExceptionHandling(e, sys) from e