File size: 4,084 Bytes
70303d6
 
390b2d8
70303d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390b2d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from logging import disable
from pkg_resources import EggMetadata
import streamlit as st
import streamlit.components.v1 as components
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network
from streamlit.state.session_state import SessionState
from streamlit.type_util import Key
import rebel
import wikipedia

network_filename = "test.html"

state_variables = {
    'has_run':False,
    'wiki_suggestions': "",
    'wiki_text' : [],
    'nodes':[]
}

for k, v in state_variables.items():
    if k not in st.session_state:
        st.session_state[k] = v

def clip_text(t, lenght = 5):
    return ".".join(t.split(".")[:lenght]) + "."



def generate_graph():
    if 'wiki_text' not in st.session_state:
        return
    if len(st.session_state['wiki_text']) == 0:
        st.error("please enter a topic and select a wiki page first")
        return
    with st.spinner(text="Generating graph..."):
        texts = st.session_state['wiki_text']
        nodes = rebel.generate_knowledge_graph(texts, network_filename)
        st.session_state['nodes'] = nodes
        st.session_state['has_run'] = True
    st.success('Done!')

def show_suggestion():
    with st.spinner(text="fetching wiki topics..."):
        if st.session_state['input_method'] == "wikipedia":
            text = st.session_state.text
            if text is not None:
                st.session_state['wiki_suggestions'] = wikipedia.search(text, results = 3)

def show_wiki_text(page_title):
    with st.spinner(text="fetching wiki page..."):
        try:
            page = wikipedia.page(title=page_title, auto_suggest=False)
            st.session_state['wiki_text'].append(clip_text(page.summary))
        except wikipedia.DisambiguationError as e:
            with st.spinner(text="Woops, ambigious term, recalculating options..."):
                st.session_state['wiki_suggestions'].remove(page_title)
                temp = st.session_state['wiki_suggestions'] + e.options[:3]
                st.session_state['wiki_suggestions'] = list(set(temp))

def add_text(term):
    try:
        extra_text = clip_text(wikipedia.page(title=term, auto_suggest=True).summary)
        st.session_state['wiki_text'].append(extra_text)
    except wikipedia.DisambiguationError as e:
        st.session_state["nodes"].remove(term)


def reset_session():
    for k in state_variables:
        del st.session_state[k]

st.title('REBELious knowledge graph generation')
st.session_state['input_method'] = "wikipedia"

# st.selectbox(
#      'input method',
#      ('wikipedia', 'free text'),  key="input_method")

if st.session_state['input_method'] != "wikipedia":
    st.text_area("Your text", key="text")
else:
    st.text_input("wikipedia search term",on_change=show_suggestion, key="text")

if len(st.session_state['wiki_suggestions']) != 0:
    columns = st.columns([1] * len(st.session_state['wiki_suggestions']))
    for i, (c, s) in enumerate(zip(columns, st.session_state['wiki_suggestions'])):
        with c:
            st.button(s, on_click=show_wiki_text, args=(s,), key=i)

if len(st.session_state['wiki_text']) != 0:
    for t in st.session_state['wiki_text']:
        new_expander = st.expander(label=t[:30] + "...")
        with new_expander:
            st.markdown(t)

if st.session_state['input_method'] != "wikipedia":
    st.button("find wiki pages")
    if "wiki_suggestions" in st.session_state:
            st.button("generate", on_click=generate_graph, key="gen_graph")

else:
    st.button("generate", on_click=generate_graph, key="gen_graph2")


if st.session_state['has_run']:
    cols = st.columns([4, 1])
    with cols[0]:
        HtmlFile = open(network_filename, 'r', encoding='utf-8')
        source_code = HtmlFile.read()
        components.html(source_code, height=1500,width=1500)
    with cols[1]:
        st.text("expand")
        for s in st.session_state["nodes"]:
            st.button(s, on_click=add_text, args=(s,))