DebateKG / app.py
Hellisotherpeople's picture
Update app.py
4133681
raw
history blame
4.61 kB
import pysbd
from txtai.embeddings import Embeddings
import networkx as nx
from tqdm import tqdm
from txtai.graph import GraphFactory
from datasets import load_dataset
import streamlit as st
import streamlit.components.v1 as components
st.set_page_config(page_title="DebateKG")
st.title("DebateKG - Automatic Policy Debate Case Creation")
st.write("WIP, give me a few more days before reviewing!")
st.caption("github: https://github.com/Hellisotherpeople/DebateKG")
form = st.sidebar.form("Main Settings")
form.header("Main Settings")
number_of_paths = form.number_input("Enter the cutoff number of paths for all shortest path search", value = 4)
highlight_threshold = form.number_input("Enter the minimum similarity value needed to highlight" , value = 4)
show_extract = form.checkbox("Show extracts", value = False)
show_abstract = form.checkbox("Show abstract", value = False)
show_full_doc = form.checkbox("Show full doc", value = False)
show_citation = form.checkbox("Show citation", value = False)
rerank_word = form.text_area("Enter the word", value = "Full-Document")
rerank_topic = form.text_area("Enter the topic", value = "Full-Document")
form.form_submit_button("Submit")
dataset = load_dataset("Hellisotherpeople/DebateSum", split = "train")
seg = pysbd.Segmenter(language="en", clean=False)
embeddings = Embeddings({
"path": "sentence-transformers/all-mpnet-base-v2",
"content": True,
"functions": [
{"name": "graph", "function": "graph.attribute"},
],
"expressions": [
{"name": "topic", "expression": "graph(indexid, 'topic')"},
{"name": "topicrank", "expression": "graph(indexid, 'topicrank')"}
],
"graph": {
"limit": 100,
"minscore": 0.10,
"topics": {
"terms": 4,
"resolution" : 100
}
}
})
embeddings.load("DebateSum_SemanticGraph_mpnet_extract.tar.gz")
graph = embeddings.graph
def david_distance(source, target, attrs):
distance = max(1.0 - attrs["weight"], 0.0)
return distance if distance >= 0.15 else 1.00
def david_showpath(source, target, the_graph):
return nx.shortest_path(the_graph, source, target, david_distance)
import string
def highlight(index, result):
output = f"{index}. "
spans = [(token, score, "#fff59d" if score > 0.01 else None) for token, score in result["tokens"]]
for token, _, color in spans:
output += f"<span style='background-color: {color}'>{token}</span> " if color else f"{token} "
return output
def showpath_any(list_of_arguments, strip_punctuation = True, the_graph=graph.backend):
list_of_paths = []
for x, y in zip(list_of_arguments, list_of_arguments[1:]):
a_path = david_showpath(x, y, the_graph)
list_of_paths.extend(a_path)
#print(list_of_paths)
path = [graph.attribute(p, "text") for p in list_of_paths]
list_of_evidence_ids = []
for text in path:
if strip_punctuation:
text = text.translate(str.maketrans("","", string.punctuation))
list_of_evidence_ids.append(int(embeddings.search(f"select id from txtai where similar('{text}') limit 1")[0]['id']))
print(list_of_evidence_ids)
sections = []
for x, p in enumerate(path):
if x == 0:
# Print start node
sections.append(f"{x + 1}. {p}")
#sections.append(dataset["Abstract"][list_of_evidence_ids[x]])
#sections.append(dataset["Citation"][list_of_evidence_ids[x+1]])
#sections.append(dataset["Full-Document"][list_of_evidence_ids[x]])
if x < len(path) - 1:
# Explain and highlight next path element
results = embeddings.explain(p, [path[x + 1]], limit=1)[0]
sections.append(highlight(x + 2, results))
#sections.append(dataset["Abstract"][list_of_evidence_ids[x+1]])
#sections.append(dataset["Citation"][list_of_evidence_ids[x+1]])
#sections.append(dataset["Full-Document"][list_of_evidence_ids[x+1]])
return components.html("<br/><br/>".join(sections), scrolling = True, width = 800, height = 1000)
def question(text, rerank_word = "", rerank_topic = "", limit = 100):
return embeddings.search(f"select id, text, topic, evidence_id, score from txtai where similar('{text}') and text like '%{rerank_word}%' and topic like '%{rerank_topic}%' limit {limit}")
query_form = st.form("Query the Index:")
query_form.write("Write a SQL query")
query_form_submitted = query_form.form_submit_button("Click me to get ")
#showpath_any([3, 12, 15])
with st.expander("mine", expanded = False):
st.write(embeddings.search(f"select * from txtai where similar('you') and text like '%the%' limit 10"))