prithivMLmods commited on
Commit
0097859
·
verified ·
1 Parent(s): 07c3da8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -126
app.py CHANGED
@@ -1,127 +1,127 @@
1
- import os
2
- from collections.abc import Iterator
3
- from threading import Thread
4
-
5
- import gradio as gr
6
- import spaces
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
-
10
- DESCRIPTION = """
11
- # GWQ PREV
12
- """
13
-
14
- MAX_MAX_NEW_TOKENS = 2048
15
- DEFAULT_MAX_NEW_TOKENS = 1024
16
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
-
18
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
-
20
- model_id = "prithivMLmods/GWQ2b"
21
- tokenizer = AutoTokenizer.from_pretrained(model_id)
22
- model = AutoModelForCausalLM.from_pretrained(
23
- model_id,
24
- device_map="auto",
25
- torch_dtype=torch.bfloat16,
26
- )
27
- model.config.sliding_window = 4096
28
- model.eval()
29
-
30
-
31
- @spaces.GPU()
32
- def generate(
33
- message: str,
34
- chat_history: list[dict],
35
- max_new_tokens: int = 1024,
36
- temperature: float = 0.6,
37
- top_p: float = 0.9,
38
- top_k: int = 50,
39
- repetition_penalty: float = 1.2,
40
- ) -> Iterator[str]:
41
- conversation = chat_history.copy()
42
- conversation.append({"role": "user", "content": message})
43
-
44
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
45
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
- input_ids = input_ids.to(model.device)
49
-
50
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
51
- generate_kwargs = dict(
52
- {"input_ids": input_ids},
53
- streamer=streamer,
54
- max_new_tokens=max_new_tokens,
55
- do_sample=True,
56
- top_p=top_p,
57
- top_k=top_k,
58
- temperature=temperature,
59
- num_beams=1,
60
- repetition_penalty=repetition_penalty,
61
- )
62
- t = Thread(target=model.generate, kwargs=generate_kwargs)
63
- t.start()
64
-
65
- outputs = []
66
- for text in streamer:
67
- outputs.append(text)
68
- yield "".join(outputs)
69
-
70
-
71
- demo = gr.ChatInterface(
72
- fn=generate,
73
- additional_inputs=[
74
- gr.Slider(
75
- label="Max new tokens",
76
- minimum=1,
77
- maximum=MAX_MAX_NEW_TOKENS,
78
- step=1,
79
- value=DEFAULT_MAX_NEW_TOKENS,
80
- ),
81
- gr.Slider(
82
- label="Temperature",
83
- minimum=0.1,
84
- maximum=4.0,
85
- step=0.1,
86
- value=0.6,
87
- ),
88
- gr.Slider(
89
- label="Top-p (nucleus sampling)",
90
- minimum=0.05,
91
- maximum=1.0,
92
- step=0.05,
93
- value=0.9,
94
- ),
95
- gr.Slider(
96
- label="Top-k",
97
- minimum=1,
98
- maximum=1000,
99
- step=1,
100
- value=50,
101
- ),
102
- gr.Slider(
103
- label="Repetition penalty",
104
- minimum=1.0,
105
- maximum=2.0,
106
- step=0.05,
107
- value=1.2,
108
- ),
109
- ],
110
- stop_btn=None,
111
- examples=[
112
- ["Hello there! How are you doing?"],
113
- ["Can you explain briefly to me what is the Python programming language?"],
114
- ["Explain the plot of Cinderella in a sentence."],
115
- ["How many hours does it take a man to eat a Helicopter?"],
116
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
117
- ],
118
- cache_examples=False,
119
- type="messages",
120
- description=DESCRIPTION,
121
- css_paths="style.css",
122
- fill_height=True,
123
- )
124
-
125
-
126
- if __name__ == "__main__":
127
  demo.queue(max_size=20).launch()
 
1
+ import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """
11
+ # GWQ PREV
12
+ """
13
+
14
+ MAX_MAX_NEW_TOKENS = 2048
15
+ DEFAULT_MAX_NEW_TOKENS = 1024
16
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
+
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+
20
+ model_id = "prithivMLmods/GWQ2b"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_id,
24
+ device_map="auto",
25
+ torch_dtype=torch.bfloat16,
26
+ )
27
+ model.config.sliding_window = 4096
28
+ model.eval()
29
+
30
+
31
+ @spaces.GPU(duration=120)
32
+ def generate(
33
+ message: str,
34
+ chat_history: list[dict],
35
+ max_new_tokens: int = 1024,
36
+ temperature: float = 0.6,
37
+ top_p: float = 0.9,
38
+ top_k: int = 50,
39
+ repetition_penalty: float = 1.2,
40
+ ) -> Iterator[str]:
41
+ conversation = chat_history.copy()
42
+ conversation.append({"role": "user", "content": message})
43
+
44
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
45
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
+ input_ids = input_ids.to(model.device)
49
+
50
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
51
+ generate_kwargs = dict(
52
+ {"input_ids": input_ids},
53
+ streamer=streamer,
54
+ max_new_tokens=max_new_tokens,
55
+ do_sample=True,
56
+ top_p=top_p,
57
+ top_k=top_k,
58
+ temperature=temperature,
59
+ num_beams=1,
60
+ repetition_penalty=repetition_penalty,
61
+ )
62
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
63
+ t.start()
64
+
65
+ outputs = []
66
+ for text in streamer:
67
+ outputs.append(text)
68
+ yield "".join(outputs)
69
+
70
+
71
+ demo = gr.ChatInterface(
72
+ fn=generate,
73
+ additional_inputs=[
74
+ gr.Slider(
75
+ label="Max new tokens",
76
+ minimum=1,
77
+ maximum=MAX_MAX_NEW_TOKENS,
78
+ step=1,
79
+ value=DEFAULT_MAX_NEW_TOKENS,
80
+ ),
81
+ gr.Slider(
82
+ label="Temperature",
83
+ minimum=0.1,
84
+ maximum=4.0,
85
+ step=0.1,
86
+ value=0.6,
87
+ ),
88
+ gr.Slider(
89
+ label="Top-p (nucleus sampling)",
90
+ minimum=0.05,
91
+ maximum=1.0,
92
+ step=0.05,
93
+ value=0.9,
94
+ ),
95
+ gr.Slider(
96
+ label="Top-k",
97
+ minimum=1,
98
+ maximum=1000,
99
+ step=1,
100
+ value=50,
101
+ ),
102
+ gr.Slider(
103
+ label="Repetition penalty",
104
+ minimum=1.0,
105
+ maximum=2.0,
106
+ step=0.05,
107
+ value=1.2,
108
+ ),
109
+ ],
110
+ stop_btn=None,
111
+ examples=[
112
+ ["Hello there! How are you doing?"],
113
+ ["Can you explain briefly to me what is the Python programming language?"],
114
+ ["Explain the plot of Cinderella in a sentence."],
115
+ ["How many hours does it take a man to eat a Helicopter?"],
116
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
117
+ ],
118
+ cache_examples=False,
119
+ type="messages",
120
+ description=DESCRIPTION,
121
+ css_paths="style.css",
122
+ fill_height=True,
123
+ )
124
+
125
+
126
+ if __name__ == "__main__":
127
  demo.queue(max_size=20).launch()