prithivMLmods commited on
Commit
435e30e
·
verified ·
1 Parent(s): 2c6ba27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -20
app.py CHANGED
@@ -1,12 +1,13 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from collections.abc import Iterator
5
- from threading import Thread
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
- from pyvis.network import Network
8
- import networkx as nx
9
- import os
10
 
11
  DESCRIPTION = """
12
  # GWQ PREV
@@ -28,19 +29,20 @@ model = AutoModelForCausalLM.from_pretrained(
28
  model.config.sliding_window = 4096
29
  model.eval()
30
 
31
- def create_knowledge_graph(text):
32
- # Simple example: Create a graph from the text
33
  G = nx.Graph()
34
  words = text.split()
35
  for i in range(len(words) - 1):
36
  G.add_edge(words[i], words[i + 1])
37
- return G
38
-
39
- def visualize_knowledge_graph(graph):
40
- net = Network(notebook=True, cdn_resources='in_line')
41
- net.from_nx(graph)
42
- net.show("knowledge_graph.html")
43
- return "knowledge_graph.html"
 
44
 
45
  @spaces.GPU(duration=120)
46
  def generate(
@@ -51,7 +53,6 @@ def generate(
51
  top_p: float = 0.9,
52
  top_k: int = 50,
53
  repetition_penalty: float = 1.2,
54
- visualize_graph: bool = False,
55
  ) -> Iterator[str]:
56
  conversation = chat_history.copy()
57
  conversation.append({"role": "user", "content": message})
@@ -82,10 +83,9 @@ def generate(
82
  outputs.append(text)
83
  yield "".join(outputs)
84
 
85
- if visualize_graph:
86
- graph = create_knowledge_graph("".join(outputs))
87
- graph_file = visualize_knowledge_graph(graph)
88
- yield f"Knowledge graph saved to {graph_file}"
89
 
90
  demo = gr.ChatInterface(
91
  fn=generate,
@@ -125,7 +125,6 @@ demo = gr.ChatInterface(
125
  step=0.05,
126
  value=1.2,
127
  ),
128
- gr.Checkbox(label="Visualize Knowledge Graph", value=False),
129
  ],
130
  stop_btn=None,
131
  examples=[
@@ -139,6 +138,7 @@ demo = gr.ChatInterface(
139
  description=DESCRIPTION,
140
  css_paths="style.css",
141
  fill_height=True,
 
142
  )
143
 
144
  if __name__ == "__main__":
 
1
+ import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
+ import networkx as nx
5
+ import matplotlib.pyplot as plt
6
+
7
  import gradio as gr
8
  import spaces
9
  import torch
 
 
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
11
 
12
  DESCRIPTION = """
13
  # GWQ PREV
 
29
  model.config.sliding_window = 4096
30
  model.eval()
31
 
32
+ def create_knowledge_graph_image(text):
33
+ # Example: Create a simple knowledge graph from the text
34
  G = nx.Graph()
35
  words = text.split()
36
  for i in range(len(words) - 1):
37
  G.add_edge(words[i], words[i + 1])
38
+
39
+ # Draw the graph
40
+ plt.figure(figsize=(8, 6))
41
+ pos = nx.spring_layout(G)
42
+ nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray', node_size=2000, font_size=10)
43
+ plt.savefig("knowledge_graph.png")
44
+ plt.close()
45
+ return "knowledge_graph.png"
46
 
47
  @spaces.GPU(duration=120)
48
  def generate(
 
53
  top_p: float = 0.9,
54
  top_k: int = 50,
55
  repetition_penalty: float = 1.2,
 
56
  ) -> Iterator[str]:
57
  conversation = chat_history.copy()
58
  conversation.append({"role": "user", "content": message})
 
83
  outputs.append(text)
84
  yield "".join(outputs)
85
 
86
+ # After generating the text, create the knowledge graph image
87
+ knowledge_graph_image = create_knowledge_graph_image("".join(outputs))
88
+ yield knowledge_graph_image
 
89
 
90
  demo = gr.ChatInterface(
91
  fn=generate,
 
125
  step=0.05,
126
  value=1.2,
127
  ),
 
128
  ],
129
  stop_btn=None,
130
  examples=[
 
138
  description=DESCRIPTION,
139
  css_paths="style.css",
140
  fill_height=True,
141
+ additional_outputs=[gr.Image(label="Knowledge Graph")]
142
  )
143
 
144
  if __name__ == "__main__":