Spaces:
Sleeping
Sleeping
import pysbd | |
from txtai.embeddings import Embeddings | |
import networkx as nx | |
from tqdm import tqdm | |
from txtai.graph import GraphFactory | |
from datasets import load_dataset | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import string | |
st.set_page_config(page_title="DebateKG") | |
st.title("DebateKG - Automatic Policy Debate Case Creation") | |
st.caption("github: https://github.com/Hellisotherpeople/DebateKG") | |
form = st.sidebar.form("Main Settings") | |
form.header("Main Settings") | |
highlight_threshold = form.number_input("Enter the minimum similarity value needed to highlight" , value = 0.05) | |
show_extract = form.checkbox("Show extracts", value = True) | |
show_abstract = form.checkbox("Show abstract", value = False) | |
show_full_doc = form.checkbox("Show full doc", value = False) | |
show_citation = form.checkbox("Show citation", value = True) | |
rerank_word = form.text_input("(Optional) Constrain all evidence in the case to have this word within its text", value = "") | |
form.caption("Doing this may create graphs which are so constrained that DebateKG can't find a valid path in the graph to build a case") | |
html_window_width = form.number_input("Enter the pixel width of the output debate case window", value = 1000) | |
html_window_height = form.number_input("Enter the pixel height of the output debate case window", value = 1000) | |
option = form.selectbox( | |
'Which Knowledge Graph do you want to use?', | |
('DebateSum_SemanticGraph_longformer_extract.tar.gz', 'DebateSum_SemanticGraph_longformer_abstract.tar.gz', 'DebateSum_SemanticGraph_mpnet_abstract.tar.gz', 'DebateSum_SemanticGraph_legalbert_abstract.tar.gz', 'DebateSum_SemanticGraph_legalbert_extract.tar.gz', 'DebateSum_SemanticGraph_mpnet_extract.tar.gz', 'DebateSum_SemanticGraph_mpnet_sentence.tar.gz'), index = 2) | |
form.form_submit_button("Change Settings") | |
def load_my_dataset(): | |
dataset = load_dataset("Hellisotherpeople/DebateSum", split = "train") | |
return dataset | |
def load_embeddings(): | |
embeddings = Embeddings({ | |
"path": "sentence-transformers/all-mpnet-base-v2", | |
"content": True, | |
"functions": [ | |
{"name": "graph", "function": "graph.attribute"}, | |
], | |
"expressions": [ | |
{"name": "topic", "expression": "graph(indexid, 'topic')"}, | |
{"name": "topicrank", "expression": "graph(indexid, 'topicrank')"} | |
], | |
"graph": { | |
"limit": 100, | |
"minscore": 0.10, | |
"topics": { | |
"terms": 4, | |
"resolution" : 100 | |
} | |
} | |
}) | |
embeddings.load(option) | |
return embeddings | |
dataset = load_my_dataset() | |
embeddings = load_embeddings() | |
graph = embeddings.graph | |
def david_distance(source, target, attrs): | |
distance = max(1.0 - attrs["weight"], 0.0) | |
return distance if distance >= 0.15 else 1.00 | |
def david_showpath(source, target, the_graph): | |
return nx.shortest_path(the_graph, source, target, david_distance) | |
def david_show_all_paths(source, target, the_graph): | |
return nx.all_shortest_paths(the_graph, source, target, david_distance) | |
def highlight(index, result): | |
output = f"{index}. " | |
spans = [(token, score, "#fff59d" if score > highlight_threshold else None) for token, score in result["tokens"]] | |
for token, _, color in spans: | |
output += f"<span style='background-color: {color}'>{token}</span> " if color else f"{token} " | |
return output | |
def showpath_any(list_of_arguments, strip_punctuation = True, the_graph=graph.backend): | |
list_of_paths = [] | |
for x, y in zip(list_of_arguments, list_of_arguments[1:]): | |
a_path = david_showpath(x, y, the_graph) | |
list_of_paths.extend(a_path) | |
#print(list_of_paths) | |
path = [graph.attribute(p, "text") for p in list_of_paths] | |
list_of_evidence_ids = [] | |
for text in path: | |
if strip_punctuation: | |
text = text.translate(str.maketrans("","", string.punctuation)) | |
list_of_evidence_ids.append(int(embeddings.search(f"select id from txtai where similar('{text}') limit 1")[0]['id'])) | |
sections = [] | |
#sections.append(list_of_evidence_ids) | |
for x, p in enumerate(path): | |
if x == 0: | |
# Print start node | |
sections.append(f"{x + 1}. {p}") | |
if show_abstract: | |
sections.append(dataset["Abstract"][list_of_evidence_ids[x]]) | |
if show_citation: | |
sections.append(dataset["Citation"][list_of_evidence_ids[x]]) | |
if show_extract: | |
sections.append(dataset["Extract"][list_of_evidence_ids[x]]) | |
if show_full_doc: | |
sections.append(dataset["Full-Document"][list_of_evidence_ids[x]]) | |
if x < len(path) - 1: | |
# Explain and highlight next path element | |
results = embeddings.explain(p, [path[x + 1]], limit=1)[0] | |
sections.append(highlight(x + 2, results)) | |
if show_abstract: | |
sections.append(dataset["Abstract"][list_of_evidence_ids[x+1]]) | |
if show_citation: | |
sections.append(dataset["Citation"][list_of_evidence_ids[x+1]]) | |
if show_extract: | |
sections.append(dataset["Extract"][list_of_evidence_ids[x+1]]) | |
if show_full_doc: | |
sections.append(dataset["Full-Document"][list_of_evidence_ids[x+1]]) | |
return components.html("<br/><br/>".join(sections), scrolling = True, width = html_window_width, height = html_window_height) | |
def question(text, rerank_word = "", rerank_topic = "", limit = 100): | |
return embeddings.search(f"select id, text, topic, evidence_id, score from txtai where similar('{text}') and text like '%{rerank_word}%' and topic like '%{rerank_topic}%' limit {limit}") | |
query_form = st.form("Query the Index:") | |
query_form.write("Step 1: Find Arguments") | |
query_form.write("Use semantic SQL from txtai to find some arguments, we use indexids to keep track of them.") | |
query_form.caption("You can use the semantic SQL to explore the dataset too! The possibilities are limitless!") | |
query_sql = query_form.text_area("Enter a semantic SQL statement", value = f"select topic, * from txtai where similar('Trump and US relations with China') and topic like '%trump%' and text like '%china%' limit 1") | |
query_form_submitted = query_form.form_submit_button("Query") | |
if query_form_submitted: | |
with st.expander("Output (Open Me)", expanded = False): | |
#my_path = showpath_any([170750, 50, 23]) | |
#st.write(embeddings.search(f"select * from txtai where similar('you') and text like '%the%' limit 10")) | |
st.write(embeddings.search(query_sql)) | |
paths_form = st.form("Build the Arguments") | |
paths_form.write("Step 2: Build a Policy Debate Case") | |
paths_form.write("Enter any number of indexids (arguments), DebateKG will build a debate case out of it which links them all together") | |
user_paths_string = paths_form.text_area("Enter a list of indexids seperated by whitespace", value = "250 10000 2405") | |
user_paths_list_of_strings = user_paths_string.split() | |
user_paths_list = list(map(int, user_paths_list_of_strings)) | |
paths_form_submitted = paths_form.form_submit_button("Build a Policy Debate Case") | |
if paths_form_submitted: | |
if rerank_word: | |
selected_nodes = [n for n,v in graph.backend.nodes(data=True) if rerank_word in v['text']] ##also works for topic | |
H = graph.backend.subgraph(selected_nodes) | |
showpath_any(user_paths_list, the_graph = H) | |
else: | |
showpath_any(user_paths_list) | |