import os from threading import Thread from typing import Iterator import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer DESCRIPTION = """\ # Gemma 2 9B Neogenesis ITA 💎🌍🇮🇹 Fine-tuned version of VAGOsolutions/SauerkrautLM-gemma-2-9b-it to improve the performance on the Italian language. Good model with 9.24 billion parameters, with 8k context length. [🪪 **Model card**](https://huggingface.co/anakin87/gemma-2-9b-neogenesis-ita) """ MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_id = "anakin87/gemma-2-9b-neogenesis-ita" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) model.config.sliding_window = 4096 model.eval() @spaces.GPU def generate( message: str, chat_history: list[dict], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: conversation = chat_history.copy() conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=None, examples=[ ["Ciao! Come stai?"], ["Scrivi l'incipit di un racconto che inizia con: 'Era una notte buia e tempestosa, ma Anna non aveva paura del temporale..."], ["Cos'è uno static method in python? Fornisci un esempio"], ["Fammi un elenco puntato dei pro e contro di vivere in Italia. Massimo 2 pro e 2 contro."], ["Risolvere 9x^2+2x=-5"], ["Immagina di essere il capo di una missione spaziale su un pianeta sconosciuto. Durante l'esplorazione, scopri una civiltà aliena che sembra essere un pericolo per l'umanità. Come ti comporti con loro, e quali azioni intraprendi per proteggere il futuro dell'umanità, pur rispettando le leggi universali della non-interferenza?"], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ], cache_examples=False, type="messages", ) fonts = {"font":[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"], "font_mono": [gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"]} with gr.Blocks(css_paths="style.css", fill_height=True, theme=gr.themes.Soft(**fonts)) as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") chat_interface.render() if __name__ == "__main__": demo.queue(max_size=20).launch()