prithivMLmods commited on
Commit
d9f4030
·
verified ·
1 Parent(s): 90dd2c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -39
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  from collections.abc import Iterator
3
  from threading import Thread
4
- import networkx as nx
5
- import matplotlib.pyplot as plt
6
  import gradio as gr
7
  import spaces
8
  import torch
@@ -28,29 +27,8 @@ model = AutoModelForCausalLM.from_pretrained(
28
  model.config.sliding_window = 4096
29
  model.eval()
30
 
31
- def generate_knowledge_graph():
32
- # Create a simple knowledge graph
33
- G = nx.Graph()
34
- G.add_node("AI", title="Artificial Intelligence")
35
- G.add_node("ML", title="Machine Learning")
36
- G.add_node("DL", title="Deep Learning")
37
- G.add_edge("AI", "ML")
38
- G.add_edge("ML", "DL")
39
-
40
- # Draw the graph using matplotlib
41
- plt.figure(figsize=(8, 6))
42
- pos = nx.spring_layout(G)
43
- nx.draw(G, pos, with_labels=True, node_size=3000, node_color="lightblue", font_size=10, font_weight="bold")
44
- plt.title("Knowledge Graph")
45
-
46
- # Save the graph as a PDF
47
- pdf_path = "knowledge_graph.pdf"
48
- plt.savefig(pdf_path, format="pdf")
49
- plt.close()
50
-
51
- return pdf_path
52
 
53
- @spaces.GPU(duration=120)
54
  def generate(
55
  message: str,
56
  chat_history: list[dict],
@@ -63,17 +41,15 @@ def generate(
63
  conversation = chat_history.copy()
64
  conversation.append({"role": "user", "content": message})
65
 
66
- # Tokenize the input
67
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
68
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
69
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
70
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
71
  input_ids = input_ids.to(model.device)
72
 
73
- # Set up the streamer
74
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
75
  generate_kwargs = dict(
76
- input_ids=input_ids,
77
  streamer=streamer,
78
  max_new_tokens=max_new_tokens,
79
  do_sample=True,
@@ -83,22 +59,14 @@ def generate(
83
  num_beams=1,
84
  repetition_penalty=repetition_penalty,
85
  )
86
-
87
- # Start the generation in a separate thread
88
  t = Thread(target=model.generate, kwargs=generate_kwargs)
89
  t.start()
90
 
91
- # Stream the output
92
  outputs = []
93
  for text in streamer:
94
  outputs.append(text)
95
  yield "".join(outputs)
96
 
97
- # Ensure the thread is joined after completion
98
- t.join()
99
-
100
- # Generate the knowledge graph PDF file
101
- pdf_path = generate_knowledge_graph()
102
 
103
  demo = gr.ChatInterface(
104
  fn=generate,
@@ -138,15 +106,13 @@ demo = gr.ChatInterface(
138
  step=0.05,
139
  value=1.2,
140
  ),
141
- gr.File(label="Download Knowledge Graph (PDF)", value=pdf_path, visible=True),
142
  ],
143
  stop_btn=None,
144
  examples=[
145
  ["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
146
  ["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
147
  ["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
148
- ["How many hours does it take a man to eat a Helicopter?"],
149
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
150
  ],
151
  cache_examples=False,
152
  type="messages",
@@ -155,5 +121,6 @@ demo = gr.ChatInterface(
155
  fill_height=True,
156
  )
157
 
 
158
  if __name__ == "__main__":
159
- demo.queue(max_size=20).launch(share=True) # Set share=True for a public link
 
1
  import os
2
  from collections.abc import Iterator
3
  from threading import Thread
4
+
 
5
  import gradio as gr
6
  import spaces
7
  import torch
 
27
  model.config.sliding_window = 4096
28
  model.eval()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ @spaces.GPU()
32
  def generate(
33
  message: str,
34
  chat_history: list[dict],
 
41
  conversation = chat_history.copy()
42
  conversation.append({"role": "user", "content": message})
43
 
 
44
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
45
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
  input_ids = input_ids.to(model.device)
49
 
 
50
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
51
  generate_kwargs = dict(
52
+ {"input_ids": input_ids},
53
  streamer=streamer,
54
  max_new_tokens=max_new_tokens,
55
  do_sample=True,
 
59
  num_beams=1,
60
  repetition_penalty=repetition_penalty,
61
  )
 
 
62
  t = Thread(target=model.generate, kwargs=generate_kwargs)
63
  t.start()
64
 
 
65
  outputs = []
66
  for text in streamer:
67
  outputs.append(text)
68
  yield "".join(outputs)
69
 
 
 
 
 
 
70
 
71
  demo = gr.ChatInterface(
72
  fn=generate,
 
106
  step=0.05,
107
  value=1.2,
108
  ),
 
109
  ],
110
  stop_btn=None,
111
  examples=[
112
  ["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
113
  ["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
114
  ["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
115
+ ["How Many R's in the Word 'STRAWBERRY' ?"],
 
116
  ],
117
  cache_examples=False,
118
  type="messages",
 
121
  fill_height=True,
122
  )
123
 
124
+
125
  if __name__ == "__main__":
126
+ demo.queue(max_size=20).launch()