sagar007 commited on
Commit
2848e2c
·
verified ·
1 Parent(s): 15967e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -149
app.py CHANGED
@@ -1,18 +1,26 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  # Load the model and tokenizer
7
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- model_name,
 
11
  torch_dtype=torch.bfloat16,
12
  device_map="auto"
13
  )
14
 
15
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
16
  def generate_text(prompt, max_length, temperature):
17
  messages = [
18
  {"role": "system", "content": "You are a helpful assistant."},
@@ -20,10 +28,8 @@ def generate_text(prompt, max_length, temperature):
20
  ]
21
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
22
 
23
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
24
-
25
- outputs = model.generate(
26
- **inputs,
27
  max_new_tokens=max_length,
28
  do_sample=True,
29
  temperature=temperature,
@@ -31,151 +37,33 @@ def generate_text(prompt, max_length, temperature):
31
  top_p=0.95,
32
  )
33
 
34
- return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
35
 
36
- # Custom CSS
37
- css = """
38
- body {
39
- background-color: #1a1a2e;
40
- color: #e0e0e0;
41
- font-family: 'Arial', sans-serif;
42
- }
43
- .container {
44
- max-width: 900px;
45
- margin: auto;
46
- padding: 20px;
47
- }
48
- .gradio-container {
49
- background-color: #16213e;
50
- border-radius: 15px;
51
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
52
- }
53
- .header {
54
- background-color: #0f3460;
55
- padding: 20px;
56
- border-radius: 15px 15px 0 0;
57
- text-align: center;
58
- margin-bottom: 20px;
59
- }
60
- .header h1 {
61
- color: #e94560;
62
- font-size: 2.5em;
63
- margin-bottom: 10px;
64
- }
65
- .header p {
66
- color: #a0a0a0;
67
- }
68
- .header img {
69
- max-width: 300px;
70
- border-radius: 10px;
71
- margin: 15px auto;
72
- display: block;
73
- }
74
- .input-group, .output-group {
75
- background-color: #1a1a2e;
76
- padding: 20px;
77
- border-radius: 10px;
78
- margin-bottom: 20px;
79
- }
80
- .input-group label, .output-group label {
81
- color: #e94560;
82
- font-weight: bold;
83
- }
84
- .generate-btn {
85
- background-color: #e94560 !important;
86
- color: white !important;
87
- border: none !important;
88
- border-radius: 5px !important;
89
- padding: 10px 20px !important;
90
- font-size: 16px !important;
91
- cursor: pointer !important;
92
- transition: background-color 0.3s ease !important;
93
- }
94
- .generate-btn:hover {
95
- background-color: #c81e45 !important;
96
- }
97
- .example-prompts {
98
- background-color: #1f2b47;
99
- padding: 15px;
100
- border-radius: 10px;
101
- margin-bottom: 20px;
102
- }
103
- .example-prompts h3 {
104
- color: #e94560;
105
- margin-bottom: 10px;
106
- }
107
- .example-prompts ul {
108
- list-style-type: none;
109
- padding-left: 0;
110
- }
111
- .example-prompts li {
112
- margin-bottom: 5px;
113
- cursor: pointer;
114
- transition: color 0.3s ease;
115
- }
116
- .example-prompts li:hover {
117
- color: #e94560;
118
- }
119
- """
120
-
121
- # Example prompts
122
- example_prompts = [
123
- "Write a Python function to find the n-th Fibonacci number.",
124
- "Explain the concept of recursion in programming.",
125
- "What are the key differences between Python and JavaScript?",
126
- "Tell me a short story about a time-traveling robot.",
127
- "Describe the process of photosynthesis in simple terms."
128
  ]
129
 
130
- # Gradio interface
131
- with gr.Blocks(css=css) as iface:
132
- gr.HTML(
133
- """
134
- <div class="header">
135
- <h1>Llama-3.1-Storm-8B Text Generation</h1>
136
- <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
137
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
138
- </div>
139
- """
140
- )
141
-
142
- with gr.Group():
143
- gr.HTML(
144
- """
145
- <div class="example-prompts">
146
- <h3>Example Prompts:</h3>
147
- <ul>
148
- """ + "".join([f"<li>{prompt}</li>" for prompt in example_prompts]) + """
149
- </ul>
150
- </div>
151
- """
152
- )
153
-
154
- with gr.Group(elem_classes="input-group"):
155
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=5)
156
  max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
157
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
158
- generate_btn = gr.Button("Generate", elem_classes="generate-btn")
159
-
160
- with gr.Group(elem_classes="output-group"):
161
  output = gr.Textbox(label="Generated Text", lines=10)
162
-
163
- generate_btn.click(generate_text, inputs=[prompt, max_length, temperature], outputs=output)
164
-
165
- # JavaScript to make example prompts clickable
166
- gr.HTML(
167
- """
168
- <script>
169
- document.addEventListener('DOMContentLoaded', (event) => {
170
- document.querySelectorAll('.example-prompts li').forEach(item => {
171
- item.addEventListener('click', event => {
172
- document.querySelector('textarea[data-testid="textbox"]').value = event.target.textContent;
173
- });
174
- });
175
- });
176
- </script>
177
- """
178
  )
 
 
179
 
180
- # Launch the app
181
- iface.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoTokenizer, pipeline
4
 
5
  # Load the model and tokenizer
6
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ pipe = pipeline(
9
+ "text-generation",
10
+ model=model_name,
11
  torch_dtype=torch.bfloat16,
12
  device_map="auto"
13
  )
14
 
15
+ # HTML content
16
+ HTML_CONTENT = """
17
+ <h1 style="text-align: center;">Llama-3.1-Storm-8B Text Generation</h1>
18
+ <p style="text-align: center;">Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt or select an example, and let the AI create!</p>
19
+ <div style="display: flex; justify-content: center; margin-bottom: 20px;">
20
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama" style="width:200px; border-radius:10px;">
21
+ </div>
22
+ """
23
+
24
  def generate_text(prompt, max_length, temperature):
25
  messages = [
26
  {"role": "system", "content": "You are a helpful assistant."},
 
28
  ]
29
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
30
 
31
+ outputs = pipe(
32
+ formatted_prompt,
 
 
33
  max_new_tokens=max_length,
34
  do_sample=True,
35
  temperature=temperature,
 
37
  top_p=0.95,
38
  )
39
 
40
+ return outputs[0]['generated_text'][len(formatted_prompt):]
41
 
42
+ examples = [
43
+ "Write a short story about a magical llama.",
44
+ "Explain the concept of machine learning to a 10-year-old.",
45
+ "Describe the process of making the perfect cup of coffee.",
46
+ "What are the main differences between Python and JavaScript?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ]
48
 
49
+ with gr.Blocks() as demo:
50
+ gr.HTML(HTML_CONTENT)
51
+ with gr.Row():
52
+ with gr.Column(scale=2):
53
+ prompt = gr.Textbox(label="Prompt", lines=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
55
  temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
56
+ submit_button = gr.Button("Generate")
57
+ with gr.Column(scale=2):
 
58
  output = gr.Textbox(label="Generated Text", lines=10)
59
+
60
+ gr.Examples(
61
+ examples=examples,
62
+ inputs=prompt,
63
+ label="Click on an example to load it into the prompt box:"
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
+
66
+ submit_button.click(generate_text, inputs=[prompt, max_length, temperature], outputs=[output])
67
 
68
+ if __name__ == "__main__":
69
+ demo.launch()