Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,749 Bytes
cc1bdc1 7cb9de2 cc1bdc1 7cb9de2 cc1bdc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
# Gemma 2 2B Neogenesis ITA ๐๐๐ฎ๐น
Fine-tuned version of Google/gemma-2-2b-it to improve the performance on the Italian language.
Small (2.6 B parameters) but good model, with 8k context length.
[๐ชช **Model card**](https://huggingface.co/anakin87/gemma-2-2b-neogenesis-ita)
[๐ **Kaggle notebook**](https://www.kaggle.com/code/anakin87/post-training-gemma-for-italian-and-beyond) - Learn how this model was trained.
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "anakin87/gemma-2-2b-neogenesis-ita"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.config.sliding_window = 4096
model.eval()
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = chat_history.copy()
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Ciao! Come stai?"],
["Pro e contro di una relazione a lungo termine. Elenco puntato con max 3 pro e 3 contro sintetici."],
["Quante ore impiega un uomo per mangiare un elicottero?"],
["Come si apre un file JSON in Python?"],
["Fammi un elenco puntato dei pro e contro di vivere in Italia. Massimo 2 pro e 2 contro."],
["Inventa una breve storia con animali sul valore dell'amicizia."],
["Scrivi un articolo di 100 parole sui 'Benefici dell'open-source nella ricerca sull'intelligenza artificiale'"],
["Can you explain briefly to me what is the Python programming language?"],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
type="messages",
)
fonts = {"font":[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
"font_mono": [gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"]}
with gr.Blocks(css_paths="style.css", fill_height=True, theme=gr.themes.Soft(**fonts)) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|