prithivMLmods commited on
Commit
f3475ee
·
verified ·
1 Parent(s): 69471fa

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """
11
+ # GWQ PREV
12
+ """
13
+
14
+ MAX_MAX_NEW_TOKENS = 2048
15
+ DEFAULT_MAX_NEW_TOKENS = 1024
16
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
+
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+
20
+ model_id = "prithivMLmods/GWQ2b"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_id,
24
+ device_map="auto",
25
+ torch_dtype=torch.bfloat16,
26
+ )
27
+ model.config.sliding_window = 4096
28
+ model.eval()
29
+
30
+
31
+ @spaces.GPU()
32
+ def generate(
33
+ message: str,
34
+ chat_history: list[dict],
35
+ max_new_tokens: int = 1024,
36
+ temperature: float = 0.6,
37
+ top_p: float = 0.9,
38
+ top_k: int = 50,
39
+ repetition_penalty: float = 1.2,
40
+ ) -> Iterator[str]:
41
+ conversation = chat_history.copy()
42
+ conversation.append({"role": "user", "content": message})
43
+
44
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
45
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
+ input_ids = input_ids.to(model.device)
49
+
50
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
51
+ generate_kwargs = dict(
52
+ {"input_ids": input_ids},
53
+ streamer=streamer,
54
+ max_new_tokens=max_new_tokens,
55
+ do_sample=True,
56
+ top_p=top_p,
57
+ top_k=top_k,
58
+ temperature=temperature,
59
+ num_beams=1,
60
+ repetition_penalty=repetition_penalty,
61
+ )
62
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
63
+ t.start()
64
+
65
+ outputs = []
66
+ for text in streamer:
67
+ outputs.append(text)
68
+ yield "".join(outputs)
69
+
70
+
71
+ demo = gr.ChatInterface(
72
+ fn=generate,
73
+ additional_inputs=[
74
+ gr.Slider(
75
+ label="Max new tokens",
76
+ minimum=1,
77
+ maximum=MAX_MAX_NEW_TOKENS,
78
+ step=1,
79
+ value=DEFAULT_MAX_NEW_TOKENS,
80
+ ),
81
+ gr.Slider(
82
+ label="Temperature",
83
+ minimum=0.1,
84
+ maximum=4.0,
85
+ step=0.1,
86
+ value=0.6,
87
+ ),
88
+ gr.Slider(
89
+ label="Top-p (nucleus sampling)",
90
+ minimum=0.05,
91
+ maximum=1.0,
92
+ step=0.05,
93
+ value=0.9,
94
+ ),
95
+ gr.Slider(
96
+ label="Top-k",
97
+ minimum=1,
98
+ maximum=1000,
99
+ step=1,
100
+ value=50,
101
+ ),
102
+ gr.Slider(
103
+ label="Repetition penalty",
104
+ minimum=1.0,
105
+ maximum=2.0,
106
+ step=0.05,
107
+ value=1.2,
108
+ ),
109
+ ],
110
+ stop_btn=None,
111
+ examples=[
112
+ ["Hello there! How are you doing?"],
113
+ ["Can you explain briefly to me what is the Python programming language?"],
114
+ ["Explain the plot of Cinderella in a sentence."],
115
+ ["How many hours does it take a man to eat a Helicopter?"],
116
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
117
+ ],
118
+ cache_examples=False,
119
+ type="messages",
120
+ description=DESCRIPTION,
121
+ css_paths="style.css",
122
+ fill_height=True,
123
+ )
124
+
125
+
126
+ if __name__ == "__main__":
127
+ demo.queue(max_size=20).launch()