File size: 3,531 Bytes
2bbaa94
 
 
 
aa09f7d
 
 
 
2bbaa94
4da74c8
 
 
 
aa09f7d
4da74c8
 
aa09f7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bbaa94
a68cebc
 
 
 
 
 
2bbaa94
 
a68cebc
2bbaa94
 
 
 
a68cebc
2bbaa94
a68cebc
4e83e49
c887c10
a68cebc
 
b82144a
c887c10
2bbaa94
a68cebc
2bbaa94
a68cebc
 
 
 
 
 
 
 
 
2bbaa94
a68cebc
2bbaa94
 
a68cebc
 
 
2bbaa94
 
 
 
 
 
 
 
 
 
a68cebc
 
 
 
 
 
 
 
 
 
d17861b
2bbaa94
d17861b
a68cebc
 
 
 
 
d17861b
2bbaa94
a68cebc
2bbaa94
d17861b
a68cebc
 
b82144a
d17861b
2bbaa94
 
b82144a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import subprocess
import sys
import os

# Fonction pour installer un package si non présent
def install_package(package_name):
    subprocess.run([sys.executable, "-m", "pip", "install", package_name], check=True)

# Vérifiez si torch est installé, sinon installez-le
try:
    import torch
except ImportError:
    print("Torch n'est pas installé. Installation de torch...")
    install_package("torch")
    import torch

# Vérifiez si transformers est installé, sinon installez-le
try:
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TextIteratorStreamer,
    )
except ImportError:
    print("Transformers n'est pas installé. Installation de transformers...")
    install_package("transformers")
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TextIteratorStreamer,
    )

# Installer flash-attn
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

import gradio as gr
from threading import Thread

# Obtenir le token d'authentification Hugging Face
token = os.getenv("HF_TOKEN")
if not token:
    raise ValueError("Le token d'authentification HF_TOKEN n'est pas défini.")

# Charger le modèle et le tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "HaitameLaf/Phi3_StoryGenerator",
    token=token,
    trust_remote_code=True,
)
tok = AutoTokenizer.from_pretrained("HaitameLaf/Phi3_StoryGenerator", token=token)

terminators = [tok.eos_token_id]

# Vérifier la disponibilité du GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

model = model.to(device)

# Fonction de chat
def chat(message, history, temperature, do_sample, max_tokens):
    chat = [{"role": "user", "content": item[0]} for item in history]
    chat.extend({"role": "assistant", "content": item[1]} for item in history if item[1])
    chat.append({"role": "user", "content": message})
    messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    model_inputs = tok([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = {
        "input_ids": model_inputs.input_ids,
        "streamer": streamer,
        "max_new_tokens": max_tokens,
        "do_sample": do_sample,
        "temperature": temperature,
        "eos_token_id": terminators,
    }

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text

    yield partial_text

# Configuration de Gradio
demo = gr.ChatInterface(
    fn=chat,
    examples=[["Write me a poem about Machine Learning."]],
    additional_inputs_accordion=gr.Accordion(
        label="⚙️ Parameters", open=False, render=False
    ),
    additional_inputs=[
        gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature"),
        gr.Checkbox(label="Sampling", value=True),
        gr.Slider(minimum=128, maximum=4096, step=1, value=512, label="Max new tokens"),
    ],
    stop_btn="Stop Generation",
    title="Chat With LLMs",
    description="Now Running [HaitameLaf/Phi3_StoryGenerator](https://huggingface.co/HaitameLaf/Phi3_StoryGenerator)",
)

if __name__ == "__main__":
    demo.launch()