anakin87's picture
Update app.py
7cb9de2 verified
raw
history blame
4.75 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
# Gemma 2 2B Neogenesis ITA ๐Ÿ’Ž๐ŸŒ๐Ÿ‡ฎ๐Ÿ‡น
Fine-tuned version of Google/gemma-2-2b-it to improve the performance on the Italian language.
Small (2.6 B parameters) but good model, with 8k context length.
[๐Ÿชช **Model card**](https://huggingface.co/anakin87/gemma-2-2b-neogenesis-ita)
[๐Ÿ““ **Kaggle notebook**](https://www.kaggle.com/code/anakin87/post-training-gemma-for-italian-and-beyond) - Learn how this model was trained.
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "anakin87/gemma-2-2b-neogenesis-ita"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = chat_history.copy()
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Ciao! Come stai?"],
["Pro e contro di una relazione a lungo termine. Elenco puntato con max 3 pro e 3 contro sintetici."],
["Quante ore impiega un uomo per mangiare un elicottero?"],
["Come si apre un file JSON in Python?"],
["Fammi un elenco puntato dei pro e contro di vivere in Italia. Massimo 2 pro e 2 contro."],
["Inventa una breve storia con animali sul valore dell'amicizia."],
["Scrivi un articolo di 100 parole sui 'Benefici dell'open-source nella ricerca sull'intelligenza artificiale'"],
["Can you explain briefly to me what is the Python programming language?"],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
type="messages",
)
fonts = {"font":[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
"font_mono": [gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"]}
with gr.Blocks(css_paths="style.css", fill_height=True, theme=gr.themes.Soft(**fonts)) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()