Spaces:
Running
Running
Reduce outliers now more efficient and relabels with correct vectoriser. Default topic labels now tidier. Hiearchical topics outputs more useful for joining to df afterwards. Switched low resource reduction algorithm to UMAP as default is not good.
Browse files- app.py +5 -5
- funcs/bertopic_hierarchical_documents.py +0 -336
- funcs/bertopic_hierarchical_documents_to_df.py +0 -250
- funcs/bertopic_vis_documents.py +377 -13
- funcs/clean_funcs.py +7 -6
- funcs/topic_core_funcs.py +60 -31
app.py
CHANGED
@@ -7,7 +7,7 @@ import pandas as pd
|
|
7 |
import numpy as np
|
8 |
|
9 |
from funcs.topic_core_funcs import pre_clean, extract_topics, reduce_outliers, represent_topics, visualise_topics, save_as_pytorch_model
|
10 |
-
from funcs.helper_functions import
|
11 |
from sklearn.feature_extraction.text import CountVectorizer
|
12 |
|
13 |
# Gradio app
|
@@ -20,6 +20,7 @@ with block:
|
|
20 |
embeddings_state = gr.State(np.array([]))
|
21 |
embeddings_type_state = gr.State("")
|
22 |
topic_model_state = gr.State()
|
|
|
23 |
custom_regex_state = gr.State(pd.DataFrame())
|
24 |
docs_state = gr.State()
|
25 |
data_file_name_no_ext_state = gr.State()
|
@@ -104,23 +105,22 @@ with block:
|
|
104 |
|
105 |
# Load in data. Update column names dropdown when file uploaded
|
106 |
in_files.upload(fn=initial_file_load, inputs=[in_files], outputs=[in_colnames, in_label, data_state, output_single_text, topic_model_state, embeddings_state, data_file_name_no_ext_state, label_list_state])
|
107 |
-
in_colnames.change(dummy_function, in_colnames, None)
|
108 |
|
109 |
# Clean data
|
110 |
custom_regex.upload(fn=custom_regex_load, inputs=[custom_regex], outputs=[custom_regex_text, custom_regex_state])
|
111 |
clean_btn.click(fn=pre_clean, inputs=[data_state, in_colnames, data_file_name_no_ext_state, custom_regex_state, clean_text, drop_duplicate_text, anonymise_drop], outputs=[output_single_text, output_file, data_state, data_file_name_no_ext_state], api_name="clean")
|
112 |
|
113 |
# Extract topics
|
114 |
-
topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext_state, label_list_state, return_intermediate_files, embedding_super_compress, low_resource_mode_opt, save_topic_model, embeddings_state, embeddings_type_state, zero_shot_similarity, seed_number, calc_probs, vectoriser_state], outputs=[output_single_text, output_file, embeddings_state, embeddings_type_state, data_file_name_no_ext_state, topic_model_state, docs_state, vectoriser_state], api_name="topics")
|
115 |
|
116 |
# Reduce outliers
|
117 |
-
reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="reduce_outliers")
|
118 |
|
119 |
# Re-represent topic labels
|
120 |
represent_llm_btn.click(fn=represent_topics, inputs=[topic_model_state, docs_state, data_file_name_no_ext_state, low_resource_mode_opt, save_topic_model, representation_type, vectoriser_state], outputs=[output_single_text, output_file, topic_model_state], api_name="represent_llm")
|
121 |
|
122 |
# Save in Pytorch format
|
123 |
-
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
|
124 |
|
125 |
# Visualise topics
|
126 |
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
|
|
7 |
import numpy as np
|
8 |
|
9 |
from funcs.topic_core_funcs import pre_clean, extract_topics, reduce_outliers, represent_topics, visualise_topics, save_as_pytorch_model
|
10 |
+
from funcs.helper_functions import initial_file_load, custom_regex_load
|
11 |
from sklearn.feature_extraction.text import CountVectorizer
|
12 |
|
13 |
# Gradio app
|
|
|
20 |
embeddings_state = gr.State(np.array([]))
|
21 |
embeddings_type_state = gr.State("")
|
22 |
topic_model_state = gr.State()
|
23 |
+
assigned_topics_state = gr.State([])
|
24 |
custom_regex_state = gr.State(pd.DataFrame())
|
25 |
docs_state = gr.State()
|
26 |
data_file_name_no_ext_state = gr.State()
|
|
|
105 |
|
106 |
# Load in data. Update column names dropdown when file uploaded
|
107 |
in_files.upload(fn=initial_file_load, inputs=[in_files], outputs=[in_colnames, in_label, data_state, output_single_text, topic_model_state, embeddings_state, data_file_name_no_ext_state, label_list_state])
|
|
|
108 |
|
109 |
# Clean data
|
110 |
custom_regex.upload(fn=custom_regex_load, inputs=[custom_regex], outputs=[custom_regex_text, custom_regex_state])
|
111 |
clean_btn.click(fn=pre_clean, inputs=[data_state, in_colnames, data_file_name_no_ext_state, custom_regex_state, clean_text, drop_duplicate_text, anonymise_drop], outputs=[output_single_text, output_file, data_state, data_file_name_no_ext_state], api_name="clean")
|
112 |
|
113 |
# Extract topics
|
114 |
+
topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext_state, label_list_state, return_intermediate_files, embedding_super_compress, low_resource_mode_opt, save_topic_model, embeddings_state, embeddings_type_state, zero_shot_similarity, seed_number, calc_probs, vectoriser_state], outputs=[output_single_text, output_file, embeddings_state, embeddings_type_state, data_file_name_no_ext_state, topic_model_state, docs_state, vectoriser_state, assigned_topics_state], api_name="topics")
|
115 |
|
116 |
# Reduce outliers
|
117 |
+
reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, assigned_topics_state, vectoriser_state, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="reduce_outliers")
|
118 |
|
119 |
# Re-represent topic labels
|
120 |
represent_llm_btn.click(fn=represent_topics, inputs=[topic_model_state, docs_state, data_file_name_no_ext_state, low_resource_mode_opt, save_topic_model, representation_type, vectoriser_state], outputs=[output_single_text, output_file, topic_model_state], api_name="represent_llm")
|
121 |
|
122 |
# Save in Pytorch format
|
123 |
+
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file], api_name="pytorch_save")
|
124 |
|
125 |
# Visualise topics
|
126 |
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, legend_label, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
funcs/bertopic_hierarchical_documents.py
DELETED
@@ -1,336 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import pandas as pd
|
3 |
-
import plotly.graph_objects as go
|
4 |
-
import math
|
5 |
-
|
6 |
-
from umap import UMAP
|
7 |
-
from typing import List, Union
|
8 |
-
|
9 |
-
|
10 |
-
def visualize_hierarchical_documents(topic_model,
|
11 |
-
docs: List[str],
|
12 |
-
hierarchical_topics: pd.DataFrame,
|
13 |
-
topics: List[int] = None,
|
14 |
-
embeddings: np.ndarray = None,
|
15 |
-
reduced_embeddings: np.ndarray = None,
|
16 |
-
sample: Union[float, int] = None,
|
17 |
-
hide_annotations: bool = False,
|
18 |
-
hide_document_hover: bool = True,
|
19 |
-
nr_levels: int = 10,
|
20 |
-
level_scale: str = 'linear',
|
21 |
-
custom_labels: Union[bool, str] = False,
|
22 |
-
title: str = "<b>Hierarchical Documents and Topics</b>",
|
23 |
-
width: int = 1200,
|
24 |
-
height: int = 750) -> go.Figure:
|
25 |
-
""" Visualize documents and their topics in 2D at different levels of hierarchy
|
26 |
-
|
27 |
-
Arguments:
|
28 |
-
docs: The documents you used when calling either `fit` or `fit_transform`
|
29 |
-
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
30 |
-
represented by their parents and their children
|
31 |
-
topics: A selection of topics to visualize.
|
32 |
-
Not to be confused with the topics that you get from `.fit_transform`.
|
33 |
-
For example, if you want to visualize only topics 1 through 5:
|
34 |
-
`topics = [1, 2, 3, 4, 5]`.
|
35 |
-
embeddings: The embeddings of all documents in `docs`.
|
36 |
-
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
|
37 |
-
sample: The percentage of documents in each topic that you would like to keep.
|
38 |
-
Value can be between 0 and 1. Setting this value to, for example,
|
39 |
-
0.1 (10% of documents in each topic) makes it easier to visualize
|
40 |
-
millions of documents as a subset is chosen.
|
41 |
-
hide_annotations: Hide the names of the traces on top of each cluster.
|
42 |
-
hide_document_hover: Hide the content of the documents when hovering over
|
43 |
-
specific points. Helps to speed up generation of visualizations.
|
44 |
-
nr_levels: The number of levels to be visualized in the hierarchy. First, the distances
|
45 |
-
in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances.
|
46 |
-
Then, for each list of distances, the merged topics are selected that have a
|
47 |
-
distance less or equal to the maximum distance of the selected list of distances.
|
48 |
-
NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to
|
49 |
-
the length of `hierarchical_topics`.
|
50 |
-
level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance
|
51 |
-
vector. Linear scaling will perform an equal number of merges at each level
|
52 |
-
while logarithmic scaling will perform more mergers in earlier levels to
|
53 |
-
provide more resolution at higher levels (this can be used for when the number
|
54 |
-
of topics is large).
|
55 |
-
custom_labels: If bool, whether to use custom topic labels that were defined using
|
56 |
-
`topic_model.set_topic_labels`.
|
57 |
-
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
58 |
-
NOTE: Custom labels are only generated for the original
|
59 |
-
un-merged topics.
|
60 |
-
title: Title of the plot.
|
61 |
-
width: The width of the figure.
|
62 |
-
height: The height of the figure.
|
63 |
-
|
64 |
-
Examples:
|
65 |
-
|
66 |
-
To visualize the topics simply run:
|
67 |
-
|
68 |
-
```python
|
69 |
-
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics)
|
70 |
-
```
|
71 |
-
|
72 |
-
Do note that this re-calculates the embeddings and reduces them to 2D.
|
73 |
-
The advised and prefered pipeline for using this function is as follows:
|
74 |
-
|
75 |
-
```python
|
76 |
-
from sklearn.datasets import fetch_20newsgroups
|
77 |
-
from sentence_transformers import SentenceTransformer
|
78 |
-
from bertopic import BERTopic
|
79 |
-
from umap import UMAP
|
80 |
-
|
81 |
-
# Prepare embeddings
|
82 |
-
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
|
83 |
-
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
84 |
-
embeddings = sentence_model.encode(docs, show_progress_bar=False)
|
85 |
-
|
86 |
-
# Train BERTopic and extract hierarchical topics
|
87 |
-
topic_model = BERTopic().fit(docs, embeddings)
|
88 |
-
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
89 |
-
|
90 |
-
# Reduce dimensionality of embeddings, this step is optional
|
91 |
-
# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
|
92 |
-
|
93 |
-
# Run the visualization with the original embeddings
|
94 |
-
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings)
|
95 |
-
|
96 |
-
# Or, if you have reduced the original embeddings already:
|
97 |
-
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
|
98 |
-
```
|
99 |
-
|
100 |
-
Or if you want to save the resulting figure:
|
101 |
-
|
102 |
-
```python
|
103 |
-
fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
|
104 |
-
fig.write_html("path/to/file.html")
|
105 |
-
```
|
106 |
-
|
107 |
-
NOTE:
|
108 |
-
This visualization was inspired by the scatter plot representation of Doc2Map:
|
109 |
-
https://github.com/louisgeisler/Doc2Map
|
110 |
-
|
111 |
-
<iframe src="../../getting_started/visualization/hierarchical_documents.html"
|
112 |
-
style="width:1000px; height: 770px; border: 0px;""></iframe>
|
113 |
-
"""
|
114 |
-
topic_per_doc = topic_model.topics_
|
115 |
-
|
116 |
-
# Sample the data to optimize for visualization and dimensionality reduction
|
117 |
-
if sample is None or sample > 1:
|
118 |
-
sample = 1
|
119 |
-
|
120 |
-
indices = []
|
121 |
-
for topic in set(topic_per_doc):
|
122 |
-
s = np.where(np.array(topic_per_doc) == topic)[0]
|
123 |
-
size = len(s) if len(s) < 100 else int(len(s)*sample)
|
124 |
-
indices.extend(np.random.choice(s, size=size, replace=False))
|
125 |
-
indices = np.array(indices)
|
126 |
-
|
127 |
-
df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
|
128 |
-
df["doc"] = [docs[index] for index in indices]
|
129 |
-
df["topic"] = [topic_per_doc[index] for index in indices]
|
130 |
-
|
131 |
-
# Extract embeddings if not already done
|
132 |
-
if sample is None:
|
133 |
-
if embeddings is None and reduced_embeddings is None:
|
134 |
-
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
135 |
-
else:
|
136 |
-
embeddings_to_reduce = embeddings
|
137 |
-
else:
|
138 |
-
if embeddings is not None:
|
139 |
-
embeddings_to_reduce = embeddings[indices]
|
140 |
-
elif embeddings is None and reduced_embeddings is None:
|
141 |
-
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
142 |
-
|
143 |
-
# Reduce input embeddings
|
144 |
-
if reduced_embeddings is None:
|
145 |
-
umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
|
146 |
-
embeddings_2d = umap_model.embedding_
|
147 |
-
elif sample is not None and reduced_embeddings is not None:
|
148 |
-
embeddings_2d = reduced_embeddings[indices]
|
149 |
-
elif sample is None and reduced_embeddings is not None:
|
150 |
-
embeddings_2d = reduced_embeddings
|
151 |
-
|
152 |
-
# Combine data
|
153 |
-
df["x"] = embeddings_2d[:, 0]
|
154 |
-
df["y"] = embeddings_2d[:, 1]
|
155 |
-
|
156 |
-
# Create topic list for each level, levels are created by calculating the distance
|
157 |
-
distances = hierarchical_topics.Distance.to_list()
|
158 |
-
if level_scale == 'log' or level_scale == 'logarithmic':
|
159 |
-
log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist()
|
160 |
-
log_indices.reverse()
|
161 |
-
max_distances = [distances[i] for i in log_indices]
|
162 |
-
elif level_scale == 'lin' or level_scale == 'linear':
|
163 |
-
max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1]
|
164 |
-
else:
|
165 |
-
raise ValueError("level_scale needs to be one of 'log' or 'linear'")
|
166 |
-
|
167 |
-
for index, max_distance in enumerate(max_distances):
|
168 |
-
|
169 |
-
# Get topics below `max_distance`
|
170 |
-
mapping = {topic: topic for topic in df.topic.unique()}
|
171 |
-
selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :]
|
172 |
-
selection.Parent_ID = selection.Parent_ID.astype(int)
|
173 |
-
selection = selection.sort_values("Parent_ID")
|
174 |
-
|
175 |
-
for row in selection.iterrows():
|
176 |
-
for topic in row[1].Topics:
|
177 |
-
mapping[topic] = row[1].Parent_ID
|
178 |
-
|
179 |
-
# Make sure the mappings are mapped 1:1
|
180 |
-
mappings = [True for _ in mapping]
|
181 |
-
while any(mappings):
|
182 |
-
for i, (key, value) in enumerate(mapping.items()):
|
183 |
-
if value in mapping.keys() and key != value:
|
184 |
-
mapping[key] = mapping[value]
|
185 |
-
else:
|
186 |
-
mappings[i] = False
|
187 |
-
|
188 |
-
# Create new column
|
189 |
-
df[f"level_{index+1}"] = df.topic.map(mapping)
|
190 |
-
df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
|
191 |
-
|
192 |
-
# Prepare topic names of original and merged topics
|
193 |
-
trace_names = []
|
194 |
-
topic_names = {}
|
195 |
-
for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
|
196 |
-
if topic < hierarchical_topics.Parent_ID.astype(int).min():
|
197 |
-
if topic_model.get_topic(topic):
|
198 |
-
if isinstance(custom_labels, str):
|
199 |
-
trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
|
200 |
-
elif topic_model.custom_labels_ is not None and custom_labels:
|
201 |
-
trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
|
202 |
-
else:
|
203 |
-
trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
|
204 |
-
topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
|
205 |
-
trace_names.append(trace_name)
|
206 |
-
else:
|
207 |
-
trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
|
208 |
-
plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
|
209 |
-
topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
|
210 |
-
trace_names.append(trace_name)
|
211 |
-
|
212 |
-
# Prepare traces
|
213 |
-
all_traces = []
|
214 |
-
for level in range(len(max_distances)):
|
215 |
-
traces = []
|
216 |
-
|
217 |
-
# Outliers
|
218 |
-
if topic_model._outliers:
|
219 |
-
traces.append(
|
220 |
-
go.Scattergl(
|
221 |
-
x=df.loc[(df[f"level_{level+1}"] == -1), "x"],
|
222 |
-
y=df.loc[df[f"level_{level+1}"] == -1, "y"],
|
223 |
-
mode='markers+text',
|
224 |
-
name="other",
|
225 |
-
hoverinfo="text",
|
226 |
-
hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] if not hide_document_hover else None,
|
227 |
-
showlegend=False,
|
228 |
-
marker=dict(color='#CFD8DC', size=5, opacity=0.5)
|
229 |
-
)
|
230 |
-
)
|
231 |
-
|
232 |
-
# Selected topics
|
233 |
-
if topics:
|
234 |
-
selection = df.loc[(df.topic.isin(topics)), :]
|
235 |
-
unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
|
236 |
-
else:
|
237 |
-
unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
|
238 |
-
|
239 |
-
for topic in unique_topics:
|
240 |
-
if topic != -1:
|
241 |
-
if topics:
|
242 |
-
selection = df.loc[(df[f"level_{level+1}"] == topic) &
|
243 |
-
(df.topic.isin(topics)), :]
|
244 |
-
else:
|
245 |
-
selection = df.loc[df[f"level_{level+1}"] == topic, :]
|
246 |
-
|
247 |
-
if not hide_annotations:
|
248 |
-
selection.loc[len(selection), :] = None
|
249 |
-
selection["text"] = ""
|
250 |
-
selection.loc[len(selection) - 1, "x"] = selection.x.mean()
|
251 |
-
selection.loc[len(selection) - 1, "y"] = selection.y.mean()
|
252 |
-
selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"]
|
253 |
-
|
254 |
-
traces.append(
|
255 |
-
go.Scattergl(
|
256 |
-
x=selection.x,
|
257 |
-
y=selection.y,
|
258 |
-
text=selection.text if not hide_annotations else None,
|
259 |
-
hovertext=selection.doc if not hide_document_hover else None,
|
260 |
-
hoverinfo="text",
|
261 |
-
name=topic_names[int(topic)]["trace_name"],
|
262 |
-
mode='markers+text',
|
263 |
-
marker=dict(size=5, opacity=0.5)
|
264 |
-
)
|
265 |
-
)
|
266 |
-
|
267 |
-
all_traces.append(traces)
|
268 |
-
|
269 |
-
# Track and count traces
|
270 |
-
nr_traces_per_set = [len(traces) for traces in all_traces]
|
271 |
-
trace_indices = [(0, nr_traces_per_set[0])]
|
272 |
-
for index, nr_traces in enumerate(nr_traces_per_set[1:]):
|
273 |
-
start = trace_indices[index][1]
|
274 |
-
end = nr_traces + start
|
275 |
-
trace_indices.append((start, end))
|
276 |
-
|
277 |
-
# Visualization
|
278 |
-
fig = go.Figure()
|
279 |
-
for traces in all_traces:
|
280 |
-
for trace in traces:
|
281 |
-
fig.add_trace(trace)
|
282 |
-
|
283 |
-
for index in range(len(fig.data)):
|
284 |
-
if index >= nr_traces_per_set[0]:
|
285 |
-
fig.data[index].visible = False
|
286 |
-
|
287 |
-
# Create and add slider
|
288 |
-
steps = []
|
289 |
-
for index, indices in enumerate(trace_indices):
|
290 |
-
step = dict(
|
291 |
-
method="update",
|
292 |
-
label=str(index),
|
293 |
-
args=[{"visible": [False] * len(fig.data)}]
|
294 |
-
)
|
295 |
-
for index in range(indices[1]-indices[0]):
|
296 |
-
step["args"][0]["visible"][index+indices[0]] = True
|
297 |
-
steps.append(step)
|
298 |
-
|
299 |
-
sliders = [dict(
|
300 |
-
currentvalue={"prefix": "Level: "},
|
301 |
-
pad={"t": 20},
|
302 |
-
steps=steps
|
303 |
-
)]
|
304 |
-
|
305 |
-
# Add grid in a 'plus' shape
|
306 |
-
x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
|
307 |
-
y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
|
308 |
-
fig.add_shape(type="line",
|
309 |
-
x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
|
310 |
-
line=dict(color="#CFD8DC", width=2))
|
311 |
-
fig.add_shape(type="line",
|
312 |
-
x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
|
313 |
-
line=dict(color="#9E9E9E", width=2))
|
314 |
-
fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
|
315 |
-
fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
|
316 |
-
|
317 |
-
# Stylize layout
|
318 |
-
fig.update_layout(
|
319 |
-
sliders=sliders,
|
320 |
-
template="simple_white",
|
321 |
-
title={
|
322 |
-
'text': f"{title}",
|
323 |
-
'x': 0.5,
|
324 |
-
'xanchor': 'center',
|
325 |
-
'yanchor': 'top',
|
326 |
-
'font': dict(
|
327 |
-
size=22,
|
328 |
-
color="Black")
|
329 |
-
},
|
330 |
-
width=width,
|
331 |
-
height=height,
|
332 |
-
)
|
333 |
-
|
334 |
-
fig.update_xaxes(visible=False)
|
335 |
-
fig.update_yaxes(visible=False)
|
336 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
funcs/bertopic_hierarchical_documents_to_df.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import pandas as pd
|
3 |
-
import plotly.graph_objects as go
|
4 |
-
import math
|
5 |
-
|
6 |
-
from umap import UMAP
|
7 |
-
from typing import List, Union
|
8 |
-
|
9 |
-
|
10 |
-
def visualize_hierarchical_documents_to_df(topic_model,
|
11 |
-
docs: List[str],
|
12 |
-
hierarchical_topics: pd.DataFrame,
|
13 |
-
topics: List[int] = None,
|
14 |
-
embeddings: np.ndarray = None,
|
15 |
-
reduced_embeddings: np.ndarray = None,
|
16 |
-
sample: Union[float, int] = None,
|
17 |
-
hide_annotations: bool = False,
|
18 |
-
hide_document_hover: bool = True,
|
19 |
-
nr_levels: int = 10,
|
20 |
-
level_scale: str = 'linear',
|
21 |
-
custom_labels: Union[bool, str] = False,
|
22 |
-
title: str = "<b>Hierarchical Documents and Topics</b>",
|
23 |
-
width: int = 1200,
|
24 |
-
height: int = 750) -> go.Figure:
|
25 |
-
""" Visualize documents and their topics in 2D at different levels of hierarchy
|
26 |
-
|
27 |
-
Arguments:
|
28 |
-
docs: The documents you used when calling either `fit` or `fit_transform`
|
29 |
-
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
30 |
-
represented by their parents and their children
|
31 |
-
topics: A selection of topics to visualize.
|
32 |
-
Not to be confused with the topics that you get from `.fit_transform`.
|
33 |
-
For example, if you want to visualize only topics 1 through 5:
|
34 |
-
`topics = [1, 2, 3, 4, 5]`.
|
35 |
-
embeddings: The embeddings of all documents in `docs`.
|
36 |
-
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
|
37 |
-
sample: The percentage of documents in each topic that you would like to keep.
|
38 |
-
Value can be between 0 and 1. Setting this value to, for example,
|
39 |
-
0.1 (10% of documents in each topic) makes it easier to visualize
|
40 |
-
millions of documents as a subset is chosen.
|
41 |
-
hide_annotations: Hide the names of the traces on top of each cluster.
|
42 |
-
hide_document_hover: Hide the content of the documents when hovering over
|
43 |
-
specific points. Helps to speed up generation of visualizations.
|
44 |
-
nr_levels: The number of levels to be visualized in the hierarchy. First, the distances
|
45 |
-
in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances.
|
46 |
-
Then, for each list of distances, the merged topics are selected that have a
|
47 |
-
distance less or equal to the maximum distance of the selected list of distances.
|
48 |
-
NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to
|
49 |
-
the length of `hierarchical_topics`.
|
50 |
-
level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance
|
51 |
-
vector. Linear scaling will perform an equal number of merges at each level
|
52 |
-
while logarithmic scaling will perform more mergers in earlier levels to
|
53 |
-
provide more resolution at higher levels (this can be used for when the number
|
54 |
-
of topics is large).
|
55 |
-
custom_labels: If bool, whether to use custom topic labels that were defined using
|
56 |
-
`topic_model.set_topic_labels`.
|
57 |
-
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
58 |
-
NOTE: Custom labels are only generated for the original
|
59 |
-
un-merged topics.
|
60 |
-
title: Title of the plot.
|
61 |
-
width: The width of the figure.
|
62 |
-
height: The height of the figure.
|
63 |
-
|
64 |
-
Examples:
|
65 |
-
|
66 |
-
To visualize the topics simply run:
|
67 |
-
|
68 |
-
```python
|
69 |
-
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics)
|
70 |
-
```
|
71 |
-
|
72 |
-
Do note that this re-calculates the embeddings and reduces them to 2D.
|
73 |
-
The advised and prefered pipeline for using this function is as follows:
|
74 |
-
|
75 |
-
```python
|
76 |
-
from sklearn.datasets import fetch_20newsgroups
|
77 |
-
from sentence_transformers import SentenceTransformer
|
78 |
-
from bertopic import BERTopic
|
79 |
-
from umap import UMAP
|
80 |
-
|
81 |
-
# Prepare embeddings
|
82 |
-
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
|
83 |
-
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
84 |
-
embeddings = sentence_model.encode(docs, show_progress_bar=False)
|
85 |
-
|
86 |
-
# Train BERTopic and extract hierarchical topics
|
87 |
-
topic_model = BERTopic().fit(docs, embeddings)
|
88 |
-
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
89 |
-
|
90 |
-
# Reduce dimensionality of embeddings, this step is optional
|
91 |
-
# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
|
92 |
-
|
93 |
-
# Run the visualization with the original embeddings
|
94 |
-
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings)
|
95 |
-
|
96 |
-
# Or, if you have reduced the original embeddings already:
|
97 |
-
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
|
98 |
-
```
|
99 |
-
|
100 |
-
Or if you want to save the resulting figure:
|
101 |
-
|
102 |
-
```python
|
103 |
-
fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
|
104 |
-
fig.write_html("path/to/file.html")
|
105 |
-
```
|
106 |
-
|
107 |
-
NOTE:
|
108 |
-
This visualization was inspired by the scatter plot representation of Doc2Map:
|
109 |
-
https://github.com/louisgeisler/Doc2Map
|
110 |
-
|
111 |
-
<iframe src="../../getting_started/visualization/hierarchical_documents.html"
|
112 |
-
style="width:1000px; height: 770px; border: 0px;""></iframe>
|
113 |
-
"""
|
114 |
-
topic_per_doc = topic_model.topics_
|
115 |
-
|
116 |
-
# Sample the data to optimize for visualization and dimensionality reduction
|
117 |
-
if sample is None or sample > 1:
|
118 |
-
sample = 1
|
119 |
-
|
120 |
-
indices = []
|
121 |
-
for topic in set(topic_per_doc):
|
122 |
-
s = np.where(np.array(topic_per_doc) == topic)[0]
|
123 |
-
size = len(s) if len(s) < 100 else int(len(s)*sample)
|
124 |
-
indices.extend(np.random.choice(s, size=size, replace=False))
|
125 |
-
indices = np.array(indices)
|
126 |
-
|
127 |
-
df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
|
128 |
-
df["doc"] = [docs[index] for index in indices]
|
129 |
-
df["topic"] = [topic_per_doc[index] for index in indices]
|
130 |
-
|
131 |
-
# Extract embeddings if not already done
|
132 |
-
if sample is None:
|
133 |
-
if embeddings is None and reduced_embeddings is None:
|
134 |
-
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
135 |
-
else:
|
136 |
-
embeddings_to_reduce = embeddings
|
137 |
-
else:
|
138 |
-
if embeddings is not None:
|
139 |
-
embeddings_to_reduce = embeddings[indices]
|
140 |
-
elif embeddings is None and reduced_embeddings is None:
|
141 |
-
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
142 |
-
|
143 |
-
# Reduce input embeddings
|
144 |
-
if reduced_embeddings is None:
|
145 |
-
umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
|
146 |
-
embeddings_2d = umap_model.embedding_
|
147 |
-
elif sample is not None and reduced_embeddings is not None:
|
148 |
-
embeddings_2d = reduced_embeddings[indices]
|
149 |
-
elif sample is None and reduced_embeddings is not None:
|
150 |
-
embeddings_2d = reduced_embeddings
|
151 |
-
|
152 |
-
# Combine data
|
153 |
-
df["x"] = embeddings_2d[:, 0]
|
154 |
-
df["y"] = embeddings_2d[:, 1]
|
155 |
-
|
156 |
-
# Create topic list for each level, levels are created by calculating the distance
|
157 |
-
distances = hierarchical_topics.Distance.to_list()
|
158 |
-
if level_scale == 'log' or level_scale == 'logarithmic':
|
159 |
-
log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist()
|
160 |
-
log_indices.reverse()
|
161 |
-
max_distances = [distances[i] for i in log_indices]
|
162 |
-
elif level_scale == 'lin' or level_scale == 'linear':
|
163 |
-
max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1]
|
164 |
-
else:
|
165 |
-
raise ValueError("level_scale needs to be one of 'log' or 'linear'")
|
166 |
-
|
167 |
-
for index, max_distance in enumerate(max_distances):
|
168 |
-
|
169 |
-
# Get topics below `max_distance`
|
170 |
-
mapping = {topic: topic for topic in df.topic.unique()}
|
171 |
-
selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :]
|
172 |
-
selection.Parent_ID = selection.Parent_ID.astype(int)
|
173 |
-
selection = selection.sort_values("Parent_ID")
|
174 |
-
|
175 |
-
for row in selection.iterrows():
|
176 |
-
for topic in row[1].Topics:
|
177 |
-
mapping[topic] = row[1].Parent_ID
|
178 |
-
|
179 |
-
# Make sure the mappings are mapped 1:1
|
180 |
-
mappings = [True for _ in mapping]
|
181 |
-
while any(mappings):
|
182 |
-
for i, (key, value) in enumerate(mapping.items()):
|
183 |
-
if value in mapping.keys() and key != value:
|
184 |
-
mapping[key] = mapping[value]
|
185 |
-
else:
|
186 |
-
mappings[i] = False
|
187 |
-
|
188 |
-
# Create new column
|
189 |
-
df[f"level_{index+1}"] = df.topic.map(mapping)
|
190 |
-
df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
|
191 |
-
|
192 |
-
# Prepare topic names of original and merged topics
|
193 |
-
trace_names = []
|
194 |
-
topic_names = {}
|
195 |
-
for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
|
196 |
-
if topic < hierarchical_topics.Parent_ID.astype(int).min():
|
197 |
-
if topic_model.get_topic(topic):
|
198 |
-
if isinstance(custom_labels, str):
|
199 |
-
trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
|
200 |
-
elif topic_model.custom_labels_ is not None and custom_labels:
|
201 |
-
trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
|
202 |
-
else:
|
203 |
-
trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
|
204 |
-
topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
|
205 |
-
trace_names.append(trace_name)
|
206 |
-
else:
|
207 |
-
trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
|
208 |
-
plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
|
209 |
-
topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
|
210 |
-
trace_names.append(trace_name)
|
211 |
-
|
212 |
-
# Prepare traces
|
213 |
-
all_traces = []
|
214 |
-
for level in range(len(max_distances)):
|
215 |
-
traces = []
|
216 |
-
|
217 |
-
# Selected topics
|
218 |
-
if topics:
|
219 |
-
selection = df.loc[(df.topic.isin(topics)), :]
|
220 |
-
unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
|
221 |
-
else:
|
222 |
-
unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
|
223 |
-
|
224 |
-
for topic in unique_topics:
|
225 |
-
if topic != -1:
|
226 |
-
if topics:
|
227 |
-
selection = df.loc[(df[f"level_{level+1}"] == topic) &
|
228 |
-
(df.topic.isin(topics)), :]
|
229 |
-
else:
|
230 |
-
selection = df.loc[df[f"level_{level+1}"] == topic, :]
|
231 |
-
|
232 |
-
if not hide_annotations:
|
233 |
-
selection.loc[len(selection), :] = None
|
234 |
-
selection["text"] = ""
|
235 |
-
selection.loc[len(selection) - 1, "x"] = selection.x.mean()
|
236 |
-
selection.loc[len(selection) - 1, "y"] = selection.y.mean()
|
237 |
-
selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"]
|
238 |
-
|
239 |
-
all_traces.append(traces)
|
240 |
-
|
241 |
-
# Track and count traces
|
242 |
-
nr_traces_per_set = [len(traces) for traces in all_traces]
|
243 |
-
trace_indices = [(0, nr_traces_per_set[0])]
|
244 |
-
for index, nr_traces in enumerate(nr_traces_per_set[1:]):
|
245 |
-
start = trace_indices[index][1]
|
246 |
-
end = nr_traces + start
|
247 |
-
trace_indices.append((start, end))
|
248 |
-
|
249 |
-
|
250 |
-
return all_traces, selection, df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
funcs/bertopic_vis_documents.py
CHANGED
@@ -1,10 +1,23 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
|
|
3 |
import plotly.graph_objects as go
|
4 |
from plotly.subplots import make_subplots
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from umap import UMAP
|
7 |
-
from typing import List, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
import itertools
|
10 |
import numpy as np
|
@@ -23,7 +36,7 @@ def visualize_documents_custom(topic_model,
|
|
23 |
custom_labels: Union[bool, str] = False,
|
24 |
title: str = "<b>Documents and Topics</b>",
|
25 |
width: int = 1200,
|
26 |
-
height: int = 750):
|
27 |
""" Visualize documents and their topics in 2D
|
28 |
|
29 |
Arguments:
|
@@ -164,9 +177,9 @@ def visualize_documents_custom(topic_model,
|
|
164 |
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
165 |
else:
|
166 |
print("Not using custom labels")
|
167 |
-
names = [f"{topic}
|
168 |
|
169 |
-
print(names)
|
170 |
|
171 |
# Visualize
|
172 |
fig = go.Figure()
|
@@ -254,6 +267,350 @@ def visualize_documents_custom(topic_model,
|
|
254 |
fig.update_yaxes(visible=False)
|
255 |
return fig
|
256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
def visualize_hierarchical_documents_custom(topic_model,
|
258 |
docs: List[str],
|
259 |
hover_labels: List[str],
|
@@ -269,7 +626,7 @@ def visualize_hierarchical_documents_custom(topic_model,
|
|
269 |
custom_labels: Union[bool, str] = False,
|
270 |
title: str = "<b>Hierarchical Documents and Topics</b>",
|
271 |
width: int = 1200,
|
272 |
-
height: int = 750) -> go.Figure:
|
273 |
""" Visualize documents and their topics in 2D at different levels of hierarchy
|
274 |
|
275 |
Arguments:
|
@@ -455,21 +812,22 @@ def visualize_hierarchical_documents_custom(topic_model,
|
|
455 |
# Prepare topic names of original and merged topics
|
456 |
trace_names = []
|
457 |
topic_names = {}
|
|
|
458 |
for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
|
459 |
if topic < hierarchical_topics.Parent_ID.astype(int).min():
|
460 |
if topic_model.get_topic(topic):
|
461 |
if isinstance(custom_labels, str):
|
462 |
-
trace_name = f"{topic}
|
463 |
elif topic_model.custom_labels_ is not None and custom_labels:
|
464 |
trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
|
465 |
else:
|
466 |
-
trace_name = f"{topic}
|
467 |
-
topic_names[topic] = {"trace_name": trace_name[:
|
468 |
trace_names.append(trace_name)
|
469 |
else:
|
470 |
-
trace_name = f"{topic}
|
471 |
-
plot_text = "
|
472 |
-
topic_names[topic] = {"trace_name": trace_name[:
|
473 |
trace_names.append(trace_name)
|
474 |
|
475 |
# Prepare traces
|
@@ -598,7 +956,13 @@ def visualize_hierarchical_documents_custom(topic_model,
|
|
598 |
|
599 |
fig.update_xaxes(visible=False)
|
600 |
fig.update_yaxes(visible=False)
|
601 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
|
603 |
def visualize_barchart_custom(topic_model,
|
604 |
topics: List[int] = None,
|
@@ -607,7 +971,7 @@ def visualize_barchart_custom(topic_model,
|
|
607 |
custom_labels: Union[bool, str] = False,
|
608 |
title: str = "<b>Topic Word Scores</b>",
|
609 |
width: int = 250,
|
610 |
-
height: int = 250) -> go.Figure:
|
611 |
""" Visualize a barchart of selected topics
|
612 |
|
613 |
Arguments:
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
+
import gradio as gr
|
4 |
import plotly.graph_objects as go
|
5 |
from plotly.subplots import make_subplots
|
6 |
|
7 |
+
from bertopic._utils import check_documents_type, validate_distance_matrix
|
8 |
+
from bertopic.plotting._hierarchy import _get_annotations
|
9 |
+
import plotly.figure_factory as ff
|
10 |
+
from packaging import version
|
11 |
+
|
12 |
+
import math
|
13 |
from umap import UMAP
|
14 |
+
from typing import List, Union, Callable
|
15 |
+
|
16 |
+
from scipy.sparse import csr_matrix
|
17 |
+
from scipy.cluster import hierarchy as sch
|
18 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
19 |
+
from sklearn import __version__ as sklearn_version
|
20 |
+
from tqdm import tqdm
|
21 |
|
22 |
import itertools
|
23 |
import numpy as np
|
|
|
36 |
custom_labels: Union[bool, str] = False,
|
37 |
title: str = "<b>Documents and Topics</b>",
|
38 |
width: int = 1200,
|
39 |
+
height: int = 750, progress=gr.Progress(track_tqdm=True)):
|
40 |
""" Visualize documents and their topics in 2D
|
41 |
|
42 |
Arguments:
|
|
|
177 |
names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics]
|
178 |
else:
|
179 |
print("Not using custom labels")
|
180 |
+
names = [f"{topic} " + ", ".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics]
|
181 |
|
182 |
+
#print(names)
|
183 |
|
184 |
# Visualize
|
185 |
fig = go.Figure()
|
|
|
267 |
fig.update_yaxes(visible=False)
|
268 |
return fig
|
269 |
|
270 |
+
def hierarchical_topics_custom(self,
|
271 |
+
docs: List[str],
|
272 |
+
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
|
273 |
+
distance_function: Callable[[csr_matrix], csr_matrix] = None, progress=gr.Progress(track_tqdm=True)) -> pd.DataFrame:
|
274 |
+
""" Create a hierarchy of topics
|
275 |
+
|
276 |
+
To create this hierarchy, BERTopic needs to be already fitted once.
|
277 |
+
Then, a hierarchy is calculated on the distance matrix of the c-TF-IDF
|
278 |
+
representation using `scipy.cluster.hierarchy.linkage`.
|
279 |
+
|
280 |
+
Based on that hierarchy, we calculate the topic representation at each
|
281 |
+
merged step. This is a local representation, as we only assume that the
|
282 |
+
chosen step is merged and not all others which typically improves the
|
283 |
+
topic representation.
|
284 |
+
|
285 |
+
Arguments:
|
286 |
+
docs: The documents you used when calling either `fit` or `fit_transform`
|
287 |
+
linkage_function: The linkage function to use. Default is:
|
288 |
+
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
|
289 |
+
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
|
290 |
+
`lambda x: 1 - cosine_similarity(x)`.
|
291 |
+
You can pass any function that returns either a square matrix of
|
292 |
+
shape (n_samples, n_samples) with zeros on the diagonal and
|
293 |
+
non-negative values or condensed distance matrix of shape
|
294 |
+
(n_samples * (n_samples - 1) / 2,) containing the upper
|
295 |
+
triangular of the distance matrix.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
299 |
+
represented by their parents and their children
|
300 |
+
|
301 |
+
Examples:
|
302 |
+
|
303 |
+
```python
|
304 |
+
from bertopic import BERTopic
|
305 |
+
topic_model = BERTopic()
|
306 |
+
topics, probs = topic_model.fit_transform(docs)
|
307 |
+
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
308 |
+
```
|
309 |
+
|
310 |
+
A custom linkage function can be used as follows:
|
311 |
+
|
312 |
+
```python
|
313 |
+
from scipy.cluster import hierarchy as sch
|
314 |
+
from bertopic import BERTopic
|
315 |
+
topic_model = BERTopic()
|
316 |
+
topics, probs = topic_model.fit_transform(docs)
|
317 |
+
|
318 |
+
# Hierarchical topics
|
319 |
+
linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True)
|
320 |
+
hierarchical_topics = topic_model.hierarchical_topics(docs, linkage_function=linkage_function)
|
321 |
+
```
|
322 |
+
"""
|
323 |
+
check_documents_type(docs)
|
324 |
+
if distance_function is None:
|
325 |
+
distance_function = lambda x: 1 - cosine_similarity(x)
|
326 |
+
|
327 |
+
if linkage_function is None:
|
328 |
+
linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True)
|
329 |
+
|
330 |
+
# Calculate distance
|
331 |
+
embeddings = self.c_tf_idf_[self._outliers:]
|
332 |
+
X = distance_function(embeddings)
|
333 |
+
X = validate_distance_matrix(X, embeddings.shape[0])
|
334 |
+
|
335 |
+
# Use the 1-D condensed distance matrix as an input instead of the raw distance matrix
|
336 |
+
Z = linkage_function(X)
|
337 |
+
|
338 |
+
# Calculate basic bag-of-words to be iteratively merged later
|
339 |
+
documents = pd.DataFrame({"Document": docs,
|
340 |
+
"ID": range(len(docs)),
|
341 |
+
"Topic": self.topics_})
|
342 |
+
documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
|
343 |
+
documents_per_topic = documents_per_topic.loc[documents_per_topic.Topic != -1, :]
|
344 |
+
clean_documents = self._preprocess_text(documents_per_topic.Document.values)
|
345 |
+
|
346 |
+
# Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0
|
347 |
+
# and will be removed in 1.2. Please use get_feature_names_out instead.
|
348 |
+
if version.parse(sklearn_version) >= version.parse("1.0.0"):
|
349 |
+
words = self.vectorizer_model.get_feature_names_out()
|
350 |
+
else:
|
351 |
+
words = self.vectorizer_model.get_feature_names()
|
352 |
+
|
353 |
+
bow = self.vectorizer_model.transform(clean_documents)
|
354 |
+
|
355 |
+
# Extract clusters
|
356 |
+
hier_topics = pd.DataFrame(columns=["Parent_ID", "Parent_Name", "Topics",
|
357 |
+
"Child_Left_ID", "Child_Left_Name",
|
358 |
+
"Child_Right_ID", "Child_Right_Name"])
|
359 |
+
for index in tqdm(range(len(Z))):
|
360 |
+
|
361 |
+
# Find clustered documents
|
362 |
+
clusters = sch.fcluster(Z, t=Z[index][2], criterion='distance') - self._outliers
|
363 |
+
nr_clusters = len(clusters)
|
364 |
+
|
365 |
+
# Extract first topic we find to get the set of topics in a merged topic
|
366 |
+
topic = None
|
367 |
+
val = Z[index][0]
|
368 |
+
while topic is None:
|
369 |
+
if val - len(clusters) < 0:
|
370 |
+
topic = int(val)
|
371 |
+
else:
|
372 |
+
val = Z[int(val - len(clusters))][0]
|
373 |
+
clustered_topics = [i for i, x in enumerate(clusters) if x == clusters[topic]]
|
374 |
+
|
375 |
+
# Group bow per cluster, calculate c-TF-IDF and extract words
|
376 |
+
grouped = csr_matrix(bow[clustered_topics].sum(axis=0))
|
377 |
+
c_tf_idf = self.ctfidf_model.transform(grouped)
|
378 |
+
selection = documents.loc[documents.Topic.isin(clustered_topics), :]
|
379 |
+
selection.Topic = 0
|
380 |
+
words_per_topic = self._extract_words_per_topic(words, selection, c_tf_idf, calculate_aspects=False)
|
381 |
+
|
382 |
+
# Extract parent's name and ID
|
383 |
+
parent_id = index + len(clusters)
|
384 |
+
parent_name = ", ".join([x[0] for x in words_per_topic[0]][:5])
|
385 |
+
|
386 |
+
# Extract child's name and ID
|
387 |
+
Z_id = Z[index][0]
|
388 |
+
child_left_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters
|
389 |
+
|
390 |
+
if Z_id - nr_clusters < 0:
|
391 |
+
child_left_name = ", ".join([x[0] for x in self.get_topic(Z_id)][:5])
|
392 |
+
else:
|
393 |
+
child_left_name = hier_topics.iloc[int(child_left_id)].Parent_Name
|
394 |
+
|
395 |
+
# Extract child's name and ID
|
396 |
+
Z_id = Z[index][1]
|
397 |
+
child_right_id = Z_id if Z_id - nr_clusters < 0 else Z_id - nr_clusters
|
398 |
+
|
399 |
+
if Z_id - nr_clusters < 0:
|
400 |
+
child_right_name = ", ".join([x[0] for x in self.get_topic(Z_id)][:5])
|
401 |
+
else:
|
402 |
+
child_right_name = hier_topics.iloc[int(child_right_id)].Parent_Name
|
403 |
+
|
404 |
+
# Save results
|
405 |
+
hier_topics.loc[len(hier_topics), :] = [parent_id, parent_name,
|
406 |
+
clustered_topics,
|
407 |
+
int(Z[index][0]), child_left_name,
|
408 |
+
int(Z[index][1]), child_right_name]
|
409 |
+
|
410 |
+
hier_topics["Distance"] = Z[:, 2]
|
411 |
+
hier_topics = hier_topics.sort_values("Parent_ID", ascending=False)
|
412 |
+
hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]] = hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]].astype(str)
|
413 |
+
|
414 |
+
return hier_topics
|
415 |
+
|
416 |
+
def visualize_hierarchy_custom(topic_model,
|
417 |
+
orientation: str = "left",
|
418 |
+
topics: List[int] = None,
|
419 |
+
top_n_topics: int = None,
|
420 |
+
custom_labels: Union[bool, str] = False,
|
421 |
+
title: str = "<b>Hierarchical Clustering</b>",
|
422 |
+
width: int = 1000,
|
423 |
+
height: int = 600,
|
424 |
+
hierarchical_topics: pd.DataFrame = None,
|
425 |
+
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
|
426 |
+
distance_function: Callable[[csr_matrix], csr_matrix] = None,
|
427 |
+
color_threshold: int = 1) -> go.Figure:
|
428 |
+
""" Visualize a hierarchical structure of the topics
|
429 |
+
|
430 |
+
A ward linkage function is used to perform the
|
431 |
+
hierarchical clustering based on the cosine distance
|
432 |
+
matrix between topic embeddings.
|
433 |
+
|
434 |
+
Arguments:
|
435 |
+
topic_model: A fitted BERTopic instance.
|
436 |
+
orientation: The orientation of the figure.
|
437 |
+
Either 'left' or 'bottom'
|
438 |
+
topics: A selection of topics to visualize
|
439 |
+
top_n_topics: Only select the top n most frequent topics
|
440 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
441 |
+
`topic_model.set_topic_labels`.
|
442 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
443 |
+
NOTE: Custom labels are only generated for the original
|
444 |
+
un-merged topics.
|
445 |
+
title: Title of the plot.
|
446 |
+
width: The width of the figure. Only works if orientation is set to 'left'
|
447 |
+
height: The height of the figure. Only works if orientation is set to 'bottom'
|
448 |
+
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
449 |
+
represented by their parents and their children.
|
450 |
+
NOTE: The hierarchical topic names are only visualized
|
451 |
+
if both `topics` and `top_n_topics` are not set.
|
452 |
+
linkage_function: The linkage function to use. Default is:
|
453 |
+
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
|
454 |
+
NOTE: Make sure to use the same `linkage_function` as used
|
455 |
+
in `topic_model.hierarchical_topics`.
|
456 |
+
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
|
457 |
+
`lambda x: 1 - cosine_similarity(x)`.
|
458 |
+
You can pass any function that returns either a square matrix of
|
459 |
+
shape (n_samples, n_samples) with zeros on the diagonal and
|
460 |
+
non-negative values or condensed distance matrix of shape
|
461 |
+
(n_samples * (n_samples - 1) / 2,) containing the upper
|
462 |
+
triangular of the distance matrix.
|
463 |
+
NOTE: Make sure to use the same `distance_function` as used
|
464 |
+
in `topic_model.hierarchical_topics`.
|
465 |
+
color_threshold: Value at which the separation of clusters will be made which
|
466 |
+
will result in different colors for different clusters.
|
467 |
+
A higher value will typically lead in less colored clusters.
|
468 |
+
|
469 |
+
Returns:
|
470 |
+
fig: A plotly figure
|
471 |
+
|
472 |
+
Examples:
|
473 |
+
|
474 |
+
To visualize the hierarchical structure of
|
475 |
+
topics simply run:
|
476 |
+
|
477 |
+
```python
|
478 |
+
topic_model.visualize_hierarchy()
|
479 |
+
```
|
480 |
+
|
481 |
+
If you also want the labels visualized of hierarchical topics,
|
482 |
+
run the following:
|
483 |
+
|
484 |
+
```python
|
485 |
+
# Extract hierarchical topics and their representations
|
486 |
+
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
487 |
+
|
488 |
+
# Visualize these representations
|
489 |
+
topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)
|
490 |
+
```
|
491 |
+
|
492 |
+
If you want to save the resulting figure:
|
493 |
+
|
494 |
+
```python
|
495 |
+
fig = topic_model.visualize_hierarchy()
|
496 |
+
fig.write_html("path/to/file.html")
|
497 |
+
```
|
498 |
+
<iframe src="../../getting_started/visualization/hierarchy.html"
|
499 |
+
style="width:1000px; height: 680px; border: 0px;""></iframe>
|
500 |
+
"""
|
501 |
+
if distance_function is None:
|
502 |
+
distance_function = lambda x: 1 - cosine_similarity(x)
|
503 |
+
|
504 |
+
if linkage_function is None:
|
505 |
+
linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True)
|
506 |
+
|
507 |
+
# Select topics based on top_n and topics args
|
508 |
+
freq_df = topic_model.get_topic_freq()
|
509 |
+
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
510 |
+
if topics is not None:
|
511 |
+
topics = list(topics)
|
512 |
+
elif top_n_topics is not None:
|
513 |
+
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
514 |
+
else:
|
515 |
+
topics = sorted(freq_df.Topic.to_list())
|
516 |
+
|
517 |
+
# Select embeddings
|
518 |
+
all_topics = sorted(list(topic_model.get_topics().keys()))
|
519 |
+
indices = np.array([all_topics.index(topic) for topic in topics])
|
520 |
+
|
521 |
+
# Select topic embeddings
|
522 |
+
if topic_model.c_tf_idf_ is not None:
|
523 |
+
embeddings = topic_model.c_tf_idf_[indices]
|
524 |
+
else:
|
525 |
+
embeddings = np.array(topic_model.topic_embeddings_)[indices]
|
526 |
+
|
527 |
+
# Annotations
|
528 |
+
if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()):
|
529 |
+
annotations = _get_annotations(topic_model=topic_model,
|
530 |
+
hierarchical_topics=hierarchical_topics,
|
531 |
+
embeddings=embeddings,
|
532 |
+
distance_function=distance_function,
|
533 |
+
linkage_function=linkage_function,
|
534 |
+
orientation=orientation,
|
535 |
+
custom_labels=custom_labels)
|
536 |
+
else:
|
537 |
+
annotations = None
|
538 |
+
|
539 |
+
# wrap distance function to validate input and return a condensed distance matrix
|
540 |
+
distance_function_viz = lambda x: validate_distance_matrix(
|
541 |
+
distance_function(x), embeddings.shape[0])
|
542 |
+
# Create dendogram
|
543 |
+
fig = ff.create_dendrogram(embeddings,
|
544 |
+
orientation=orientation,
|
545 |
+
distfun=distance_function_viz,
|
546 |
+
linkagefun=linkage_function,
|
547 |
+
hovertext=annotations,
|
548 |
+
color_threshold=color_threshold)
|
549 |
+
|
550 |
+
# Create nicer labels
|
551 |
+
axis = "yaxis" if orientation == "left" else "xaxis"
|
552 |
+
if isinstance(custom_labels, str):
|
553 |
+
new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]]
|
554 |
+
new_labels = [", ".join([label[0] for label in labels[:4]]) for labels in new_labels]
|
555 |
+
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
|
556 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
557 |
+
new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]]
|
558 |
+
else:
|
559 |
+
new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)])
|
560 |
+
for x in fig.layout[axis]["ticktext"]]
|
561 |
+
new_labels = [", ".join([label[0] for label in labels[:4]]) for labels in new_labels]
|
562 |
+
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
|
563 |
+
|
564 |
+
# Stylize layout
|
565 |
+
fig.update_layout(
|
566 |
+
plot_bgcolor='#ECEFF1',
|
567 |
+
template="plotly_white",
|
568 |
+
title={
|
569 |
+
'text': f"{title}",
|
570 |
+
'x': 0.5,
|
571 |
+
'xanchor': 'center',
|
572 |
+
'yanchor': 'top',
|
573 |
+
'font': dict(
|
574 |
+
size=22,
|
575 |
+
color="Black")
|
576 |
+
},
|
577 |
+
hoverlabel=dict(
|
578 |
+
bgcolor="white",
|
579 |
+
font_size=16,
|
580 |
+
font_family="Rockwell"
|
581 |
+
),
|
582 |
+
)
|
583 |
+
|
584 |
+
# Stylize orientation
|
585 |
+
if orientation == "left":
|
586 |
+
fig.update_layout(height=200 + (15 * len(topics)),
|
587 |
+
width=width,
|
588 |
+
yaxis=dict(tickmode="array",
|
589 |
+
ticktext=new_labels))
|
590 |
+
|
591 |
+
# Fix empty space on the bottom of the graph
|
592 |
+
y_max = max([trace['y'].max() + 5 for trace in fig['data']])
|
593 |
+
y_min = min([trace['y'].min() - 5 for trace in fig['data']])
|
594 |
+
fig.update_layout(yaxis=dict(range=[y_min, y_max]))
|
595 |
+
|
596 |
+
else:
|
597 |
+
fig.update_layout(width=200 + (15 * len(topics)),
|
598 |
+
height=height,
|
599 |
+
xaxis=dict(tickmode="array",
|
600 |
+
ticktext=new_labels))
|
601 |
+
|
602 |
+
if hierarchical_topics is not None:
|
603 |
+
for index in [0, 3]:
|
604 |
+
axis = "x" if orientation == "left" else "y"
|
605 |
+
xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
606 |
+
ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
607 |
+
hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
608 |
+
|
609 |
+
fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black',
|
610 |
+
hovertext=hovertext, hoverinfo="text",
|
611 |
+
mode='markers', showlegend=False))
|
612 |
+
return fig
|
613 |
+
|
614 |
def visualize_hierarchical_documents_custom(topic_model,
|
615 |
docs: List[str],
|
616 |
hover_labels: List[str],
|
|
|
626 |
custom_labels: Union[bool, str] = False,
|
627 |
title: str = "<b>Hierarchical Documents and Topics</b>",
|
628 |
width: int = 1200,
|
629 |
+
height: int = 750, progress=gr.Progress(track_tqdm=True)) -> go.Figure:
|
630 |
""" Visualize documents and their topics in 2D at different levels of hierarchy
|
631 |
|
632 |
Arguments:
|
|
|
812 |
# Prepare topic names of original and merged topics
|
813 |
trace_names = []
|
814 |
topic_names = {}
|
815 |
+
trace_name_char_length = 60
|
816 |
for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
|
817 |
if topic < hierarchical_topics.Parent_ID.astype(int).min():
|
818 |
if topic_model.get_topic(topic):
|
819 |
if isinstance(custom_labels, str):
|
820 |
+
trace_name = f"{topic} " + ", ".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:5])
|
821 |
elif topic_model.custom_labels_ is not None and custom_labels:
|
822 |
trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
|
823 |
else:
|
824 |
+
trace_name = f"{topic} " + ", ".join([word[:20] for word, _ in topic_model.get_topic(topic)][:5])
|
825 |
+
topic_names[topic] = {"trace_name": trace_name[:trace_name_char_length], "plot_text": trace_name[:trace_name_char_length]}
|
826 |
trace_names.append(trace_name)
|
827 |
else:
|
828 |
+
trace_name = f"{topic} " + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
|
829 |
+
plot_text = ", ".join([name[:20] for name in trace_name.split(" ")[:5]])
|
830 |
+
topic_names[topic] = {"trace_name": trace_name[:trace_name_char_length], "plot_text": plot_text[:trace_name_char_length]}
|
831 |
trace_names.append(trace_name)
|
832 |
|
833 |
# Prepare traces
|
|
|
956 |
|
957 |
fig.update_xaxes(visible=False)
|
958 |
fig.update_yaxes(visible=False)
|
959 |
+
|
960 |
+
hierarchy_topics_df = df.filter(regex=r'topic|^level').drop_duplicates(subset="topic")
|
961 |
+
|
962 |
+
topic_names = pd.DataFrame(topic_names).T
|
963 |
+
|
964 |
+
|
965 |
+
return fig, hierarchy_topics_df, topic_names
|
966 |
|
967 |
def visualize_barchart_custom(topic_model,
|
968 |
topics: List[int] = None,
|
|
|
971 |
custom_labels: Union[bool, str] = False,
|
972 |
title: str = "<b>Topic Word Scores</b>",
|
973 |
width: int = 250,
|
974 |
+
height: int = 250, progress=gr.Progress(track_tqdm=True)) -> go.Figure:
|
975 |
""" Visualize a barchart of selected topics
|
976 |
|
977 |
Arguments:
|
funcs/clean_funcs.py
CHANGED
@@ -33,18 +33,19 @@ multiple_spaces_regex = r'\s{2,}'
|
|
33 |
|
34 |
def initial_clean(texts, custom_regex, progress=gr.Progress()):
|
35 |
texts = pl.Series(texts).str.strip_chars()
|
36 |
-
text = texts.str.replace_all(html_pattern_regex, '')
|
37 |
-
text = text.str.replace_all(email_pattern_regex, '')
|
38 |
-
text = text.str.replace_all(nums_two_more_regex, '')
|
39 |
-
text = text.str.replace_all(postcode_pattern_regex, '')
|
40 |
-
text = text.str.replace_all(multiple_spaces_regex, '')
|
41 |
|
42 |
# Allow for custom regex patterns to be removed
|
43 |
if len(custom_regex) > 0:
|
44 |
for pattern in custom_regex:
|
45 |
raw_string_pattern = r'{}'.format(pattern)
|
46 |
print("Removing regex pattern: ", raw_string_pattern)
|
47 |
-
text = text.str.replace_all(raw_string_pattern, '')
|
|
|
|
|
48 |
|
49 |
text = text.to_list()
|
50 |
|
|
|
33 |
|
34 |
def initial_clean(texts, custom_regex, progress=gr.Progress()):
|
35 |
texts = pl.Series(texts).str.strip_chars()
|
36 |
+
text = texts.str.replace_all(html_pattern_regex, ' ')
|
37 |
+
text = text.str.replace_all(email_pattern_regex, ' ')
|
38 |
+
text = text.str.replace_all(nums_two_more_regex, ' ')
|
39 |
+
text = text.str.replace_all(postcode_pattern_regex, ' ')
|
|
|
40 |
|
41 |
# Allow for custom regex patterns to be removed
|
42 |
if len(custom_regex) > 0:
|
43 |
for pattern in custom_regex:
|
44 |
raw_string_pattern = r'{}'.format(pattern)
|
45 |
print("Removing regex pattern: ", raw_string_pattern)
|
46 |
+
text = text.str.replace_all(raw_string_pattern, ' ')
|
47 |
+
|
48 |
+
text = text.str.replace_all(multiple_spaces_regex, ' ')
|
49 |
|
50 |
text = text.to_list()
|
51 |
|
funcs/topic_core_funcs.py
CHANGED
@@ -11,6 +11,8 @@ from bertopic import BERTopic
|
|
11 |
from funcs.clean_funcs import initial_clean
|
12 |
from funcs.helper_functions import read_file, zip_folder, delete_files_in_folder, save_topic_outputs
|
13 |
from funcs.embeddings import make_or_load_embeddings
|
|
|
|
|
14 |
|
15 |
from sentence_transformers import SentenceTransformer
|
16 |
from sklearn.pipeline import make_pipeline
|
@@ -145,9 +147,7 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
145 |
if not in_colnames:
|
146 |
error_message = "Please enter one column name to use for cleaning and finding topics."
|
147 |
print(error_message)
|
148 |
-
return error_message, None, data_file_name_no_ext, embeddings_out, None, None
|
149 |
-
|
150 |
-
|
151 |
|
152 |
in_colnames_list_first = in_colnames[0]
|
153 |
|
@@ -186,7 +186,9 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
186 |
|
187 |
embeddings_type_state = "tfidf"
|
188 |
|
189 |
-
umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
|
|
|
|
|
190 |
|
191 |
embeddings_out = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
|
192 |
|
@@ -195,7 +197,7 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
195 |
|
196 |
progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
|
197 |
|
198 |
-
fail_error_message = "Topic model creation failed. Try reducing minimum documents per topic on the slider above (try 15 or less), then click 'Extract topics' again."
|
199 |
|
200 |
if not candidate_topics:
|
201 |
|
@@ -217,10 +219,11 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
217 |
topics_probs_out.to_csv(topics_probs_out_name)
|
218 |
output_list.append(topics_probs_out_name)
|
219 |
|
220 |
-
except:
|
|
|
221 |
print(fail_error_message)
|
222 |
|
223 |
-
return fail_error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs, vectoriser_model
|
224 |
|
225 |
|
226 |
# Do this if you have pre-defined topics
|
@@ -229,7 +232,7 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
229 |
error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
|
230 |
print(error_message)
|
231 |
|
232 |
-
return error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs, vectoriser_model
|
233 |
|
234 |
zero_shot_topics = read_file(candidate_topics.name)
|
235 |
zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
|
@@ -254,17 +257,21 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
254 |
topics_probs_out.to_csv(topics_probs_out_name)
|
255 |
output_list.append(topics_probs_out_name)
|
256 |
|
257 |
-
except:
|
|
|
258 |
print(fail_error_message)
|
259 |
|
260 |
-
return fail_error_message, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, docs, vectoriser_model
|
261 |
|
262 |
# For some reason, zero topic modelling exports assigned topics as a np.array instead of a list. Converting it back here.
|
263 |
if isinstance(assigned_topics, np.ndarray):
|
264 |
assigned_topics = assigned_topics.tolist()
|
265 |
|
266 |
-
|
|
|
|
|
267 |
|
|
|
268 |
doc_dets = topic_model.get_document_info(docs)
|
269 |
|
270 |
documents_per_topic = doc_dets.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
|
@@ -277,13 +284,19 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
277 |
c_tf_idf, _ = topic_model._c_tf_idf(documents_per_topic)
|
278 |
topic_model.c_tf_idf_ = c_tf_idf
|
279 |
|
|
|
|
|
|
|
|
|
280 |
if not assigned_topics:
|
281 |
-
|
282 |
-
return "No topics found.", output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, topic_model, docs
|
283 |
-
|
284 |
else:
|
285 |
print("Topic model created.")
|
286 |
|
|
|
|
|
|
|
|
|
287 |
# Replace current topic labels if new ones loaded in
|
288 |
if not custom_labels_df.empty:
|
289 |
#custom_label_list = list(custom_labels_df.iloc[:,0])
|
@@ -315,9 +328,9 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
315 |
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
|
316 |
print(time_out)
|
317 |
|
318 |
-
return output_text, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, topic_model, docs, vectoriser_model
|
319 |
|
320 |
-
def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, save_topic_model, progress=gr.Progress(track_tqdm=True)):
|
321 |
|
322 |
progress(0, desc= "Preparing data")
|
323 |
|
@@ -325,7 +338,8 @@ def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, sa
|
|
325 |
|
326 |
all_tic = time.perf_counter()
|
327 |
|
328 |
-
|
|
|
329 |
|
330 |
if isinstance(assigned_topics, np.ndarray):
|
331 |
assigned_topics = assigned_topics.tolist()
|
@@ -339,7 +353,12 @@ def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, sa
|
|
339 |
# Then, update the topics to the ones that considered the new data
|
340 |
|
341 |
progress(0.6, desc= "Updating original model")
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
print("Finished reducing outliers.")
|
345 |
|
@@ -375,7 +394,7 @@ def represent_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode
|
|
375 |
|
376 |
representation_model = create_representation_model(representation_type, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
377 |
|
378 |
-
progress(0.
|
379 |
topic_model.update_topics(docs, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
380 |
|
381 |
topic_dets = topic_model.get_topic_info()
|
@@ -394,8 +413,7 @@ def represent_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode
|
|
394 |
else:
|
395 |
new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ", aspect = representation_type)
|
396 |
|
397 |
-
topic_model.set_topic_labels(new_topic_labels)
|
398 |
-
#topic_model.set_topic_labels(list(topic_dets["Name"]))
|
399 |
|
400 |
# Outputs
|
401 |
progress(0.8, desc= "Saving outputs")
|
@@ -414,8 +432,7 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
|
|
414 |
output_list = []
|
415 |
vis_tic = time.perf_counter()
|
416 |
|
417 |
-
|
418 |
-
|
419 |
if not visualisation_type_radio:
|
420 |
return "Please choose a visualisation type above.", output_list, None, None
|
421 |
|
@@ -475,7 +492,7 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
|
|
475 |
|
476 |
elif visualisation_type_radio == "Hierarchical view":
|
477 |
|
478 |
-
hierarchical_topics = topic_model
|
479 |
|
480 |
# Print topic tree
|
481 |
tree = topic_model.get_topic_tree(hierarchical_topics, tight_layout = True)
|
@@ -488,16 +505,28 @@ def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode
|
|
488 |
output_list.append(tree_name)
|
489 |
|
490 |
# Save new hierarchical topic model to file
|
491 |
-
hierarchical_topics_name = data_file_name_no_ext + '_' + '
|
492 |
hierarchical_topics.to_csv(hierarchical_topics_name)
|
493 |
output_list.append(hierarchical_topics_name)
|
494 |
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
|
502 |
topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
|
503 |
topics_vis.write_html(topics_vis_name)
|
|
|
11 |
from funcs.clean_funcs import initial_clean
|
12 |
from funcs.helper_functions import read_file, zip_folder, delete_files_in_folder, save_topic_outputs
|
13 |
from funcs.embeddings import make_or_load_embeddings
|
14 |
+
from funcs.bertopic_vis_documents import visualize_documents_custom, visualize_hierarchical_documents_custom, hierarchical_topics_custom, visualize_hierarchy_custom
|
15 |
+
|
16 |
|
17 |
from sentence_transformers import SentenceTransformer
|
18 |
from sklearn.pipeline import make_pipeline
|
|
|
147 |
if not in_colnames:
|
148 |
error_message = "Please enter one column name to use for cleaning and finding topics."
|
149 |
print(error_message)
|
150 |
+
return error_message, None, data_file_name_no_ext, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, None, vectoriser_state, []
|
|
|
|
|
151 |
|
152 |
in_colnames_list_first = in_colnames[0]
|
153 |
|
|
|
186 |
|
187 |
embeddings_type_state = "tfidf"
|
188 |
|
189 |
+
#umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
|
190 |
+
# UMAP model uses Bertopic defaults
|
191 |
+
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', low_memory=True, random_state=random_seed)
|
192 |
|
193 |
embeddings_out = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
|
194 |
|
|
|
197 |
|
198 |
progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
|
199 |
|
200 |
+
fail_error_message = "Topic model creation failed. Try reducing minimum documents per topic on the slider above (try 15 or less), then click 'Extract topics' again. If that doesn't work, try running the first two clean steps on your data first (see Clean data above) to ensure there are no NaNs/missing texts in your data."
|
201 |
|
202 |
if not candidate_topics:
|
203 |
|
|
|
219 |
topics_probs_out.to_csv(topics_probs_out_name)
|
220 |
output_list.append(topics_probs_out_name)
|
221 |
|
222 |
+
except Exception as error:
|
223 |
+
print(error)
|
224 |
print(fail_error_message)
|
225 |
|
226 |
+
return fail_error_message, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, docs, vectoriser_model, []
|
227 |
|
228 |
|
229 |
# Do this if you have pre-defined topics
|
|
|
232 |
error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
|
233 |
print(error_message)
|
234 |
|
235 |
+
return error_message, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, docs, vectoriser_model, []
|
236 |
|
237 |
zero_shot_topics = read_file(candidate_topics.name)
|
238 |
zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
|
|
|
257 |
topics_probs_out.to_csv(topics_probs_out_name)
|
258 |
output_list.append(topics_probs_out_name)
|
259 |
|
260 |
+
except Exception as error:
|
261 |
+
print("An exception occurred:", error)
|
262 |
print(fail_error_message)
|
263 |
|
264 |
+
return fail_error_message, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, None, docs, vectoriser_model, []
|
265 |
|
266 |
# For some reason, zero topic modelling exports assigned topics as a np.array instead of a list. Converting it back here.
|
267 |
if isinstance(assigned_topics, np.ndarray):
|
268 |
assigned_topics = assigned_topics.tolist()
|
269 |
|
270 |
+
|
271 |
+
|
272 |
+
# Zero shot modelling is a model merge, which wipes the c_tf_idf part of the resulting model completely. To get hierarchical modelling to work, we need to recreate this part of the model with the CountVectorizer options used to create the initial model. Since with zero shot, we are merging two models that have exactly the same set of documents, the vocubulary should be the same, and so recreating the cf_tf_idf component in this way shouldn't be a problem. Discussion here, and below based on Maarten's suggested code: https://github.com/MaartenGr/BERTopic/issues/1700
|
273 |
|
274 |
+
# Get document info
|
275 |
doc_dets = topic_model.get_document_info(docs)
|
276 |
|
277 |
documents_per_topic = doc_dets.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
|
|
|
284 |
c_tf_idf, _ = topic_model._c_tf_idf(documents_per_topic)
|
285 |
topic_model.c_tf_idf_ = c_tf_idf
|
286 |
|
287 |
+
###
|
288 |
+
|
289 |
+
|
290 |
+
# Check we have topics
|
291 |
if not assigned_topics:
|
292 |
+
return "No topics found.", output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, topic_model, docs, vectoriser_model,[]
|
|
|
|
|
293 |
else:
|
294 |
print("Topic model created.")
|
295 |
|
296 |
+
# Tidy up topic label format a bit to have commas and spaces by default
|
297 |
+
new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ")
|
298 |
+
topic_model.set_topic_labels(new_topic_labels)
|
299 |
+
|
300 |
# Replace current topic labels if new ones loaded in
|
301 |
if not custom_labels_df.empty:
|
302 |
#custom_label_list = list(custom_labels_df.iloc[:,0])
|
|
|
328 |
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
|
329 |
print(time_out)
|
330 |
|
331 |
+
return output_text, output_list, embeddings_out, embeddings_type_state, data_file_name_no_ext, topic_model, docs, vectoriser_model, assigned_topics
|
332 |
|
333 |
+
def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, assigned_topics, vectoriser_model, save_topic_model, progress=gr.Progress(track_tqdm=True)):
|
334 |
|
335 |
progress(0, desc= "Preparing data")
|
336 |
|
|
|
338 |
|
339 |
all_tic = time.perf_counter()
|
340 |
|
341 |
+
# This step not necessary?
|
342 |
+
#assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
343 |
|
344 |
if isinstance(assigned_topics, np.ndarray):
|
345 |
assigned_topics = assigned_topics.tolist()
|
|
|
353 |
# Then, update the topics to the ones that considered the new data
|
354 |
|
355 |
progress(0.6, desc= "Updating original model")
|
356 |
+
|
357 |
+
topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model = vectoriser_model)
|
358 |
+
|
359 |
+
# Tidy up topic label format a bit to have commas and spaces by default
|
360 |
+
new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ")
|
361 |
+
topic_model.set_topic_labels(new_topic_labels)
|
362 |
|
363 |
print("Finished reducing outliers.")
|
364 |
|
|
|
394 |
|
395 |
representation_model = create_representation_model(representation_type, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
396 |
|
397 |
+
progress(0.3, desc= "Updating existing topics")
|
398 |
topic_model.update_topics(docs, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
399 |
|
400 |
topic_dets = topic_model.get_topic_info()
|
|
|
413 |
else:
|
414 |
new_topic_labels = topic_model.generate_topic_labels(nr_words=3, separator=", ", aspect = representation_type)
|
415 |
|
416 |
+
topic_model.set_topic_labels(new_topic_labels)
|
|
|
417 |
|
418 |
# Outputs
|
419 |
progress(0.8, desc= "Saving outputs")
|
|
|
432 |
output_list = []
|
433 |
vis_tic = time.perf_counter()
|
434 |
|
435 |
+
|
|
|
436 |
if not visualisation_type_radio:
|
437 |
return "Please choose a visualisation type above.", output_list, None, None
|
438 |
|
|
|
492 |
|
493 |
elif visualisation_type_radio == "Hierarchical view":
|
494 |
|
495 |
+
hierarchical_topics = hierarchical_topics_custom(topic_model, docs)
|
496 |
|
497 |
# Print topic tree
|
498 |
tree = topic_model.get_topic_tree(hierarchical_topics, tight_layout = True)
|
|
|
505 |
output_list.append(tree_name)
|
506 |
|
507 |
# Save new hierarchical topic model to file
|
508 |
+
hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics_distz_' + today_rev + '.csv'
|
509 |
hierarchical_topics.to_csv(hierarchical_topics_name)
|
510 |
output_list.append(hierarchical_topics_name)
|
511 |
|
512 |
+
|
513 |
+
#try:
|
514 |
+
topics_vis, hierarchy_df, hierarchy_topic_names = visualize_hierarchical_documents_custom(topic_model, docs, label_list, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
|
515 |
+
topics_vis_2 = visualize_hierarchy_custom(topic_model, hierarchical_topics=hierarchical_topics, width= 1200, height = 750)
|
516 |
+
|
517 |
+
# Write hierarchical topics levels to df
|
518 |
+
hierarchy_df_name = data_file_name_no_ext + '_' + 'hierarchy_topics_df_' + today_rev + '.csv'
|
519 |
+
hierarchy_df.to_csv(hierarchy_df_name)
|
520 |
+
output_list.append(hierarchy_df_name)
|
521 |
+
|
522 |
+
# Write hierarchical topics names to df
|
523 |
+
hierarchy_topic_names_name = data_file_name_no_ext + '_' + 'hierarchy_topics_names_' + today_rev + '.csv'
|
524 |
+
hierarchy_topic_names.to_csv(hierarchy_topic_names_name)
|
525 |
+
output_list.append(hierarchy_topic_names_name)
|
526 |
+
|
527 |
+
#except:
|
528 |
+
# error_message = "Visualisation preparation failed. Perhaps you need more topics to create the full hierarchy (more than 10)?"
|
529 |
+
# return error_message, output_list, None, None
|
530 |
|
531 |
topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
|
532 |
topics_vis.write_html(topics_vis_name)
|