File size: 4,329 Bytes
59e8091
04d3cd3
59e8091
 
 
 
 
 
99ad72b
 
04d3cd3
59e8091
daa724b
59e8091
daa724b
59e8091
 
6752e3f
 
 
59e8091
6752e3f
 
59e8091
6752e3f
 
 
 
59e8091
e147ef4
f78ccb4
59e8091
6752e3f
 
 
 
 
 
 
 
 
e147ef4
99ad72b
6752e3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59e8091
6752e3f
 
59e8091
6752e3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from transformers import LlavaOnevisionProcessor, LlavaOnevisionForConditionalGeneration, TextIteratorStreamer
from threading import Thread
import re
import time 
from PIL import Image
import torch
import cv2
import spaces

model_id = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"

processor = LlavaOnevisionProcessor.from_pretrained(model_id)

model = LlavaOnevisionForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
model.to("cuda")

# Function to capture frames from the camera
def capture_camera_frames(num_frames):
    camera = cv2.VideoCapture(0)  # Accessing the camera (0 is the default camera)
    frames = []
    for _ in range(num_frames):
        ret, frame = camera.read()
        if not ret:
            break
        pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        frames.append(pil_img)
    camera.release()
    return frames

@spaces.GPU
def bot_streaming(message, history):
    txt = message.text
    ext_buffer = f"user\n{txt} assistant"
    
    if message.files:
        if len(message.files) == 1:
            image = [message.files[0].path]
        elif len(message.files) > 1:
            image = [msg.path for msg in message.files]
    else:
        image = None

    # Check if we should use the camera
    if txt.lower().startswith("camera"):
        # Capture frames from the camera
        image = capture_camera_frames(5)  # Capture 5 frames

    if message.files is None and not image:
        gr.Error("You need to upload an image or video, or access the camera for LLaVA to work.")
        return

    video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg")
    image_extensions = Image.registered_extensions()
    image_extensions = tuple([ex for ex, f in image_extensions.items()])

    if len(image) == 1:
        if image[0].endswith(video_extensions):
            video = sample_frames(image[0], 32)
            image = None
            prompt = f"<|im_start|>user <video>\n{message.text}<|im_end|><|im_start|>assistant"
        elif image[0].endswith(image_extensions):
            image = Image.open(image[0]).convert("RGB")
            video = None
            prompt = f"<|im_start|>user <image>\n{message.text}<|im_end|><|im_start|>assistant"
    elif len(image) > 1:
        image_list = []
        user_prompt = message.text

        for img in image:
            if img.endswith(image_extensions):
                img = Image.open(img).convert("RGB")
                image_list.append(img)
            elif img.endswith(video_extensions):        
                frames = sample_frames(img, 6)
                for frame in frames:
                    image_list.append(frame)
        
        toks = "<image>" * len(image_list)
        prompt = "<|im_start|>user" + toks + f"\n{user_prompt}<|im_end|><|im_start|>assistant"

        image = image_list
        video = None

    inputs = processor(text=prompt, images=image, videos=video, return_tensors="pt").to("cuda", torch.float16)
    streamer = TextIteratorStreamer(processor, **{"max_new_tokens": 200, "skip_special_tokens": True})
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200)
    generated_text = ""

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer[len(ext_buffer):]
        time.sleep(0.01)
        yield generated_text_without_prompt


# Integrate camera access into Gradio demo
demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Onevision with Camera", examples=[
    {"text": "Take a picture with the camera and describe what is in it.", "files":[]},
    {"text": "Do the cats in these two videos have the same breed? What breed is each cat?", "files":["./cats_1.mp4", "./cats_2.mp4"]},
    {"text": "Here are several images from a cooking book, showing how to prepare a meal step by step. Can you write a recipe for the meal?", "files":["./step0.png", "./step1.png", "./step2.png", "./step3.png"]}, 
], 
    textbox=gr.MultimodalTextbox(file_count="multiple"), 
    description="Upload an image or video, or try capturing frames with the camera and chat about it.",
    stop_btn="Stop Generation", multimodal=True)

demo.launch(debug=True)