sagar007 commited on
Commit
7e9dd79
·
verified ·
1 Parent(s): b834bb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -43
app.py CHANGED
@@ -1,26 +1,45 @@
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoTokenizer, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Load the model and tokenizer
6
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- pipe = pipeline(
9
- "text-generation",
10
- model=model_name,
11
  torch_dtype=torch.bfloat16,
12
  device_map="auto"
13
  )
14
 
15
- # HTML content
16
- HTML_CONTENT = """
17
- <h1>Llama-3.1-Storm-8B Text Generation</h1>
18
- <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
19
- <div class="llama-image">
20
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama" style="width:200px; border-radius:10px;">
21
- </div>
22
- """
23
-
24
  def generate_text(prompt, max_length, temperature):
25
  messages = [
26
  {"role": "system", "content": "You are a helpful assistant."},
@@ -28,8 +47,10 @@ def generate_text(prompt, max_length, temperature):
28
  ]
29
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
30
 
31
- outputs = pipe(
32
- formatted_prompt,
 
 
33
  max_new_tokens=max_length,
34
  do_sample=True,
35
  temperature=temperature,
@@ -37,33 +58,25 @@ def generate_text(prompt, max_length, temperature):
37
  top_p=0.95,
38
  )
39
 
40
- return outputs[0]['generated_text'][len(formatted_prompt):]
41
 
42
- # Define examples
43
- examples = [
44
- ["Tell me a short story about a brave astronaut.", 150, 0.7],
45
- ["Explain quantum computing in simple terms.", 200, 0.5],
46
- ["Write a haiku about spring.", 50, 1.0],
47
- ]
48
-
49
- with gr.Blocks() as demo:
50
- gr.HTML(HTML_CONTENT)
51
- with gr.Row():
52
- with gr.Column(scale=2):
53
- prompt = gr.Textbox(label="Prompt", lines=5)
54
- max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
55
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
56
- submit_button = gr.Button("Generate")
57
- with gr.Column(scale=2):
58
- output = gr.Textbox(label="Generated Text", lines=10)
59
-
60
- submit_button.click(generate_text, inputs=[prompt, max_length, temperature], outputs=[output])
61
-
62
- # Add examples
63
- gr.Examples(
64
- examples=examples,
65
- inputs=[prompt, max_length, temperature],
66
- )
67
 
68
- if __name__ == "__main__":
69
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ # HTML template for custom UI
7
+ HTML_TEMPLATE = """
8
+ <style>
9
+ .llama-image {
10
+ display: flex;
11
+ justify-content: center;
12
+ margin-bottom: 20px;
13
+ }
14
+ .llama-image img {
15
+ max-width: 300px;
16
+ border-radius: 10px;
17
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
18
+ }
19
+ .llama-description {
20
+ text-align: center;
21
+ font-weight: bold;
22
+ margin-top: 10px;
23
+ }
24
+ </style>
25
+ <div class="llama-image">
26
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
27
+ <div class="llama-description">Llama-3.1-Storm-8B Model</div>
28
+ </div>
29
+ <h1>Llama-3.1-Storm-8B Text Generation</h1>
30
+ <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
31
+ """
32
 
33
  # Load the model and tokenizer
34
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
35
  tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
 
38
  torch_dtype=torch.bfloat16,
39
  device_map="auto"
40
  )
41
 
42
+ @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
43
  def generate_text(prompt, max_length, temperature):
44
  messages = [
45
  {"role": "system", "content": "You are a helpful assistant."},
 
47
  ]
48
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
49
 
50
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
51
+
52
+ outputs = model.generate(
53
+ **inputs,
54
  max_new_tokens=max_length,
55
  do_sample=True,
56
  temperature=temperature,
 
58
  top_p=0.95,
59
  )
60
 
61
+ return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
62
 
63
+ # Create Gradio interface
64
+ iface = gr.Interface(
65
+ fn=generate_text,
66
+ inputs=[
67
+ gr.Textbox(lines=5, label="Prompt"),
68
+ gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length"),
69
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
70
+ ],
71
+ outputs=gr.Textbox(lines=10, label="Generated Text"),
72
+ title="Llama-3.1-Storm-8B Text Generation",
73
+ description="Enter a prompt to generate text using the Llama-3.1-Storm-8B model.",
74
+ article=None,
75
+ css=".gradio-container {max-width: 800px; margin: auto;}",
76
+ )
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ iface.launch(
79
+ additional_inputs=[
80
+ gr.HTML(HTML_TEMPLATE)
81
+ ]
82
+ )