File size: 4,598 Bytes
985eabb
7e9dd79
985eabb
7e9dd79
e7bd83c
7e9dd79
8b8d0cf
 
 
7e9dd79
 
8b8d0cf
 
 
 
7e9dd79
1f7ba92
02a0e92
1f7ba92
 
02a0e92
1f7ba92
b74d514
7e9dd79
b74d514
7e9dd79
 
1f7ba92
 
 
 
 
02a0e92
b74d514
7e9dd79
1e235cc
b74d514
eac46c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15967e4
eac46c1
 
 
 
 
 
 
 
 
b74d514
eac46c1
 
 
 
 
 
 
 
 
 
 
b74d514
eac46c1
b74d514
 
 
 
eac46c1
 
 
 
 
 
 
 
 
 
 
b74d514
 
 
eac46c1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
#

# Load the model and tokenizer
model_name = "akjindal53244/Llama-3.1-Storm-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

@spaces.GPU(duration=120)
def generate_text(prompt, max_length, temperature):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_length,
        do_sample=True,
        temperature=temperature,
        top_k=100,
        top_p=0.95,
    )

    return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)


# Custom CSS
css = """
body {
    background-color: #1a1a2e;
    color: #e0e0e0;
    font-family: 'Arial', sans-serif;
}
.container {
    max-width: 900px;
    margin: auto;
    padding: 20px;
}
.gradio-container {
    background-color: #16213e;
    border-radius: 15px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.header {
    background-color: #0f3460;
    padding: 20px;
    border-radius: 15px 15px 0 0;
    text-align: center;
    margin-bottom: 20px;
}
.header h1 {
    color: #e94560;
    font-size: 2.5em;
    margin-bottom: 10px;
}
.header p {
    color: #a0a0a0;
}
.header img {
    max-width: 300px;
    border-radius: 10px;
    margin: 15px auto;
    display: block;
}
.input-group, .output-group {
    background-color: #1a1a2e;
    padding: 20px;
    border-radius: 10px;
    margin-bottom: 20px;
}
.input-group label, .output-group label {
    color: #e94560;
    font-weight: bold;
}
.generate-btn {
    background-color: #e94560 !important;
    color: white !important;
    border: none !important;
    border-radius: 5px !important;
    padding: 10px 20px !important;
    font-size: 16px !important;
    cursor: pointer !important;
    transition: background-color 0.3s ease !important;
}
.generate-btn:hover {
    background-color: #c81e45 !important;
}
.example-prompts {
    background-color: #1f2b47;
    padding: 15px;
    border-radius: 10px;
    margin-bottom: 20px;
}
.example-prompts h3 {
    color: #e94560;
    margin-bottom: 10px;
}
.example-prompts ul {
    list-style-type: none;
    padding-left: 0;
}
.example-prompts li {
    margin-bottom: 5px;
    cursor: pointer;
    transition: color 0.3s ease;
}
.example-prompts li:hover {
    color: #e94560;
}
"""

# Example prompts
example_prompts = [
    "Write a Python function to find the n-th Fibonacci number.",
    "Explain the concept of recursion in programming.",
    "What are the key differences between Python and JavaScript?",
    "Tell me a short story about a time-traveling robot.",
    "Describe the process of photosynthesis in simple terms."
]

# Gradio interface
# Gradio interface
with gr.Blocks(css=css) as iface:
    gr.HTML(
        """
        <div class="header">
            <h1>Llama-3.1-Storm-8B Text Generation</h1>
            <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
            <img src="/static-proxy?url=https%3A%2F%2Fcdn-uploads.huggingface.co%2Fproduction%2Fuploads%2F64c75c1237333ccfef30a602%2FtmOlbERGKP7JSODa6T06J.jpeg%26quot%3B alt="Llama">
        </div>
        """
    )

    with gr.Group():
        with gr.Group(elem_classes="example-prompts"):
            gr.HTML("<h3>Example Prompts:</h3>")
            example_buttons = [gr.Button(prompt) for prompt in example_prompts]

        with gr.Group(elem_classes="input-group"):
            prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=5)
            max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
            temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
            generate_btn = gr.Button("Generate", elem_classes="generate-btn")

        with gr.Group(elem_classes="output-group"):
            output = gr.Textbox(label="Generated Text", lines=10)

    generate_btn.click(generate_text, inputs=[prompt, max_length, temperature], outputs=output)

    # Set up example prompt buttons
    for button in example_buttons:
        button.click(lambda x: x, inputs=[button], outputs=[prompt])

# Launch the app
iface.launch()