import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
from threading import Thread
import re
import time
from PIL import Image
import torch
import spaces
processor = AutoProcessor.from_pretrained("ucsahin/TraVisionLM-DPO", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("ucsahin/TraVisionLM-base", trust_remote_code=True)
model_dpo = AutoModelForCausalLM.from_pretrained("ucsahin/TraVisionLM-DPO", trust_remote_code=True)
model.to("cuda:0")
model_dpo.to("cuda:0")
@spaces.GPU
def bot_streaming(message, history, max_tokens, temperature, top_p, top_k, repetition_penalty):
print(max_tokens, temperature, top_p, top_k, repetition_penalty)
print(message)
if message['files']:
image = message['files'][-1]['path']
else:
# if there's no image uploaded for this turn, look for images in the past turns
for hist in history:
if type(hist[0])==tuple:
image = hist[0][0]
if image is None:
gr.Error("Lütfen önce bir resim yükleyin.")
prompt = f"{message['text']}"
image = Image.open(image).convert("RGB")
inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda:0")
generation_kwargs = dict(
inputs, max_new_tokens=max_tokens,
do_sample=True, temperature=temperature, top_p=top_p,
top_k=top_k, repetition_penalty=repetition_penalty
)
generated_text = ""
model_outputs = model.generate(**generation_kwargs)
dpo_outputs = model_dpo.generate(**generation_kwargs)
ref_text = processor.batch_decode(model_outputs, skip_special_tokens=True)[0]
dpo_text = processor.batch_decode(dpo_outputs, skip_special_tokens=True)[0]
generated_text = f"
Base model cevabı:
\n{ref_text[len(prompt)+1:]}\nDPO model cevabı:
\n{dpo_text[len(prompt)+1:]}\n"
return generated_text
gr.set_static_paths(paths=["static/images/"])
logo_path = "static/images/logo-color-v2.png"
PLACEHOLDER = f"""
Resim yükleyin ve bir soru sorun!
Örnek resim ve soruları kullanabilirsiniz.
"""
DESCRIPTION = f"""
### 875M parametreli küçük ama çok hızlı bir Türkçe Görsel Dil Modeli 🇹🇷🌟⚡️⚡️🇹🇷
Yüklediğiniz resimleri açıklatabilir ve onlarla ilgili ucu açık sorular sorabilirsiniz 🖼️🤖
**Bu demoda sorularınıza iki TraVisionLM modeli tarafından cevap verilmektedir:** ♊♊
- **Base model:** [TraVisionLM-base](https://huggingface.co/ucsahin/TraVisionLM-base) 🌛
- **DPO model:** [TraVisionLM-DPO](https://huggingface.co/ucsahin/TraVisionLM-DPO) 🌜
"""
with gr.Accordion("Generation parameters", open=False) as parameter_accordion:
max_tokens_item = gr.Slider(64, 1024, value=512, step=64, label="Max tokens")
temperature_item = gr.Slider(0.1, 2, value=0.6, step=0.1, label="Temperature")
top_p_item = gr.Slider(0, 1.0, value=0.9, step=0.05, label="Top_p")
top_k_item = gr.Slider(0, 100, value=50, label="Top_k")
repeat_penalty_item = gr.Slider(0, 2, value=1.2, label="Repeat penalty")
demo = gr.ChatInterface(
title="TraVisionLM - Demo",
description=DESCRIPTION,
fn=bot_streaming,
chatbot=gr.Chatbot(placeholder=PLACEHOLDER, scale=1),
examples=[
[{"text": "Resimde kaç kişi var?", "files":["./family.jpg"]}],
[{"text": "Açıkla", "files":["./anitkabir2.jpg"]}],
[{"text": "Kısaca açıkla", "files":["./at.jpg"]}],
[{"text": "Görüntüdeki otobüsün görünümü nasıldır?", "files":["./bus.jpg"]}],
[{"text": "Görüntüde hava durumu nasıl?", "files":["./plane.jpg"]}],
[{"text": "Resimdeki ilginç unsurlar nelerdir?", "files":["./dog.jpg"]}],
[{"text": "Tren istasyonu kalabalık mı yoksa boş mu?", "files":["./train.jpg"]}],
[{"text": "Resimdeki araba hangi renk?", "files":["./car.jpg"]}],
[{"text": "Görüntünün odak noktası nedir?", "files":["./mandog.jpg"]}],
[{"text": "Resimde neresi görünüyor?", "files":["./galata.jpg"]}],
[{"text": "Resim nasıl bir tarza sahip?", "files":["./suluboya.jpg"]}],
[{"text": "Detaylı açıkla", "files":["./tren.jpg"]}],
[{"text": "Resimde nasıl bir etkinlik var?", "files":["./paris1.jpg"]}],
[{"text": "Görsel nasıl bir ortamı gösteriyor?", "files":["./lamba.jpg"]}],
],
additional_inputs=[max_tokens_item, temperature_item, top_p_item, top_k_item, repeat_penalty_item],
additional_inputs_accordion=parameter_accordion,
stop_btn="Stop Generation",
multimodal=True
)
demo.launch(debug=True, max_file_size="5mb")