import os from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from PIL import Image import uuid import io # Text-only model setup DESCRIPTION = """ # GWQ PREV """ MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "prithivMLmods/GWQ2b" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.config.sliding_window = 4096 model.eval() # Multimodal model setup MULTIMODAL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" multimodal_model = Qwen2VLForConditionalGeneration.from_pretrained( MULTIMODAL_MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16 ).to("cuda").eval() multimodal_processor = AutoProcessor.from_pretrained(MULTIMODAL_MODEL_ID, trust_remote_code=True) image_extensions = Image.registered_extensions() video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp") def identify_and_save_blob(blob_path): """Identifies if the blob is an image or video and saves it accordingly.""" try: with open(blob_path, 'rb') as file: blob_content = file.read() # Try to identify if it's an image try: Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image extension = ".png" # Default to PNG for saving media_type = "image" except (IOError, SyntaxError): # If it's not a valid image, assume it's a video extension = ".mp4" # Default to MP4 for saving media_type = "video" # Create a unique filename filename = f"temp_{uuid.uuid4()}_media{extension}" with open(filename, "wb") as f: f.write(blob_content) return filename, media_type except FileNotFoundError: raise ValueError(f"The file {blob_path} was not found.") except Exception as e: raise ValueError(f"An error occurred while processing the file: {e}") @spaces.GPU() def generate( message: str, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, files: list = None, ) -> Iterator[str]: if files and len(files) > 0: # Multimodal input media_path = files[0] if media_path.endswith(tuple([i for i, f in image_extensions.items()])): media_type = "image" elif media_path.endswith(video_extensions): media_type = "video" else: try: media_path, media_type = identify_and_save_blob(media_path) except Exception as e: raise ValueError("Unsupported media type. Please upload an image or video.") messages = [ { "role": "user", "content": [ { "type": media_type, media_type: media_path, **({"fps": 8.0} if media_type == "video" else {}), }, {"type": "text", "text": message}, ], } ] text = multimodal_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = multimodal_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to("cuda") streamer = TextIteratorStreamer( multimodal_processor, skip_prompt=True, **{"skip_special_tokens": True} ) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) thread = Thread(target=multimodal_model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer else: # Text-only input conversation = chat_history.copy() conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) demo = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=None, examples=[ ["Hello there! How are you doing?"], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ], cache_examples=False, type="messages", description=DESCRIPTION, css_paths="style.css", fill_height=True, multimodal=True, textbox=gr.MultimodalTextbox(), ) if __name__ == "__main__": demo.queue(max_size=20).launch()