ArneBinder
commited on
new demo setup with langchain retriever
Browse filesbased on https://github.com/ArneBinder/pie-document-level/pull/298
- app.py +471 -279
- model_utils.py +339 -65
- rendering_utils.py +50 -40
- requirements.txt +28 -5
- retrieve_and_dump_all_relevant.py +101 -0
- retriever/related_span_retriever_with_relations_from_other_docs.yaml +47 -0
- src/__init__.py +0 -0
- src/hf_pipeline/__init__.py +1 -0
- src/hf_pipeline/__pycache__/__init__.cpython-310.pyc +0 -0
- src/hf_pipeline/__pycache__/__init__.cpython-39.pyc +0 -0
- src/hf_pipeline/__pycache__/feature_extraction.cpython-310.pyc +0 -0
- src/hf_pipeline/__pycache__/feature_extraction.cpython-39.pyc +0 -0
- src/hf_pipeline/feature_extraction.py +317 -0
- src/langchain_modules/__init__.py +9 -0
- src/langchain_modules/__pycache__/__init__.cpython-310.pyc +0 -0
- src/langchain_modules/__pycache__/__init__.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/basic_pie_document_store.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/datasets_pie_document_store.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/huggingface_span_embeddings.cpython-310.pyc +0 -0
- src/langchain_modules/__pycache__/huggingface_span_embeddings.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/pie_document_store.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/qdrant_span_vectorstore.cpython-310.pyc +0 -0
- src/langchain_modules/__pycache__/qdrant_span_vectorstore.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/serializable_store.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/span_embeddings.cpython-310.pyc +0 -0
- src/langchain_modules/__pycache__/span_embeddings.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/span_retriever.cpython-310.pyc +0 -0
- src/langchain_modules/__pycache__/span_retriever.cpython-39.pyc +0 -0
- src/langchain_modules/__pycache__/span_vectorstore.cpython-39.pyc +0 -0
- src/langchain_modules/basic_pie_document_store.py +103 -0
- src/langchain_modules/datasets_pie_document_store.py +156 -0
- src/langchain_modules/huggingface_span_embeddings.py +192 -0
- src/langchain_modules/pie_document_store.py +88 -0
- src/langchain_modules/qdrant_span_vectorstore.py +349 -0
- src/langchain_modules/serializable_store.py +137 -0
- src/langchain_modules/span_embeddings.py +103 -0
- src/langchain_modules/span_retriever.py +860 -0
- src/langchain_modules/span_vectorstore.py +131 -0
app.py
CHANGED
@@ -1,32 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
import logging
|
3 |
import os.path
|
4 |
import re
|
5 |
import tempfile
|
6 |
-
from
|
7 |
-
from typing import List, Optional, Tuple, Union
|
8 |
|
9 |
import arxiv
|
10 |
import gradio as gr
|
11 |
import pandas as pd
|
12 |
import requests
|
13 |
import torch
|
|
|
14 |
from bs4 import BeautifulSoup
|
15 |
-
from
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
23 |
)
|
|
|
|
|
|
|
24 |
from rendering_utils import HIGHLIGHT_SPANS_JS, render_displacy, render_pretty_table
|
25 |
-
|
26 |
-
from
|
|
|
|
|
|
|
27 |
|
28 |
logger = logging.getLogger(__name__)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
|
31 |
RENDER_WITH_PRETTY_TABLE = "Pretty Table"
|
32 |
|
@@ -35,16 +59,31 @@ DEFAULT_MODEL_REVISION = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
|
|
35 |
# local path
|
36 |
# DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
|
37 |
# DEFAULT_MODEL_REVISION = None
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
-
DEFAULT_EMBEDDING_MAX_LENGTH = 512
|
41 |
-
DEFAULT_EMBEDDING_BATCH_SIZE = 32
|
42 |
DEFAULT_SPLIT_REGEX = "\n\n\n+"
|
43 |
DEFAULT_ARXIV_ID = "1706.03762"
|
|
|
|
|
|
|
44 |
|
45 |
# Whether to handle segmented entities in the document. If True, labeled_spans are converted
|
46 |
# to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them.
|
|
|
47 |
HANDLE_PARTS_OF_SAME = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
def escape_regex(regex: str) -> str:
|
@@ -59,19 +98,38 @@ def unescape_regex(regex: str) -> str:
|
|
59 |
return result
|
60 |
|
61 |
|
|
|
|
|
|
|
|
|
|
|
62 |
def render_annotated_document(
|
63 |
-
|
64 |
-
|
65 |
-
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
66 |
-
],
|
67 |
render_with: str,
|
68 |
render_kwargs_json: str,
|
69 |
) -> str:
|
|
|
|
|
|
|
|
|
70 |
render_kwargs = json.loads(render_kwargs_json)
|
71 |
if render_with == RENDER_WITH_PRETTY_TABLE:
|
72 |
-
html = render_pretty_table(
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
elif render_with == RENDER_WITH_DISPLACY:
|
74 |
-
html = render_displacy(
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
else:
|
76 |
raise ValueError(f"Unknown render_with value: {render_with}")
|
77 |
|
@@ -79,84 +137,47 @@ def render_annotated_document(
|
|
79 |
|
80 |
|
81 |
def wrapped_process_text(
|
82 |
-
text: str,
|
83 |
-
|
84 |
-
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
85 |
-
document_store: DocumentStore,
|
86 |
-
split_regex_escaped: str,
|
87 |
-
handle_parts_of_same: bool = False,
|
88 |
-
) -> Tuple[
|
89 |
-
dict,
|
90 |
-
Union[
|
91 |
-
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
92 |
-
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
93 |
-
],
|
94 |
-
]:
|
95 |
try:
|
96 |
-
|
97 |
-
text=text,
|
98 |
-
doc_id=doc_id,
|
99 |
-
split_regex=unescape_regex(split_regex_escaped)
|
100 |
-
if len(split_regex_escaped) > 0
|
101 |
-
else None,
|
102 |
-
)
|
103 |
-
document = annotate_document(
|
104 |
-
document=document,
|
105 |
-
annotation_pipeline=models[0],
|
106 |
-
embedding_model=models[1],
|
107 |
-
handle_parts_of_same=handle_parts_of_same,
|
108 |
-
)
|
109 |
-
document_store.add_document(document)
|
110 |
except Exception as e:
|
111 |
raise gr.Error(f"Failed to process text: {e}")
|
112 |
-
# remove the embeddings because they are very large
|
113 |
-
if document.metadata.get("embeddings"):
|
114 |
-
document.metadata = {k: v for k, v in document.metadata.items() if k != "embeddings"}
|
115 |
# Return as dict and document to avoid serialization issues
|
116 |
-
return
|
117 |
|
118 |
|
119 |
def process_uploaded_files(
|
120 |
-
file_names: List[str],
|
121 |
-
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
122 |
-
document_store: DocumentStore,
|
123 |
-
split_regex_escaped: str,
|
124 |
-
show_max_cross_doc_sims: bool = False,
|
125 |
-
min_similarity: float = 0.95,
|
126 |
-
handle_parts_of_same: bool = False,
|
127 |
) -> pd.DataFrame:
|
128 |
try:
|
129 |
-
|
|
|
130 |
for file_name in file_names:
|
131 |
if file_name.lower().endswith(".txt"):
|
132 |
# read the file content
|
133 |
with open(file_name, "r", encoding="utf-8") as f:
|
134 |
text = f.read()
|
135 |
base_file_name = os.path.basename(file_name)
|
136 |
-
|
137 |
-
|
138 |
-
text=text,
|
139 |
-
doc_id=base_file_name,
|
140 |
-
split_regex=unescape_regex(split_regex_escaped)
|
141 |
-
if len(split_regex_escaped) > 0
|
142 |
-
else None,
|
143 |
-
)
|
144 |
-
new_document = annotate_document(
|
145 |
-
document=new_document,
|
146 |
-
annotation_pipeline=models[0],
|
147 |
-
embedding_model=models[1],
|
148 |
-
handle_parts_of_same=handle_parts_of_same,
|
149 |
-
)
|
150 |
-
new_documents.append(new_document)
|
151 |
else:
|
152 |
raise gr.Error(f"Unsupported file format: {file_name}")
|
153 |
-
|
154 |
except Exception as e:
|
155 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
156 |
|
157 |
-
return
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
|
162 |
def open_accordion():
|
@@ -167,30 +188,34 @@ def close_accordion():
|
|
167 |
return gr.Accordion(open=False)
|
168 |
|
169 |
|
170 |
-
def
|
171 |
evt: gr.SelectData,
|
172 |
-
|
173 |
-
|
174 |
-
) ->
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
row_idx, col_idx = evt.index
|
179 |
-
|
180 |
-
|
181 |
-
col_name = "doc_id"
|
182 |
-
doc_id = processed_documents_df.iloc[row_idx][col_name]
|
183 |
-
doc = document_store.get_document(doc_id, with_embeddings=False)
|
184 |
-
return doc
|
185 |
|
186 |
|
187 |
def set_relation_types(
|
188 |
-
|
189 |
default: Optional[List[str]] = None,
|
190 |
) -> gr.Dropdown:
|
191 |
-
|
192 |
-
|
193 |
-
relation_types = arg_pipeline.taskmodule.labels_per_layer["binary_relations"]
|
194 |
else:
|
195 |
raise gr.Error("Unsupported taskmodule for relation types")
|
196 |
|
@@ -202,21 +227,64 @@ def set_relation_types(
|
|
202 |
)
|
203 |
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
def download_processed_documents(
|
206 |
-
|
207 |
-
file_name: str = "
|
208 |
-
) -> str:
|
|
|
|
|
|
|
|
|
|
|
209 |
file_path = os.path.join(tempfile.gettempdir(), file_name)
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
212 |
|
213 |
|
214 |
def upload_processed_documents(
|
215 |
file_name: str,
|
216 |
-
|
217 |
) -> pd.DataFrame:
|
218 |
-
|
219 |
-
|
|
|
|
|
220 |
|
221 |
|
222 |
def clean_spaces(text: str) -> str:
|
@@ -253,21 +321,21 @@ def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[st
|
|
253 |
except arxiv.HTTPError as e:
|
254 |
raise gr.Error(f"Failed to fetch arXiv data: {e}")
|
255 |
if len(result) == 0:
|
256 |
-
raise gr.Error(f"Could not find any paper with
|
257 |
first_result = result[0]
|
258 |
if abstract_only:
|
259 |
abstract_clean = first_result.summary.replace("\n", " ")
|
260 |
return abstract_clean, first_result.entry_id
|
261 |
if "/abs/" not in first_result.entry_id:
|
262 |
raise gr.Error(
|
263 |
-
f"Could not create the HTML URL for
|
264 |
f"an unexpected format: {first_result.entry_id}"
|
265 |
)
|
266 |
html_url = first_result.entry_id.replace("/abs/", "/html/")
|
267 |
request_result = requests.get(html_url)
|
268 |
if request_result.status_code != 200:
|
269 |
raise gr.Error(
|
270 |
-
f"Could not fetch the HTML content for
|
271 |
f"{request_result.status_code}"
|
272 |
)
|
273 |
html_content = request_result.text
|
@@ -275,19 +343,31 @@ def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[st
|
|
275 |
return text_clean, html_url
|
276 |
|
277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
def main():
|
279 |
|
280 |
example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
|
281 |
|
282 |
-
print("Loading
|
283 |
-
argumentation_model
|
284 |
model_name=DEFAULT_MODEL_NAME,
|
285 |
revision=DEFAULT_MODEL_REVISION,
|
286 |
-
embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
|
287 |
-
embedding_max_length=DEFAULT_EMBEDDING_MAX_LENGTH,
|
288 |
-
embedding_batch_size=DEFAULT_EMBEDDING_BATCH_SIZE,
|
289 |
device=DEFAULT_DEVICE,
|
290 |
)
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
default_render_kwargs = {
|
293 |
"entity_options": {
|
@@ -315,21 +395,10 @@ def main():
|
|
315 |
}
|
316 |
|
317 |
with gr.Blocks() as demo:
|
318 |
-
document_store_state = gr.State(
|
319 |
-
DocumentStore(
|
320 |
-
span_annotation_caption="adu",
|
321 |
-
relation_annotation_caption="relation",
|
322 |
-
vector_store=QdrantVectorStore(),
|
323 |
-
document_type=TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
324 |
-
if not HANDLE_PARTS_OF_SAME
|
325 |
-
else TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
326 |
-
span_layer_name="labeled_spans"
|
327 |
-
if not HANDLE_PARTS_OF_SAME
|
328 |
-
else "labeled_multi_spans",
|
329 |
-
)
|
330 |
-
)
|
331 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
332 |
-
models_state = gr.State((argumentation_model, embedding_model))
|
|
|
|
|
333 |
with gr.Row():
|
334 |
with gr.Column(scale=1):
|
335 |
doc_id = gr.Textbox(
|
@@ -341,63 +410,54 @@ def main():
|
|
341 |
lines=20,
|
342 |
value=example_text,
|
343 |
)
|
344 |
-
|
345 |
-
arxiv_id = gr.Textbox(
|
346 |
-
label="arXiv paper ID",
|
347 |
-
placeholder=f"e.g. {DEFAULT_ARXIV_ID}",
|
348 |
-
max_lines=1,
|
349 |
-
)
|
350 |
-
load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False)
|
351 |
-
load_arxiv_btn = gr.Button("Load Text from arXiv", variant="secondary")
|
352 |
-
load_arxiv_btn.click(
|
353 |
-
fn=load_text_from_arxiv,
|
354 |
-
inputs=[arxiv_id, load_arxiv_only_abstract],
|
355 |
-
outputs=[doc_text, doc_id],
|
356 |
-
)
|
357 |
with gr.Accordion("Model Configuration", open=False):
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
)
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
label="Embedding Model Batch Size",
|
379 |
-
minimum=1,
|
380 |
-
maximum=128,
|
381 |
-
step=1,
|
382 |
-
value=DEFAULT_EMBEDDING_BATCH_SIZE,
|
383 |
-
)
|
384 |
device = gr.Textbox(
|
385 |
label="Device (e.g. 'cuda' or 'cpu')",
|
386 |
value=DEFAULT_DEVICE,
|
387 |
)
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
)
|
|
|
401 |
split_regex_escaped = gr.Textbox(
|
402 |
label="Regex to partition the text",
|
403 |
placeholder="Regular expression pattern to split the text into partitions",
|
@@ -406,12 +466,12 @@ def main():
|
|
406 |
|
407 |
predict_btn = gr.Button("Analyse")
|
408 |
|
409 |
-
document_state = gr.State()
|
410 |
-
|
411 |
with gr.Column(scale=1):
|
412 |
|
413 |
-
|
414 |
-
|
|
|
|
|
415 |
|
416 |
with gr.Accordion("Render Options", open=False):
|
417 |
render_as = gr.Dropdown(
|
@@ -424,9 +484,11 @@ def main():
|
|
424 |
lines=5,
|
425 |
value=json.dumps(default_render_kwargs, indent=2),
|
426 |
)
|
427 |
-
|
428 |
|
429 |
-
|
|
|
|
|
430 |
|
431 |
with gr.Column(scale=1):
|
432 |
with gr.Accordion(
|
@@ -436,14 +498,11 @@ def main():
|
|
436 |
headers=["id", "num_adus", "num_relations"],
|
437 |
interactive=False,
|
438 |
)
|
439 |
-
show_max_cross_docu_sims = gr.Checkbox(
|
440 |
-
label="Show max cross-document similarities", value=False
|
441 |
-
)
|
442 |
gr.Markdown("Data Snapshot:")
|
443 |
with gr.Row():
|
444 |
download_processed_documents_btn = gr.DownloadButton("Download")
|
445 |
upload_processed_documents_btn = gr.UploadButton(
|
446 |
-
"Upload", file_types=["
|
447 |
)
|
448 |
|
449 |
upload_btn = gr.UploadButton(
|
@@ -452,9 +511,22 @@ def main():
|
|
452 |
file_count="multiple",
|
453 |
)
|
454 |
|
455 |
-
with gr.Accordion("
|
456 |
-
|
457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
|
459 |
with gr.Accordion("Retrieval Configuration", open=False):
|
460 |
min_similarity = gr.Slider(
|
@@ -462,174 +534,294 @@ def main():
|
|
462 |
minimum=0.0,
|
463 |
maximum=1.0,
|
464 |
step=0.01,
|
465 |
-
value=
|
466 |
)
|
467 |
top_k = gr.Slider(
|
468 |
label="Top K",
|
469 |
minimum=2,
|
470 |
maximum=50,
|
471 |
step=1,
|
472 |
-
value=
|
473 |
)
|
474 |
-
retrieve_similar_adus_btn = gr.Button(
|
475 |
-
|
476 |
-
|
477 |
-
all2all_adu_similarities_button = gr.Button(
|
478 |
-
"Compute all ADU-to-ADU similarities"
|
479 |
)
|
480 |
-
|
481 |
-
headers=["
|
482 |
)
|
483 |
-
|
484 |
-
|
485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
)
|
487 |
|
488 |
-
#
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
499 |
interactive=False,
|
|
|
500 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
|
502 |
render_event_kwargs = dict(
|
503 |
-
fn=render_annotated_document
|
504 |
-
|
|
|
|
|
|
|
|
|
|
|
505 |
outputs=rendered_output,
|
506 |
)
|
507 |
|
508 |
show_overview_kwargs = dict(
|
509 |
-
fn=lambda
|
510 |
-
|
511 |
),
|
512 |
-
inputs=[
|
513 |
outputs=[processed_documents_df],
|
514 |
)
|
515 |
-
predict_btn.click(
|
516 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
inputs=[
|
518 |
doc_text,
|
519 |
doc_id,
|
520 |
-
|
521 |
-
|
522 |
split_regex_escaped,
|
523 |
],
|
524 |
-
outputs=[
|
525 |
api_name="predict",
|
526 |
-
).success(**show_overview_kwargs)
|
527 |
render_btn.click(**render_event_kwargs, api_name="render")
|
528 |
|
529 |
-
|
530 |
-
fn=lambda
|
531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
outputs=[document_json],
|
533 |
-
).success(close_accordion, inputs=[], outputs=[output_accordion]).then(
|
534 |
-
**render_event_kwargs
|
535 |
)
|
536 |
|
537 |
upload_btn.upload(
|
538 |
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
539 |
).then(
|
540 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
inputs=[
|
542 |
upload_btn,
|
543 |
-
|
544 |
-
|
545 |
split_regex_escaped,
|
546 |
-
show_max_cross_docu_sims,
|
547 |
-
min_similarity,
|
548 |
],
|
549 |
outputs=[processed_documents_df],
|
550 |
)
|
551 |
processed_documents_df.select(
|
552 |
-
|
553 |
-
inputs=[processed_documents_df,
|
554 |
-
outputs=[
|
555 |
)
|
556 |
-
show_max_cross_docu_sims.change(**show_overview_kwargs)
|
557 |
|
558 |
download_processed_documents_btn.click(
|
559 |
-
fn=
|
560 |
-
|
|
|
|
|
561 |
outputs=[download_processed_documents_btn],
|
562 |
)
|
563 |
upload_processed_documents_btn.upload(
|
564 |
-
fn=upload_processed_documents
|
565 |
-
|
|
|
|
|
566 |
outputs=[processed_documents_df],
|
567 |
)
|
568 |
|
569 |
retrieve_relevant_adus_event_kwargs = dict(
|
570 |
-
fn=
|
571 |
-
|
572 |
-
|
|
|
|
|
|
|
|
|
573 |
),
|
574 |
inputs=[
|
575 |
-
|
576 |
selected_adu_id,
|
577 |
-
document_state,
|
578 |
min_similarity,
|
579 |
top_k,
|
580 |
-
relation_types,
|
581 |
],
|
582 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
583 |
)
|
584 |
|
585 |
selected_adu_id.change(
|
586 |
-
fn=
|
587 |
-
|
588 |
-
annotation_layer="labeled_spans"
|
589 |
-
if not HANDLE_PARTS_OF_SAME
|
590 |
-
else "labeled_multi_spans",
|
591 |
-
use_predictions=True,
|
592 |
),
|
593 |
-
inputs=[
|
594 |
outputs=[selected_adu_text],
|
595 |
).success(**retrieve_relevant_adus_event_kwargs)
|
596 |
|
597 |
retrieve_similar_adus_btn.click(
|
598 |
-
fn=lambda
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
annotation_layer="labeled_spans"
|
604 |
-
if not HANDLE_PARTS_OF_SAME
|
605 |
-
else "labeled_multi_spans",
|
606 |
),
|
607 |
inputs=[
|
608 |
-
|
609 |
selected_adu_id,
|
610 |
-
document_state,
|
611 |
min_similarity,
|
612 |
top_k,
|
613 |
],
|
614 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
615 |
)
|
616 |
|
617 |
-
|
618 |
-
fn=
|
619 |
-
|
620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
)
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
|
|
|
|
|
|
|
|
626 |
),
|
627 |
-
inputs=[
|
628 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
629 |
)
|
630 |
|
631 |
-
#
|
632 |
-
#
|
|
|
|
|
633 |
# )
|
634 |
|
635 |
rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
import json
|
11 |
import logging
|
12 |
import os.path
|
13 |
import re
|
14 |
import tempfile
|
15 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
|
16 |
|
17 |
import arxiv
|
18 |
import gradio as gr
|
19 |
import pandas as pd
|
20 |
import requests
|
21 |
import torch
|
22 |
+
import yaml
|
23 |
from bs4 import BeautifulSoup
|
24 |
+
from model_utils import (
|
25 |
+
add_annotated_pie_documents_from_dataset,
|
26 |
+
load_argumentation_model,
|
27 |
+
load_retriever,
|
28 |
+
process_texts,
|
29 |
+
retrieve_all_relevant_spans,
|
30 |
+
retrieve_all_similar_spans,
|
31 |
+
retrieve_relevant_spans,
|
32 |
+
retrieve_similar_spans,
|
33 |
)
|
34 |
+
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
35 |
+
from pytorch_ie import Annotation, Pipeline
|
36 |
+
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
37 |
from rendering_utils import HIGHLIGHT_SPANS_JS, render_displacy, render_pretty_table
|
38 |
+
|
39 |
+
from src.langchain_modules import (
|
40 |
+
DocumentAwareSpanRetriever,
|
41 |
+
DocumentAwareSpanRetrieverWithRelations,
|
42 |
+
)
|
43 |
|
44 |
logger = logging.getLogger(__name__)
|
45 |
|
46 |
+
|
47 |
+
def load_retriever_config(path: str) -> str:
|
48 |
+
with open(path, "r") as file:
|
49 |
+
yaml_string = file.read()
|
50 |
+
config = yaml.safe_load(yaml_string)
|
51 |
+
return yaml.dump(config)
|
52 |
+
|
53 |
+
|
54 |
RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
|
55 |
RENDER_WITH_PRETTY_TABLE = "Pretty Table"
|
56 |
|
|
|
59 |
# local path
|
60 |
# DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
|
61 |
# DEFAULT_MODEL_REVISION = None
|
62 |
+
DEFAULT_RETRIEVER_CONFIG = load_retriever_config(
|
63 |
+
"configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml"
|
64 |
+
)
|
65 |
+
# 0.943180 from data_dir="predictions/default/2024-10-15_23-40-18"
|
66 |
+
DEFAULT_MIN_SIMILARITY = 0.95
|
67 |
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
68 |
DEFAULT_SPLIT_REGEX = "\n\n\n+"
|
69 |
DEFAULT_ARXIV_ID = "1706.03762"
|
70 |
+
DEFAULT_LOAD_PIE_DATASET_KWARGS_STR = json.dumps(
|
71 |
+
dict(path="pie/sciarg", name="resolve_parts_of_same", split="train"), indent=2
|
72 |
+
)
|
73 |
|
74 |
# Whether to handle segmented entities in the document. If True, labeled_spans are converted
|
75 |
# to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them.
|
76 |
+
# This requires the networkx package to be installed.
|
77 |
HANDLE_PARTS_OF_SAME = True
|
78 |
+
LAYER_CAPTIONS = {
|
79 |
+
"labeled_multi_spans": "adus",
|
80 |
+
"binary_relations": "relations",
|
81 |
+
"labeled_partitions": "partitions",
|
82 |
+
}
|
83 |
+
RELATION_NAME_MAPPING = {
|
84 |
+
"supports_reversed": "supported by",
|
85 |
+
"contradicts_reversed": "contradicts",
|
86 |
+
}
|
87 |
|
88 |
|
89 |
def escape_regex(regex: str) -> str:
|
|
|
98 |
return result
|
99 |
|
100 |
|
101 |
+
def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict:
|
102 |
+
document = retriever.get_document(doc_id=doc_id)
|
103 |
+
return retriever.docstore.as_dict(document)
|
104 |
+
|
105 |
+
|
106 |
def render_annotated_document(
|
107 |
+
retriever: DocumentAwareSpanRetrieverWithRelations,
|
108 |
+
document_id: str,
|
|
|
|
|
109 |
render_with: str,
|
110 |
render_kwargs_json: str,
|
111 |
) -> str:
|
112 |
+
text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
|
113 |
+
retriever=retriever, document_id=document_id
|
114 |
+
)
|
115 |
+
|
116 |
render_kwargs = json.loads(render_kwargs_json)
|
117 |
if render_with == RENDER_WITH_PRETTY_TABLE:
|
118 |
+
html = render_pretty_table(
|
119 |
+
text=text,
|
120 |
+
spans=spans,
|
121 |
+
span_id2idx=span_id2idx,
|
122 |
+
binary_relations=relations,
|
123 |
+
**render_kwargs,
|
124 |
+
)
|
125 |
elif render_with == RENDER_WITH_DISPLACY:
|
126 |
+
html = render_displacy(
|
127 |
+
text=text,
|
128 |
+
spans=spans,
|
129 |
+
span_id2idx=span_id2idx,
|
130 |
+
binary_relations=relations,
|
131 |
+
**render_kwargs,
|
132 |
+
)
|
133 |
else:
|
134 |
raise ValueError(f"Unknown render_with value: {render_with}")
|
135 |
|
|
|
137 |
|
138 |
|
139 |
def wrapped_process_text(
|
140 |
+
doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs
|
141 |
+
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
try:
|
143 |
+
process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
except Exception as e:
|
145 |
raise gr.Error(f"Failed to process text: {e}")
|
|
|
|
|
|
|
146 |
# Return as dict and document to avoid serialization issues
|
147 |
+
return doc_id
|
148 |
|
149 |
|
150 |
def process_uploaded_files(
|
151 |
+
file_names: List[str], retriever: DocumentAwareSpanRetriever, **kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
) -> pd.DataFrame:
|
153 |
try:
|
154 |
+
doc_ids = []
|
155 |
+
texts = []
|
156 |
for file_name in file_names:
|
157 |
if file_name.lower().endswith(".txt"):
|
158 |
# read the file content
|
159 |
with open(file_name, "r", encoding="utf-8") as f:
|
160 |
text = f.read()
|
161 |
base_file_name = os.path.basename(file_name)
|
162 |
+
doc_ids.append(base_file_name)
|
163 |
+
texts.append(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
else:
|
165 |
raise gr.Error(f"Unsupported file format: {file_name}")
|
166 |
+
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
|
167 |
except Exception as e:
|
168 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
169 |
|
170 |
+
return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True)
|
171 |
+
|
172 |
+
|
173 |
+
def wrapped_add_annotated_pie_documents_from_dataset(
|
174 |
+
retriever: DocumentAwareSpanRetriever, verbose: bool, **kwargs
|
175 |
+
) -> pd.DataFrame:
|
176 |
+
try:
|
177 |
+
add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs)
|
178 |
+
except Exception as e:
|
179 |
+
raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}")
|
180 |
+
return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True)
|
181 |
|
182 |
|
183 |
def open_accordion():
|
|
|
188 |
return gr.Accordion(open=False)
|
189 |
|
190 |
|
191 |
+
def get_cell_for_fixed_column_from_df(
|
192 |
evt: gr.SelectData,
|
193 |
+
df: pd.DataFrame,
|
194 |
+
column: str,
|
195 |
+
) -> Any:
|
196 |
+
"""Get the value of the fixed column for the selected row in the DataFrame.
|
197 |
+
This is required can *not* with a lambda function because that will not get
|
198 |
+
the evt parameter.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
evt: The event object.
|
202 |
+
df: The DataFrame.
|
203 |
+
column: The name of the column.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
The value of the fixed column for the selected row.
|
207 |
+
"""
|
208 |
row_idx, col_idx = evt.index
|
209 |
+
doc_id = df.iloc[row_idx][column]
|
210 |
+
return doc_id
|
|
|
|
|
|
|
|
|
211 |
|
212 |
|
213 |
def set_relation_types(
|
214 |
+
argumentation_model: Pipeline,
|
215 |
default: Optional[List[str]] = None,
|
216 |
) -> gr.Dropdown:
|
217 |
+
if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
|
218 |
+
relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
|
|
|
219 |
else:
|
220 |
raise gr.Error("Unsupported taskmodule for relation types")
|
221 |
|
|
|
227 |
)
|
228 |
|
229 |
|
230 |
+
def get_span_annotation(
|
231 |
+
retriever: DocumentAwareSpanRetriever,
|
232 |
+
span_id: str,
|
233 |
+
) -> Annotation:
|
234 |
+
if span_id.strip() == "":
|
235 |
+
raise gr.Error("No span selected.")
|
236 |
+
try:
|
237 |
+
return retriever.get_span_by_id(span_id=span_id)
|
238 |
+
except Exception as e:
|
239 |
+
raise gr.Error(f"Failed to retrieve span annotation: {e}")
|
240 |
+
|
241 |
+
|
242 |
+
def get_text_spans_and_relations_from_document(
|
243 |
+
retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str
|
244 |
+
) -> Tuple[
|
245 |
+
str,
|
246 |
+
Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
247 |
+
Dict[str, int],
|
248 |
+
Sequence[BinaryRelation],
|
249 |
+
]:
|
250 |
+
document = retriever.get_document(doc_id=document_id)
|
251 |
+
pie_document = retriever.docstore.unwrap(document)
|
252 |
+
use_predicted_annotations = retriever.use_predicted_annotations(document)
|
253 |
+
spans = retriever.get_base_layer(
|
254 |
+
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
|
255 |
+
)
|
256 |
+
relations = retriever.get_relation_layer(
|
257 |
+
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
|
258 |
+
)
|
259 |
+
span_id2idx = retriever.get_span_id2idx_from_doc(document)
|
260 |
+
return pie_document.text, spans, span_id2idx, relations
|
261 |
+
|
262 |
+
|
263 |
def download_processed_documents(
|
264 |
+
retriever: DocumentAwareSpanRetriever,
|
265 |
+
file_name: str = "retriever_store",
|
266 |
+
) -> Optional[str]:
|
267 |
+
if len(retriever.docstore) == 0:
|
268 |
+
gr.Warning("No documents to download.")
|
269 |
+
return None
|
270 |
+
|
271 |
+
# zip the directory
|
272 |
file_path = os.path.join(tempfile.gettempdir(), file_name)
|
273 |
+
|
274 |
+
gr.Info(f"Zipping the retriever store to '{file_name}' ...")
|
275 |
+
result_file_path = retriever.save_to_archive(base_name=file_path, format="zip")
|
276 |
+
|
277 |
+
return result_file_path
|
278 |
|
279 |
|
280 |
def upload_processed_documents(
|
281 |
file_name: str,
|
282 |
+
retriever: DocumentAwareSpanRetriever,
|
283 |
) -> pd.DataFrame:
|
284 |
+
# load the documents from the zip file or directory
|
285 |
+
retriever.load_from_disc(file_name)
|
286 |
+
# return the overview of the document store
|
287 |
+
return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True)
|
288 |
|
289 |
|
290 |
def clean_spaces(text: str) -> str:
|
|
|
321 |
except arxiv.HTTPError as e:
|
322 |
raise gr.Error(f"Failed to fetch arXiv data: {e}")
|
323 |
if len(result) == 0:
|
324 |
+
raise gr.Error(f"Could not find any paper with arXiv ID '{arxiv_id}'")
|
325 |
first_result = result[0]
|
326 |
if abstract_only:
|
327 |
abstract_clean = first_result.summary.replace("\n", " ")
|
328 |
return abstract_clean, first_result.entry_id
|
329 |
if "/abs/" not in first_result.entry_id:
|
330 |
raise gr.Error(
|
331 |
+
f"Could not create the HTML URL for arXiv ID '{arxiv_id}' because its entry ID has "
|
332 |
f"an unexpected format: {first_result.entry_id}"
|
333 |
)
|
334 |
html_url = first_result.entry_id.replace("/abs/", "/html/")
|
335 |
request_result = requests.get(html_url)
|
336 |
if request_result.status_code != 200:
|
337 |
raise gr.Error(
|
338 |
+
f"Could not fetch the HTML content for arXiv ID '{arxiv_id}', status code: "
|
339 |
f"{request_result.status_code}"
|
340 |
)
|
341 |
html_content = request_result.text
|
|
|
343 |
return text_clean, html_url
|
344 |
|
345 |
|
346 |
+
def process_text_from_arxiv(
|
347 |
+
arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs
|
348 |
+
) -> str:
|
349 |
+
try:
|
350 |
+
text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only)
|
351 |
+
except Exception as e:
|
352 |
+
raise gr.Error(f"Failed to load text from arXiv: {e}")
|
353 |
+
return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs)
|
354 |
+
|
355 |
+
|
356 |
def main():
|
357 |
|
358 |
example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
|
359 |
|
360 |
+
print("Loading argumentation model ...")
|
361 |
+
argumentation_model = load_argumentation_model(
|
362 |
model_name=DEFAULT_MODEL_NAME,
|
363 |
revision=DEFAULT_MODEL_REVISION,
|
|
|
|
|
|
|
364 |
device=DEFAULT_DEVICE,
|
365 |
)
|
366 |
+
print("Loading retriever ...")
|
367 |
+
retriever = load_retriever(
|
368 |
+
DEFAULT_RETRIEVER_CONFIG, device=DEFAULT_DEVICE, config_format="yaml"
|
369 |
+
)
|
370 |
+
print("Models loaded.")
|
371 |
|
372 |
default_render_kwargs = {
|
373 |
"entity_options": {
|
|
|
395 |
}
|
396 |
|
397 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
399 |
+
# models_state = gr.State((argumentation_model, embedding_model))
|
400 |
+
argumentation_model_state = gr.State((argumentation_model,))
|
401 |
+
retriever_state = gr.State((retriever,))
|
402 |
with gr.Row():
|
403 |
with gr.Column(scale=1):
|
404 |
doc_id = gr.Textbox(
|
|
|
410 |
lines=20,
|
411 |
value=example_text,
|
412 |
)
|
413 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
with gr.Accordion("Model Configuration", open=False):
|
415 |
+
with gr.Accordion("argumentation structure", open=True):
|
416 |
+
model_name = gr.Textbox(
|
417 |
+
label="Model Name",
|
418 |
+
value=DEFAULT_MODEL_NAME,
|
419 |
+
)
|
420 |
+
model_revision = gr.Textbox(
|
421 |
+
label="Model Revision",
|
422 |
+
value=DEFAULT_MODEL_REVISION,
|
423 |
+
)
|
424 |
+
load_arg_model_btn = gr.Button("Load Argumentation Model")
|
425 |
+
|
426 |
+
with gr.Accordion("retriever", open=True):
|
427 |
+
retriever_config = gr.Textbox(
|
428 |
+
label="Retriever Configuration",
|
429 |
+
placeholder="Configuration for the retriever",
|
430 |
+
value=DEFAULT_RETRIEVER_CONFIG,
|
431 |
+
lines=len(DEFAULT_RETRIEVER_CONFIG.split("\n")),
|
432 |
+
)
|
433 |
+
load_retriever_btn = gr.Button("Load Retriever")
|
434 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
device = gr.Textbox(
|
436 |
label="Device (e.g. 'cuda' or 'cpu')",
|
437 |
value=DEFAULT_DEVICE,
|
438 |
)
|
439 |
+
load_arg_model_btn.click(
|
440 |
+
fn=lambda _model_name, _model_revision, _device: (
|
441 |
+
load_argumentation_model(
|
442 |
+
model_name=_model_name, revision=_model_revision, device=_device
|
443 |
+
),
|
444 |
+
),
|
445 |
+
inputs=[model_name, model_revision, device],
|
446 |
+
outputs=argumentation_model_state,
|
447 |
+
)
|
448 |
+
load_retriever_btn.click(
|
449 |
+
fn=lambda _retriever_config, _device, _previous_retriever: (
|
450 |
+
load_retriever(
|
451 |
+
retriever_config=_retriever_config,
|
452 |
+
device=_device,
|
453 |
+
previous_retriever=_previous_retriever[0],
|
454 |
+
config_format="yaml",
|
455 |
+
),
|
456 |
+
),
|
457 |
+
inputs=[retriever_config, device, retriever_state],
|
458 |
+
outputs=retriever_state,
|
459 |
)
|
460 |
+
|
461 |
split_regex_escaped = gr.Textbox(
|
462 |
label="Regex to partition the text",
|
463 |
placeholder="Regular expression pattern to split the text into partitions",
|
|
|
466 |
|
467 |
predict_btn = gr.Button("Analyse")
|
468 |
|
|
|
|
|
469 |
with gr.Column(scale=1):
|
470 |
|
471 |
+
selected_document_id = gr.Textbox(
|
472 |
+
label="Selected Document", max_lines=1, interactive=False
|
473 |
+
)
|
474 |
+
rendered_output = gr.HTML(label="Rendered Output")
|
475 |
|
476 |
with gr.Accordion("Render Options", open=False):
|
477 |
render_as = gr.Dropdown(
|
|
|
484 |
lines=5,
|
485 |
value=json.dumps(default_render_kwargs, indent=2),
|
486 |
)
|
487 |
+
render_btn = gr.Button("Re-render")
|
488 |
|
489 |
+
with gr.Accordion("See plain result ...", open=False) as document_json_accordion:
|
490 |
+
get_document_json_btn = gr.Button("Fetch annotated document as JSON")
|
491 |
+
document_json = gr.JSON(label="Model Output")
|
492 |
|
493 |
with gr.Column(scale=1):
|
494 |
with gr.Accordion(
|
|
|
498 |
headers=["id", "num_adus", "num_relations"],
|
499 |
interactive=False,
|
500 |
)
|
|
|
|
|
|
|
501 |
gr.Markdown("Data Snapshot:")
|
502 |
with gr.Row():
|
503 |
download_processed_documents_btn = gr.DownloadButton("Download")
|
504 |
upload_processed_documents_btn = gr.UploadButton(
|
505 |
+
"Upload", file_types=["files"]
|
506 |
)
|
507 |
|
508 |
upload_btn = gr.UploadButton(
|
|
|
511 |
file_count="multiple",
|
512 |
)
|
513 |
|
514 |
+
with gr.Accordion("Import text from arXiv", open=False):
|
515 |
+
arxiv_id = gr.Textbox(
|
516 |
+
label="arXiv paper ID",
|
517 |
+
placeholder=f"e.g. {DEFAULT_ARXIV_ID}",
|
518 |
+
max_lines=1,
|
519 |
+
)
|
520 |
+
load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False)
|
521 |
+
load_arxiv_btn = gr.Button("Load & process from arXiv", variant="secondary")
|
522 |
+
|
523 |
+
with gr.Accordion("Import annotated PIE dataset", open=False):
|
524 |
+
load_pie_dataset_kwargs_str = gr.Textbox(
|
525 |
+
label="Parameters for Loading the PIE Dataset",
|
526 |
+
value=DEFAULT_LOAD_PIE_DATASET_KWARGS_STR,
|
527 |
+
lines=len(DEFAULT_LOAD_PIE_DATASET_KWARGS_STR.split("\n")),
|
528 |
+
)
|
529 |
+
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
|
530 |
|
531 |
with gr.Accordion("Retrieval Configuration", open=False):
|
532 |
min_similarity = gr.Slider(
|
|
|
534 |
minimum=0.0,
|
535 |
maximum=1.0,
|
536 |
step=0.01,
|
537 |
+
value=DEFAULT_MIN_SIMILARITY,
|
538 |
)
|
539 |
top_k = gr.Slider(
|
540 |
label="Top K",
|
541 |
minimum=2,
|
542 |
maximum=50,
|
543 |
step=1,
|
544 |
+
value=10,
|
545 |
)
|
546 |
+
retrieve_similar_adus_btn = gr.Button(
|
547 |
+
"Retrieve *similar* ADUs for *selected* ADU"
|
|
|
|
|
|
|
548 |
)
|
549 |
+
similar_adus_df = gr.DataFrame(
|
550 |
+
headers=["doc_id", "adu_id", "score", "text"], interactive=False
|
551 |
)
|
552 |
+
retrieve_all_similar_adus_btn = gr.Button(
|
553 |
+
"Retrieve *similar* ADUs for *all* ADUs in the document"
|
554 |
+
)
|
555 |
+
all_similar_adus_df = gr.DataFrame(
|
556 |
+
headers=["doc_id", "query_adu_id", "adu_id", "score", "text"],
|
557 |
+
interactive=False,
|
558 |
+
)
|
559 |
+
retrieve_all_relevant_adus_btn = gr.Button(
|
560 |
+
"Retrieve *relevant* ADUs for *all* ADUs in the document"
|
561 |
+
)
|
562 |
+
all_relevant_adus_df = gr.DataFrame(
|
563 |
+
headers=["doc_id", "adu_id", "score", "text"], interactive=False
|
564 |
)
|
565 |
|
566 |
+
# currently not used
|
567 |
+
# relation_types = set_relation_types(
|
568 |
+
# argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"]
|
569 |
+
# )
|
570 |
+
|
571 |
+
# Dummy textbox to hold the hover adu id. On click on the rendered output,
|
572 |
+
# its content will be copied to selected_adu_id which will trigger the retrieval.
|
573 |
+
hover_adu_id = gr.Textbox(
|
574 |
+
label="ID (hover)",
|
575 |
+
elem_id="hover_adu_id",
|
576 |
+
interactive=False,
|
577 |
+
visible=False,
|
578 |
+
)
|
579 |
+
selected_adu_id = gr.Textbox(
|
580 |
+
label="ID (selected)",
|
581 |
+
elem_id="selected_adu_id",
|
582 |
interactive=False,
|
583 |
+
visible=False,
|
584 |
)
|
585 |
+
selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False)
|
586 |
+
|
587 |
+
with gr.Accordion("Relevant ADUs from other documents", open=True):
|
588 |
+
|
589 |
+
relevant_adus_df = gr.DataFrame(
|
590 |
+
headers=[
|
591 |
+
"relation",
|
592 |
+
"adu",
|
593 |
+
"reference_adu",
|
594 |
+
"doc_id",
|
595 |
+
"sim_score",
|
596 |
+
"rel_score",
|
597 |
+
],
|
598 |
+
interactive=False,
|
599 |
+
)
|
600 |
|
601 |
render_event_kwargs = dict(
|
602 |
+
fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document(
|
603 |
+
retriever=_retriever[0],
|
604 |
+
document_id=_document_id,
|
605 |
+
render_with=_render_as,
|
606 |
+
render_kwargs_json=_render_kwargs,
|
607 |
+
),
|
608 |
+
inputs=[retriever_state, selected_document_id, render_as, render_kwargs],
|
609 |
outputs=rendered_output,
|
610 |
)
|
611 |
|
612 |
show_overview_kwargs = dict(
|
613 |
+
fn=lambda _retriever: _retriever[0].docstore.overview(
|
614 |
+
layer_captions=LAYER_CAPTIONS, use_predictions=True
|
615 |
),
|
616 |
+
inputs=[retriever_state],
|
617 |
outputs=[processed_documents_df],
|
618 |
)
|
619 |
+
predict_btn.click(
|
620 |
+
fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text(
|
621 |
+
text=_doc_text,
|
622 |
+
doc_id=_doc_id,
|
623 |
+
argumentation_model=_argumentation_model[0],
|
624 |
+
retriever=_retriever[0],
|
625 |
+
split_regex_escaped=(
|
626 |
+
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
|
627 |
+
),
|
628 |
+
handle_parts_of_same=HANDLE_PARTS_OF_SAME,
|
629 |
+
),
|
630 |
inputs=[
|
631 |
doc_text,
|
632 |
doc_id,
|
633 |
+
argumentation_model_state,
|
634 |
+
retriever_state,
|
635 |
split_regex_escaped,
|
636 |
],
|
637 |
+
outputs=[selected_document_id],
|
638 |
api_name="predict",
|
639 |
+
).success(**show_overview_kwargs).success(**render_event_kwargs)
|
640 |
render_btn.click(**render_event_kwargs, api_name="render")
|
641 |
|
642 |
+
load_arxiv_btn.click(
|
643 |
+
fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv(
|
644 |
+
arxiv_id=_arxiv_id,
|
645 |
+
abstract_only=_load_arxiv_only_abstract,
|
646 |
+
argumentation_model=_argumentation_model[0],
|
647 |
+
retriever=_retriever[0],
|
648 |
+
split_regex_escaped=(
|
649 |
+
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
|
650 |
+
),
|
651 |
+
handle_parts_of_same=HANDLE_PARTS_OF_SAME,
|
652 |
+
),
|
653 |
+
inputs=[
|
654 |
+
arxiv_id,
|
655 |
+
load_arxiv_only_abstract,
|
656 |
+
argumentation_model_state,
|
657 |
+
retriever_state,
|
658 |
+
split_regex_escaped,
|
659 |
+
],
|
660 |
+
outputs=[selected_document_id],
|
661 |
+
api_name="predict",
|
662 |
+
).success(**show_overview_kwargs)
|
663 |
+
|
664 |
+
load_pie_dataset_btn.click(
|
665 |
+
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
666 |
+
).then(
|
667 |
+
fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset(
|
668 |
+
retriever=_retriever[0], verbose=True, **json.loads(_load_pie_dataset_kwargs_str)
|
669 |
+
),
|
670 |
+
inputs=[retriever_state, load_pie_dataset_kwargs_str],
|
671 |
+
outputs=[processed_documents_df],
|
672 |
+
)
|
673 |
+
|
674 |
+
selected_document_id.change(**render_event_kwargs)
|
675 |
+
|
676 |
+
get_document_json_btn.click(
|
677 |
+
fn=lambda _retriever, _document_id: get_document_as_dict(
|
678 |
+
retriever=_retriever[0], doc_id=_document_id
|
679 |
+
),
|
680 |
+
inputs=[retriever_state, selected_document_id],
|
681 |
outputs=[document_json],
|
|
|
|
|
682 |
)
|
683 |
|
684 |
upload_btn.upload(
|
685 |
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
686 |
).then(
|
687 |
+
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files(
|
688 |
+
file_names=_file_names,
|
689 |
+
argumentation_model=_argumentation_model[0],
|
690 |
+
retriever=_retriever[0],
|
691 |
+
split_regex_escaped=unescape_regex(_split_regex_escaped),
|
692 |
+
handle_parts_of_same=HANDLE_PARTS_OF_SAME,
|
693 |
+
),
|
694 |
inputs=[
|
695 |
upload_btn,
|
696 |
+
argumentation_model_state,
|
697 |
+
retriever_state,
|
698 |
split_regex_escaped,
|
|
|
|
|
699 |
],
|
700 |
outputs=[processed_documents_df],
|
701 |
)
|
702 |
processed_documents_df.select(
|
703 |
+
fn=get_cell_for_fixed_column_from_df,
|
704 |
+
inputs=[processed_documents_df, gr.State("doc_id")],
|
705 |
+
outputs=[selected_document_id],
|
706 |
)
|
|
|
707 |
|
708 |
download_processed_documents_btn.click(
|
709 |
+
fn=lambda _retriever: download_processed_documents(
|
710 |
+
_retriever[0], file_name="processed_documents"
|
711 |
+
),
|
712 |
+
inputs=[retriever_state],
|
713 |
outputs=[download_processed_documents_btn],
|
714 |
)
|
715 |
upload_processed_documents_btn.upload(
|
716 |
+
fn=lambda file_name, _retriever: upload_processed_documents(
|
717 |
+
file_name, retriever=_retriever[0]
|
718 |
+
),
|
719 |
+
inputs=[upload_processed_documents_btn, retriever_state],
|
720 |
outputs=[processed_documents_df],
|
721 |
)
|
722 |
|
723 |
retrieve_relevant_adus_event_kwargs = dict(
|
724 |
+
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
|
725 |
+
retriever=_retriever[0],
|
726 |
+
query_span_id=_selected_adu_id,
|
727 |
+
k=_top_k,
|
728 |
+
score_threshold=_min_similarity,
|
729 |
+
relation_label_mapping=RELATION_NAME_MAPPING,
|
730 |
+
# columns=relevant_adus.headers
|
731 |
),
|
732 |
inputs=[
|
733 |
+
retriever_state,
|
734 |
selected_adu_id,
|
|
|
735 |
min_similarity,
|
736 |
top_k,
|
|
|
737 |
],
|
738 |
+
outputs=[relevant_adus_df],
|
739 |
+
)
|
740 |
+
relevant_adus_df.select(
|
741 |
+
fn=get_cell_for_fixed_column_from_df,
|
742 |
+
inputs=[relevant_adus_df, gr.State("doc_id")],
|
743 |
+
outputs=[selected_document_id],
|
744 |
)
|
745 |
|
746 |
selected_adu_id.change(
|
747 |
+
fn=lambda _retriever, _selected_adu_id: get_span_annotation(
|
748 |
+
retriever=_retriever[0], span_id=_selected_adu_id
|
|
|
|
|
|
|
|
|
749 |
),
|
750 |
+
inputs=[retriever_state, selected_adu_id],
|
751 |
outputs=[selected_adu_text],
|
752 |
).success(**retrieve_relevant_adus_event_kwargs)
|
753 |
|
754 |
retrieve_similar_adus_btn.click(
|
755 |
+
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
|
756 |
+
retriever=_retriever[0],
|
757 |
+
query_span_id=_selected_adu_id,
|
758 |
+
k=_tok_k,
|
759 |
+
score_threshold=_min_similarity,
|
|
|
|
|
|
|
760 |
),
|
761 |
inputs=[
|
762 |
+
retriever_state,
|
763 |
selected_adu_id,
|
|
|
764 |
min_similarity,
|
765 |
top_k,
|
766 |
],
|
767 |
+
outputs=[similar_adus_df],
|
768 |
+
)
|
769 |
+
similar_adus_df.select(
|
770 |
+
fn=get_cell_for_fixed_column_from_df,
|
771 |
+
inputs=[similar_adus_df, gr.State("doc_id")],
|
772 |
+
outputs=[selected_document_id],
|
773 |
)
|
774 |
|
775 |
+
retrieve_all_similar_adus_btn.click(
|
776 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
|
777 |
+
retriever=_retriever[0],
|
778 |
+
query_doc_id=_document_id,
|
779 |
+
k=_tok_k,
|
780 |
+
score_threshold=_min_similarity,
|
781 |
+
query_span_id_column="query_span_id",
|
782 |
+
),
|
783 |
+
inputs=[
|
784 |
+
retriever_state,
|
785 |
+
selected_document_id,
|
786 |
+
min_similarity,
|
787 |
+
top_k,
|
788 |
+
],
|
789 |
+
outputs=[all_similar_adus_df],
|
790 |
)
|
791 |
+
|
792 |
+
retrieve_all_relevant_adus_btn.click(
|
793 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans(
|
794 |
+
retriever=_retriever[0],
|
795 |
+
query_doc_id=_document_id,
|
796 |
+
k=_tok_k,
|
797 |
+
score_threshold=_min_similarity,
|
798 |
+
query_span_id_column="query_span_id",
|
799 |
),
|
800 |
+
inputs=[
|
801 |
+
retriever_state,
|
802 |
+
selected_document_id,
|
803 |
+
min_similarity,
|
804 |
+
top_k,
|
805 |
+
],
|
806 |
+
outputs=[all_relevant_adus_df],
|
807 |
+
)
|
808 |
+
|
809 |
+
# select query span id from the "retrieve all" result data frames
|
810 |
+
all_similar_adus_df.select(
|
811 |
+
fn=get_cell_for_fixed_column_from_df,
|
812 |
+
inputs=[all_similar_adus_df, gr.State("query_span_id")],
|
813 |
+
outputs=[selected_adu_id],
|
814 |
+
)
|
815 |
+
all_relevant_adus_df.select(
|
816 |
+
fn=get_cell_for_fixed_column_from_df,
|
817 |
+
inputs=[all_relevant_adus_df, gr.State("query_span_id")],
|
818 |
+
outputs=[selected_adu_id],
|
819 |
)
|
820 |
|
821 |
+
# argumentation_model_state.change(
|
822 |
+
# fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]),
|
823 |
+
# inputs=[argumentation_model_state],
|
824 |
+
# outputs=[relation_types],
|
825 |
# )
|
826 |
|
827 |
rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
|
model_utils.py
CHANGED
@@ -1,10 +1,10 @@
|
|
|
|
1 |
import logging
|
2 |
-
from typing import Optional,
|
3 |
|
4 |
import gradio as gr
|
5 |
-
import
|
6 |
-
from
|
7 |
-
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
|
8 |
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
|
9 |
from pytorch_ie import Pipeline
|
10 |
from pytorch_ie.annotations import LabeledSpan
|
@@ -13,31 +13,35 @@ from pytorch_ie.documents import (
|
|
13 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
14 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
15 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
19 |
|
20 |
def annotate_document(
|
21 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
22 |
-
|
23 |
-
embedding_model: Optional[EmbeddingModel] = None,
|
24 |
handle_parts_of_same: bool = False,
|
25 |
) -> Union[
|
26 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
27 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
28 |
]:
|
29 |
-
"""Annotate a document with the provided pipeline.
|
30 |
-
extract embeddings for the labeled spans.
|
31 |
|
32 |
Args:
|
33 |
document: The document to annotate.
|
34 |
-
|
35 |
-
embedding_model: The embedding model to use for extracting text span embeddings.
|
36 |
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
|
37 |
"""
|
38 |
|
39 |
# execute prediction pipeline
|
40 |
-
|
41 |
|
42 |
if handle_parts_of_same:
|
43 |
merger = SpansViaRelationMerger(
|
@@ -53,22 +57,6 @@ def annotate_document(
|
|
53 |
)
|
54 |
document = merger(document)
|
55 |
|
56 |
-
if embedding_model is not None:
|
57 |
-
text_span_embeddings = embedding_model(
|
58 |
-
document=document,
|
59 |
-
span_layer_name="labeled_spans" if not handle_parts_of_same else "labeled_multi_spans",
|
60 |
-
)
|
61 |
-
# convert keys to str because JSON keys must be strings
|
62 |
-
text_span_embeddings_dict = {
|
63 |
-
labeled_span_to_id(k): v.tolist() for k, v in text_span_embeddings.items()
|
64 |
-
}
|
65 |
-
document.metadata["embeddings"] = text_span_embeddings_dict
|
66 |
-
else:
|
67 |
-
gr.Warning(
|
68 |
-
"No embedding model provided. Skipping embedding extraction. You can load an embedding "
|
69 |
-
"model in the 'Model Configuration' section."
|
70 |
-
)
|
71 |
-
|
72 |
return document
|
73 |
|
74 |
|
@@ -101,6 +89,80 @@ def create_document(
|
|
101 |
return document
|
102 |
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
def load_argumentation_model(
|
105 |
model_name: str,
|
106 |
revision: Optional[str] = None,
|
@@ -133,48 +195,260 @@ def load_argumentation_model(
|
|
133 |
return model
|
134 |
|
135 |
|
136 |
-
def
|
137 |
-
|
138 |
-
|
139 |
-
embedding_max_length: int = 512,
|
140 |
-
embedding_batch_size: int = 16,
|
141 |
device: str = "cpu",
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
)
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
-
return embedding_model
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
)
|
179 |
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
import logging
|
3 |
+
from typing import Iterable, Optional, Sequence, Union
|
4 |
|
5 |
import gradio as gr
|
6 |
+
import pandas as pd
|
7 |
+
from pie_datasets import Dataset, IterableDataset, load_dataset
|
|
|
8 |
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
|
9 |
from pytorch_ie import Pipeline
|
10 |
from pytorch_ie.annotations import LabeledSpan
|
|
|
13 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
14 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
15 |
)
|
16 |
+
from typing_extensions import Protocol
|
17 |
+
|
18 |
+
from src.langchain_modules import DocumentAwareSpanRetriever
|
19 |
+
from src.langchain_modules.span_retriever import (
|
20 |
+
DocumentAwareSpanRetrieverWithRelations,
|
21 |
+
_parse_config,
|
22 |
+
)
|
23 |
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
26 |
|
27 |
def annotate_document(
|
28 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
29 |
+
argumentation_model: Pipeline,
|
|
|
30 |
handle_parts_of_same: bool = False,
|
31 |
) -> Union[
|
32 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
33 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
34 |
]:
|
35 |
+
"""Annotate a document with the provided pipeline.
|
|
|
36 |
|
37 |
Args:
|
38 |
document: The document to annotate.
|
39 |
+
argumentation_model: The pipeline to use for annotation.
|
|
|
40 |
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
|
41 |
"""
|
42 |
|
43 |
# execute prediction pipeline
|
44 |
+
argumentation_model(document)
|
45 |
|
46 |
if handle_parts_of_same:
|
47 |
merger = SpansViaRelationMerger(
|
|
|
57 |
)
|
58 |
document = merger(document)
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
return document
|
61 |
|
62 |
|
|
|
89 |
return document
|
90 |
|
91 |
|
92 |
+
def add_annotated_pie_documents(
|
93 |
+
retriever: DocumentAwareSpanRetriever,
|
94 |
+
pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
|
95 |
+
use_predicted_annotations: bool,
|
96 |
+
verbose: bool = False,
|
97 |
+
) -> None:
|
98 |
+
if verbose:
|
99 |
+
gr.Info(f"Create span embeddings for {len(pie_documents)} documents...")
|
100 |
+
num_docs_before = len(retriever.docstore)
|
101 |
+
retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations)
|
102 |
+
# number of documents that were overwritten
|
103 |
+
num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore)
|
104 |
+
# warn if documents were overwritten
|
105 |
+
if num_overwritten_docs > 0:
|
106 |
+
gr.Warning(f"{num_overwritten_docs} documents were overwritten.")
|
107 |
+
|
108 |
+
|
109 |
+
def process_texts(
|
110 |
+
texts: Iterable[str],
|
111 |
+
doc_ids: Iterable[str],
|
112 |
+
argumentation_model: Pipeline,
|
113 |
+
retriever: DocumentAwareSpanRetriever,
|
114 |
+
split_regex_escaped: Optional[str],
|
115 |
+
handle_parts_of_same: bool = False,
|
116 |
+
verbose: bool = False,
|
117 |
+
) -> None:
|
118 |
+
# check that doc_ids are unique
|
119 |
+
if len(set(doc_ids)) != len(list(doc_ids)):
|
120 |
+
raise gr.Error("Document IDs must be unique.")
|
121 |
+
pie_documents = [
|
122 |
+
create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped)
|
123 |
+
for text, doc_id in zip(texts, doc_ids)
|
124 |
+
]
|
125 |
+
if verbose:
|
126 |
+
gr.Info(f"Annotate {len(pie_documents)} documents...")
|
127 |
+
pie_documents = [
|
128 |
+
annotate_document(
|
129 |
+
document=pie_document,
|
130 |
+
argumentation_model=argumentation_model,
|
131 |
+
handle_parts_of_same=handle_parts_of_same,
|
132 |
+
)
|
133 |
+
for pie_document in pie_documents
|
134 |
+
]
|
135 |
+
add_annotated_pie_documents(
|
136 |
+
retriever=retriever,
|
137 |
+
pie_documents=pie_documents,
|
138 |
+
use_predicted_annotations=True,
|
139 |
+
verbose=verbose,
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def add_annotated_pie_documents_from_dataset(
|
144 |
+
retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs
|
145 |
+
) -> None:
|
146 |
+
try:
|
147 |
+
gr.Info(
|
148 |
+
"Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2)
|
149 |
+
)
|
150 |
+
dataset = load_dataset(**load_dataset_kwargs)
|
151 |
+
if not isinstance(dataset, (Dataset, IterableDataset)):
|
152 |
+
raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
|
153 |
+
dataset_converted = dataset.to_document_type(
|
154 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
155 |
+
)
|
156 |
+
add_annotated_pie_documents(
|
157 |
+
retriever=retriever,
|
158 |
+
pie_documents=dataset_converted,
|
159 |
+
use_predicted_annotations=False,
|
160 |
+
verbose=verbose,
|
161 |
+
)
|
162 |
+
except Exception as e:
|
163 |
+
raise gr.Error(f"Failed to load dataset: {e}")
|
164 |
+
|
165 |
+
|
166 |
def load_argumentation_model(
|
167 |
model_name: str,
|
168 |
revision: Optional[str] = None,
|
|
|
195 |
return model
|
196 |
|
197 |
|
198 |
+
def load_retriever(
|
199 |
+
retriever_config: str,
|
200 |
+
config_format: str,
|
|
|
|
|
201 |
device: str = "cpu",
|
202 |
+
previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
|
203 |
+
) -> DocumentAwareSpanRetrieverWithRelations:
|
204 |
+
try:
|
205 |
+
retriever_config = _parse_config(retriever_config, format=config_format)
|
206 |
+
# set device for the embeddings pipeline
|
207 |
+
retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
|
208 |
+
result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
|
209 |
+
# if a previous retriever is provided, load all documents and vectors from the previous retriever
|
210 |
+
if previous_retriever is not None:
|
211 |
+
# documents
|
212 |
+
all_doc_ids = list(previous_retriever.docstore.yield_keys())
|
213 |
+
gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...")
|
214 |
+
all_docs = previous_retriever.docstore.mget(all_doc_ids)
|
215 |
+
result.docstore.mset([(doc.id, doc) for doc in all_docs])
|
216 |
+
# spans (with vectors)
|
217 |
+
all_span_ids = list(previous_retriever.vectorstore.yield_keys())
|
218 |
+
all_spans = previous_retriever.vectorstore.mget(all_span_ids)
|
219 |
+
result.vectorstore.mset([(span.id, span) for span in all_spans])
|
220 |
+
|
221 |
+
gr.Info("Retriever loaded successfully.")
|
222 |
+
return result
|
223 |
+
except Exception as e:
|
224 |
+
raise gr.Error(f"Failed to load retriever: {e}")
|
225 |
+
|
226 |
+
|
227 |
+
def retrieve_similar_spans(
|
228 |
+
retriever: DocumentAwareSpanRetriever,
|
229 |
+
query_span_id: str,
|
230 |
+
**kwargs,
|
231 |
+
) -> pd.DataFrame:
|
232 |
+
if not query_span_id.strip():
|
233 |
+
raise gr.Error("No query span selected.")
|
234 |
+
try:
|
235 |
+
retrieval_result = retriever.invoke(input=query_span_id, **kwargs)
|
236 |
+
records = []
|
237 |
+
for similar_span_doc in retrieval_result:
|
238 |
+
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
|
239 |
+
span_ann = metadata["attached_span"]
|
240 |
+
records.append(
|
241 |
+
{
|
242 |
+
"doc_id": pie_doc.id,
|
243 |
+
"span_id": similar_span_doc.id,
|
244 |
+
"score": metadata["relevance_score"],
|
245 |
+
"label": span_ann.label,
|
246 |
+
"text": str(span_ann),
|
247 |
+
}
|
248 |
)
|
249 |
+
return (
|
250 |
+
pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"])
|
251 |
+
.sort_values(by="score", ascending=False)
|
252 |
+
.round(3)
|
253 |
+
)
|
254 |
+
except Exception as e:
|
255 |
+
raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
|
256 |
+
|
257 |
+
|
258 |
+
def retrieve_relevant_spans(
|
259 |
+
retriever: DocumentAwareSpanRetriever,
|
260 |
+
query_span_id: str,
|
261 |
+
relation_label_mapping: Optional[dict[str, str]] = None,
|
262 |
+
**kwargs,
|
263 |
+
) -> pd.DataFrame:
|
264 |
+
if not query_span_id.strip():
|
265 |
+
raise gr.Error("No query span selected.")
|
266 |
+
try:
|
267 |
+
relation_label_mapping = relation_label_mapping or {}
|
268 |
+
retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs)
|
269 |
+
records = []
|
270 |
+
for relevant_span_doc in retrieval_result:
|
271 |
+
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc)
|
272 |
+
span_ann = metadata["attached_span"]
|
273 |
+
tail_span_ann = metadata["attached_tail_span"]
|
274 |
+
mapped_relation_label = relation_label_mapping.get(
|
275 |
+
metadata["relation_label"], metadata["relation_label"]
|
276 |
+
)
|
277 |
+
records.append(
|
278 |
+
{
|
279 |
+
"doc_id": pie_doc.id,
|
280 |
+
"type": mapped_relation_label,
|
281 |
+
"rel_score": metadata["relation_score"],
|
282 |
+
"text": str(tail_span_ann),
|
283 |
+
"span_id": relevant_span_doc.id,
|
284 |
+
"label": tail_span_ann.label,
|
285 |
+
"ref_score": metadata["relevance_score"],
|
286 |
+
"ref_label": span_ann.label,
|
287 |
+
"ref_text": str(span_ann),
|
288 |
+
"ref_span_id": metadata["head_id"],
|
289 |
+
}
|
290 |
+
)
|
291 |
+
return (
|
292 |
+
pd.DataFrame(
|
293 |
+
records,
|
294 |
+
columns=[
|
295 |
+
"type",
|
296 |
+
# omitted for now, we get no valid relation scores for the generative model
|
297 |
+
# "rel_score",
|
298 |
+
"ref_score",
|
299 |
+
"label",
|
300 |
+
"text",
|
301 |
+
"ref_label",
|
302 |
+
"ref_text",
|
303 |
+
"doc_id",
|
304 |
+
"span_id",
|
305 |
+
"ref_span_id",
|
306 |
+
],
|
307 |
+
)
|
308 |
+
.sort_values(by=["ref_score"], ascending=False)
|
309 |
+
.round(3)
|
310 |
+
)
|
311 |
+
except Exception as e:
|
312 |
+
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
|
313 |
|
|
|
314 |
|
315 |
+
class RetrieverCallable(Protocol):
|
316 |
+
def __call__(
|
317 |
+
self,
|
318 |
+
retriever: DocumentAwareSpanRetriever,
|
319 |
+
query_span_id: str,
|
320 |
+
**kwargs,
|
321 |
+
) -> Optional[pd.DataFrame]:
|
322 |
+
pass
|
323 |
|
324 |
+
|
325 |
+
def _retrieve_for_all_spans(
|
326 |
+
retriever: DocumentAwareSpanRetriever,
|
327 |
+
query_doc_id: str,
|
328 |
+
retrieve_func: RetrieverCallable,
|
329 |
+
query_span_id_column: str = "query_span_id",
|
330 |
+
**kwargs,
|
331 |
+
) -> Optional[pd.DataFrame]:
|
332 |
+
if not query_doc_id.strip():
|
333 |
+
raise gr.Error("No query document selected.")
|
334 |
+
try:
|
335 |
+
span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id)
|
336 |
+
gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...")
|
337 |
+
span_results = {
|
338 |
+
query_span_id: retrieve_func(
|
339 |
+
retriever=retriever,
|
340 |
+
query_span_id=query_span_id,
|
341 |
+
**kwargs,
|
342 |
+
)
|
343 |
+
for query_span_id in span_id2idx.keys()
|
344 |
+
}
|
345 |
+
span_results_not_empty = {
|
346 |
+
query_span_id: df
|
347 |
+
for query_span_id, df in span_results.items()
|
348 |
+
if df is not None and not df.empty
|
349 |
+
}
|
350 |
+
|
351 |
+
# add column with query_span_id
|
352 |
+
for query_span_id, query_span_result in span_results_not_empty.items():
|
353 |
+
query_span_result[query_span_id_column] = query_span_id
|
354 |
+
|
355 |
+
if len(span_results_not_empty) == 0:
|
356 |
+
gr.Info(f"No results found for any ADU in document {query_doc_id}.")
|
357 |
+
return None
|
358 |
+
else:
|
359 |
+
result = pd.concat(span_results_not_empty.values(), ignore_index=True)
|
360 |
+
gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.")
|
361 |
+
return result
|
362 |
+
except Exception as e:
|
363 |
+
raise gr.Error(
|
364 |
+
f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}'
|
365 |
+
)
|
366 |
+
|
367 |
+
|
368 |
+
def retrieve_all_similar_spans(
|
369 |
+
retriever: DocumentAwareSpanRetriever,
|
370 |
+
query_doc_id: str,
|
371 |
+
**kwargs,
|
372 |
+
) -> Optional[pd.DataFrame]:
|
373 |
+
return _retrieve_for_all_spans(
|
374 |
+
retriever=retriever,
|
375 |
+
query_doc_id=query_doc_id,
|
376 |
+
retrieve_func=retrieve_similar_spans,
|
377 |
+
**kwargs,
|
378 |
)
|
379 |
|
380 |
+
|
381 |
+
def retrieve_all_relevant_spans(
|
382 |
+
retriever: DocumentAwareSpanRetriever,
|
383 |
+
query_doc_id: str,
|
384 |
+
**kwargs,
|
385 |
+
) -> Optional[pd.DataFrame]:
|
386 |
+
return _retrieve_for_all_spans(
|
387 |
+
retriever=retriever,
|
388 |
+
query_doc_id=query_doc_id,
|
389 |
+
retrieve_func=retrieve_relevant_spans,
|
390 |
+
**kwargs,
|
391 |
+
)
|
392 |
+
|
393 |
+
|
394 |
+
class RetrieverForAllSpansCallable(Protocol):
|
395 |
+
def __call__(
|
396 |
+
self,
|
397 |
+
retriever: DocumentAwareSpanRetriever,
|
398 |
+
query_doc_id: str,
|
399 |
+
**kwargs,
|
400 |
+
) -> Optional[pd.DataFrame]:
|
401 |
+
pass
|
402 |
+
|
403 |
+
|
404 |
+
def _retrieve_for_all_documents(
|
405 |
+
retriever: DocumentAwareSpanRetriever,
|
406 |
+
retrieve_func: RetrieverForAllSpansCallable,
|
407 |
+
query_doc_id_column: str = "query_doc_id",
|
408 |
+
**kwargs,
|
409 |
+
) -> Optional[pd.DataFrame]:
|
410 |
+
try:
|
411 |
+
all_doc_ids = list(retriever.docstore.yield_keys())
|
412 |
+
gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...")
|
413 |
+
doc_results = {
|
414 |
+
doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs)
|
415 |
+
for doc_id in all_doc_ids
|
416 |
+
}
|
417 |
+
doc_results_not_empty = {
|
418 |
+
doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty
|
419 |
+
}
|
420 |
+
# add column with query_doc_id
|
421 |
+
for doc_id, doc_result in doc_results_not_empty.items():
|
422 |
+
doc_result[query_doc_id_column] = doc_id
|
423 |
+
|
424 |
+
if len(doc_results_not_empty) == 0:
|
425 |
+
gr.Info("No results found for any document.")
|
426 |
+
return None
|
427 |
+
else:
|
428 |
+
result = pd.concat(doc_results_not_empty, ignore_index=True)
|
429 |
+
gr.Info(f"Retrieved {len(result)} ADUs for all documents.")
|
430 |
+
return result
|
431 |
+
except Exception as e:
|
432 |
+
raise gr.Error(f"Failed to retrieve results for all documents: {e}")
|
433 |
+
|
434 |
+
|
435 |
+
def retrieve_all_similar_spans_for_all_documents(
|
436 |
+
retriever: DocumentAwareSpanRetriever,
|
437 |
+
**kwargs,
|
438 |
+
) -> Optional[pd.DataFrame]:
|
439 |
+
return _retrieve_for_all_documents(
|
440 |
+
retriever=retriever,
|
441 |
+
retrieve_func=retrieve_all_similar_spans,
|
442 |
+
**kwargs,
|
443 |
+
)
|
444 |
+
|
445 |
+
|
446 |
+
def retrieve_all_relevant_spans_for_all_documents(
|
447 |
+
retriever: DocumentAwareSpanRetriever,
|
448 |
+
**kwargs,
|
449 |
+
) -> Optional[pd.DataFrame]:
|
450 |
+
return _retrieve_for_all_documents(
|
451 |
+
retriever=retriever,
|
452 |
+
retrieve_func=retrieve_all_relevant_spans,
|
453 |
+
**kwargs,
|
454 |
+
)
|
rendering_utils.py
CHANGED
@@ -1,14 +1,9 @@
|
|
1 |
import json
|
2 |
import logging
|
3 |
from collections import defaultdict
|
4 |
-
from typing import Dict, List, Optional, Union
|
5 |
|
6 |
-
from annotation_utils import labeled_span_to_id
|
7 |
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
8 |
-
from pytorch_ie.documents import (
|
9 |
-
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
10 |
-
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
11 |
-
)
|
12 |
from rendering_utils_displacy import EntityRenderer
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
@@ -92,7 +87,20 @@ HIGHLIGHT_SPANS_JS = """
|
|
92 |
});
|
93 |
}
|
94 |
}
|
95 |
-
function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
// get the textarea element that holds the reference adu id
|
97 |
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
|
98 |
// set the value of the input field
|
@@ -104,6 +112,8 @@ HIGHLIGHT_SPANS_JS = """
|
|
104 |
|
105 |
const entities = document.querySelectorAll('.entity');
|
106 |
entities.forEach(entity => {
|
|
|
|
|
107 |
const alreadyHasListener = entity.getAttribute('data-has-listener');
|
108 |
if (alreadyHasListener) {
|
109 |
return;
|
@@ -111,22 +121,30 @@ HIGHLIGHT_SPANS_JS = """
|
|
111 |
entity.addEventListener('mouseover', () => {
|
112 |
const entityId = entity.getAttribute('data-entity-id');
|
113 |
highlightRelationArguments(entityId);
|
114 |
-
|
115 |
});
|
116 |
entity.addEventListener('mouseout', () => {
|
117 |
highlightRelationArguments(null);
|
118 |
});
|
119 |
entity.setAttribute('data-has-listener', 'true');
|
120 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
}
|
122 |
"""
|
123 |
|
124 |
|
125 |
def render_pretty_table(
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
],
|
130 |
**render_kwargs,
|
131 |
):
|
132 |
from prettytable import PrettyTable
|
@@ -134,7 +152,7 @@ def render_pretty_table(
|
|
134 |
t = PrettyTable()
|
135 |
t.field_names = ["head", "tail", "relation"]
|
136 |
t.align = "l"
|
137 |
-
for relation in list(
|
138 |
t.add_row([str(relation.head), str(relation.tail), relation.label])
|
139 |
|
140 |
html = t.get_html_string(format=True)
|
@@ -144,29 +162,19 @@ def render_pretty_table(
|
|
144 |
|
145 |
|
146 |
def render_displacy(
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
],
|
151 |
inject_relations=True,
|
152 |
colors_hover=None,
|
153 |
entity_options={},
|
154 |
**render_kwargs,
|
155 |
):
|
156 |
|
157 |
-
if isinstance(document, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions):
|
158 |
-
span_layer = document.labeled_spans
|
159 |
-
elif isinstance(
|
160 |
-
document, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
161 |
-
):
|
162 |
-
span_layer = document.labeled_multi_spans
|
163 |
-
else:
|
164 |
-
raise ValueError(f"Unsupported document type: {type(document)}")
|
165 |
-
|
166 |
-
span_annotations = list(span_layer) + list(span_layer.predictions)
|
167 |
ents = []
|
168 |
-
for
|
169 |
-
|
170 |
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
|
171 |
# on hover and to inject the relation data.
|
172 |
if isinstance(labeled_span, LabeledSpan):
|
@@ -192,7 +200,7 @@ def render_displacy(
|
|
192 |
raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
|
193 |
|
194 |
spacy_doc = {
|
195 |
-
"text":
|
196 |
# the ents MUST be sorted by start and end
|
197 |
"ents": sorted(ents, key=lambda x: (x["start"], x["end"])),
|
198 |
"title": None,
|
@@ -207,12 +215,10 @@ def render_displacy(
|
|
207 |
|
208 |
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
|
209 |
if inject_relations:
|
210 |
-
binary_relations = list(document.binary_relations) + list(
|
211 |
-
document.binary_relations.predictions
|
212 |
-
)
|
213 |
html = inject_relation_data(
|
214 |
html,
|
215 |
-
|
|
|
216 |
binary_relations=binary_relations,
|
217 |
additional_colors=colors_hover,
|
218 |
)
|
@@ -221,8 +227,9 @@ def render_displacy(
|
|
221 |
|
222 |
def inject_relation_data(
|
223 |
html: str,
|
224 |
-
|
225 |
-
|
|
|
226 |
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
|
227 |
) -> str:
|
228 |
from bs4 import BeautifulSoup
|
@@ -236,7 +243,7 @@ def inject_relation_data(
|
|
236 |
entity2heads[relation.tail].append((relation.head, relation.label))
|
237 |
entity2tails[relation.head].append((relation.tail, relation.label))
|
238 |
|
239 |
-
|
240 |
# Add unique IDs to each entity
|
241 |
entities = soup.find_all(class_="entity")
|
242 |
for entity in entities:
|
@@ -247,7 +254,8 @@ def inject_relation_data(
|
|
247 |
entity[f"data-color-{key}"] = (
|
248 |
json.dumps(color) if isinstance(color, dict) else color
|
249 |
)
|
250 |
-
|
|
|
251 |
|
252 |
# sanity check.
|
253 |
if isinstance(entity_annotation, LabeledSpan):
|
@@ -265,14 +273,16 @@ def inject_relation_data(
|
|
265 |
entity["data-label"] = entity_annotation.label
|
266 |
entity["data-relation-tails"] = json.dumps(
|
267 |
[
|
268 |
-
{"entity-id":
|
269 |
for tail, label in entity2tails.get(entity_annotation, [])
|
|
|
270 |
]
|
271 |
)
|
272 |
entity["data-relation-heads"] = json.dumps(
|
273 |
[
|
274 |
-
{"entity-id":
|
275 |
for head, label in entity2heads.get(entity_annotation, [])
|
|
|
276 |
]
|
277 |
)
|
278 |
|
|
|
1 |
import json
|
2 |
import logging
|
3 |
from collections import defaultdict
|
4 |
+
from typing import Dict, List, Optional, Sequence, Union
|
5 |
|
|
|
6 |
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
|
|
|
|
|
|
|
|
7 |
from rendering_utils_displacy import EntityRenderer
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
|
|
87 |
});
|
88 |
}
|
89 |
}
|
90 |
+
function setHoverAduId(entityId) {
|
91 |
+
// get the textarea element that holds the reference adu id
|
92 |
+
let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
|
93 |
+
// set the value of the input field
|
94 |
+
hoverAduIdDiv.value = entityId;
|
95 |
+
// trigger an input event to update the state
|
96 |
+
var event = new Event('input');
|
97 |
+
hoverAduIdDiv.dispatchEvent(event);
|
98 |
+
}
|
99 |
+
function setReferenceAduIdFromHover() {
|
100 |
+
// get the hover adu id
|
101 |
+
const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
|
102 |
+
// get the value of the input field
|
103 |
+
const entityId = hoverAduIdDiv.value;
|
104 |
// get the textarea element that holds the reference adu id
|
105 |
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
|
106 |
// set the value of the input field
|
|
|
112 |
|
113 |
const entities = document.querySelectorAll('.entity');
|
114 |
entities.forEach(entity => {
|
115 |
+
// make the cursor a pointer
|
116 |
+
entity.style.cursor = 'pointer';
|
117 |
const alreadyHasListener = entity.getAttribute('data-has-listener');
|
118 |
if (alreadyHasListener) {
|
119 |
return;
|
|
|
121 |
entity.addEventListener('mouseover', () => {
|
122 |
const entityId = entity.getAttribute('data-entity-id');
|
123 |
highlightRelationArguments(entityId);
|
124 |
+
setHoverAduId(entityId);
|
125 |
});
|
126 |
entity.addEventListener('mouseout', () => {
|
127 |
highlightRelationArguments(null);
|
128 |
});
|
129 |
entity.setAttribute('data-has-listener', 'true');
|
130 |
});
|
131 |
+
const entityContainer = document.querySelector('.entities');
|
132 |
+
if (entityContainer) {
|
133 |
+
entityContainer.addEventListener('click', () => {
|
134 |
+
setReferenceAduIdFromHover();
|
135 |
+
});
|
136 |
+
// make the cursor a pointer
|
137 |
+
// entityContainer.style.cursor = 'pointer';
|
138 |
+
}
|
139 |
}
|
140 |
"""
|
141 |
|
142 |
|
143 |
def render_pretty_table(
|
144 |
+
text: str,
|
145 |
+
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
146 |
+
span_id2idx: Dict[str, int],
|
147 |
+
binary_relations: Sequence[BinaryRelation],
|
148 |
**render_kwargs,
|
149 |
):
|
150 |
from prettytable import PrettyTable
|
|
|
152 |
t = PrettyTable()
|
153 |
t.field_names = ["head", "tail", "relation"]
|
154 |
t.align = "l"
|
155 |
+
for relation in list(binary_relations) + list(binary_relations):
|
156 |
t.add_row([str(relation.head), str(relation.tail), relation.label])
|
157 |
|
158 |
html = t.get_html_string(format=True)
|
|
|
162 |
|
163 |
|
164 |
def render_displacy(
|
165 |
+
text: str,
|
166 |
+
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
167 |
+
span_id2idx: Dict[str, int],
|
168 |
+
binary_relations: Sequence[BinaryRelation],
|
169 |
inject_relations=True,
|
170 |
colors_hover=None,
|
171 |
entity_options={},
|
172 |
**render_kwargs,
|
173 |
):
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
ents = []
|
176 |
+
for entity_id, idx in span_id2idx.items():
|
177 |
+
labeled_span = spans[idx]
|
178 |
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
|
179 |
# on hover and to inject the relation data.
|
180 |
if isinstance(labeled_span, LabeledSpan):
|
|
|
200 |
raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
|
201 |
|
202 |
spacy_doc = {
|
203 |
+
"text": text,
|
204 |
# the ents MUST be sorted by start and end
|
205 |
"ents": sorted(ents, key=lambda x: (x["start"], x["end"])),
|
206 |
"title": None,
|
|
|
215 |
|
216 |
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
|
217 |
if inject_relations:
|
|
|
|
|
|
|
218 |
html = inject_relation_data(
|
219 |
html,
|
220 |
+
spans=spans,
|
221 |
+
span_id2idx=span_id2idx,
|
222 |
binary_relations=binary_relations,
|
223 |
additional_colors=colors_hover,
|
224 |
)
|
|
|
227 |
|
228 |
def inject_relation_data(
|
229 |
html: str,
|
230 |
+
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
231 |
+
span_id2idx: Dict[str, int],
|
232 |
+
binary_relations: Sequence[BinaryRelation],
|
233 |
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
|
234 |
) -> str:
|
235 |
from bs4 import BeautifulSoup
|
|
|
243 |
entity2heads[relation.tail].append((relation.head, relation.label))
|
244 |
entity2tails[relation.head].append((relation.tail, relation.label))
|
245 |
|
246 |
+
annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()}
|
247 |
# Add unique IDs to each entity
|
248 |
entities = soup.find_all(class_="entity")
|
249 |
for entity in entities:
|
|
|
254 |
entity[f"data-color-{key}"] = (
|
255 |
json.dumps(color) if isinstance(color, dict) else color
|
256 |
)
|
257 |
+
|
258 |
+
entity_annotation = spans[span_id2idx[entity["data-entity-id"]]]
|
259 |
|
260 |
# sanity check.
|
261 |
if isinstance(entity_annotation, LabeledSpan):
|
|
|
273 |
entity["data-label"] = entity_annotation.label
|
274 |
entity["data-relation-tails"] = json.dumps(
|
275 |
[
|
276 |
+
{"entity-id": annotation2id[tail], "label": label}
|
277 |
for tail, label in entity2tails.get(entity_annotation, [])
|
278 |
+
if tail in annotation2id
|
279 |
]
|
280 |
)
|
281 |
entity["data-relation-heads"] = json.dumps(
|
282 |
[
|
283 |
+
{"entity-id": annotation2id[head], "label": label}
|
284 |
for head, label in entity2heads.get(entity_annotation, [])
|
285 |
+
if head in annotation2id
|
286 |
]
|
287 |
)
|
288 |
|
requirements.txt
CHANGED
@@ -1,11 +1,34 @@
|
|
1 |
-
|
2 |
-
gradio==4.36.0
|
3 |
prettytable==3.10.0
|
4 |
-
pie-modules==0.12.0
|
5 |
beautifulsoup4==4.12.3
|
6 |
-
datasets==2.14.4
|
7 |
# numpy 2.0.0 breaks the code
|
8 |
numpy==1.25.2
|
9 |
-
qdrant-client==1.9.1
|
10 |
scipy==1.13.0
|
11 |
arxiv==2.1.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.44.1
|
|
|
2 |
prettytable==3.10.0
|
|
|
3 |
beautifulsoup4==4.12.3
|
|
|
4 |
# numpy 2.0.0 breaks the code
|
5 |
numpy==1.25.2
|
|
|
6 |
scipy==1.13.0
|
7 |
arxiv==2.1.3
|
8 |
+
pyrootutils>=1.0.0,<1.1.0
|
9 |
+
|
10 |
+
########## from root requirements ##########
|
11 |
+
# --------- pytorch-ie --------- #
|
12 |
+
pytorch-ie>=0.29.6,<0.32.0
|
13 |
+
pie-datasets>=0.10.5,<0.11.0
|
14 |
+
pie-modules>=0.14.0,<0.15.0
|
15 |
+
|
16 |
+
# --------- models -------- #
|
17 |
+
adapters>=0.1.2,<0.2.0
|
18 |
+
# ADU retrieval (and demo, in future):
|
19 |
+
langchain>=0.3.0,<0.4.0
|
20 |
+
langchain-core>=0.3.0,<0.4.0
|
21 |
+
langchain-community>=0.3.0,<0.4.0
|
22 |
+
# we use QDrant as vectorstore backend
|
23 |
+
langchain-qdrant>=0.1.0,<0.2.0
|
24 |
+
qdrant-client>=1.12.0,<2.0.0
|
25 |
+
# 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748
|
26 |
+
huggingface_hub<0.26.0 # 0.26 seems to be broken
|
27 |
+
# to to handle segmented entities (if HANDLE_PARTS_OF_SAME=True)
|
28 |
+
networkx>=3.0.0,<4.0.0
|
29 |
+
|
30 |
+
# --------- config --------- #
|
31 |
+
hydra-core>=1.3.0
|
32 |
+
|
33 |
+
# --------- dev --------- #
|
34 |
+
pre-commit # hooks for applying linters on commit
|
retrieve_and_dump_all_relevant.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from demo.model_utils import (
|
14 |
+
retrieve_all_relevant_spans,
|
15 |
+
retrieve_all_relevant_spans_for_all_documents,
|
16 |
+
retrieve_relevant_spans,
|
17 |
+
)
|
18 |
+
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument(
|
27 |
+
"-c",
|
28 |
+
"--config_path",
|
29 |
+
type=str,
|
30 |
+
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--data_path",
|
34 |
+
type=str,
|
35 |
+
required=True,
|
36 |
+
help="Path to a zip or directory containing a retriever dump.",
|
37 |
+
)
|
38 |
+
parser.add_argument("-k", "--top_k", type=int, default=10)
|
39 |
+
parser.add_argument("-t", "--threshold", type=float, default=0.95)
|
40 |
+
parser.add_argument(
|
41 |
+
"-o",
|
42 |
+
"--output_path",
|
43 |
+
type=str,
|
44 |
+
required=True,
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--query_doc_id",
|
48 |
+
type=str,
|
49 |
+
default=None,
|
50 |
+
help="If provided, retrieve all spans for only this query document.",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--query_span_id",
|
54 |
+
type=str,
|
55 |
+
default=None,
|
56 |
+
help="If provided, retrieve all spans for only this query span.",
|
57 |
+
)
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
logging.basicConfig(
|
61 |
+
format="%(asctime)s %(levelname)-8s %(message)s",
|
62 |
+
level=logging.INFO,
|
63 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
64 |
+
)
|
65 |
+
|
66 |
+
if not args.output_path.endswith(".json"):
|
67 |
+
raise ValueError("only support json output")
|
68 |
+
|
69 |
+
logger.info(f"instantiating retriever from {args.config_path}...")
|
70 |
+
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
|
71 |
+
args.config_path
|
72 |
+
)
|
73 |
+
logger.info(f"loading data from {args.data_path}...")
|
74 |
+
retriever.load_from_disc(args.data_path)
|
75 |
+
|
76 |
+
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
|
77 |
+
logger.info(f"use search_kwargs: {search_kwargs}")
|
78 |
+
|
79 |
+
if args.query_span_id is not None:
|
80 |
+
logger.warning(f"retrieving results for single span: {args.query_span_id}")
|
81 |
+
all_spans_for_all_documents = retrieve_relevant_spans(
|
82 |
+
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
|
83 |
+
)
|
84 |
+
elif args.query_doc_id is not None:
|
85 |
+
logger.warning(f"retrieving results for single document: {args.query_doc_id}")
|
86 |
+
all_spans_for_all_documents = retrieve_all_relevant_spans(
|
87 |
+
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
|
91 |
+
retriever=retriever, **search_kwargs
|
92 |
+
)
|
93 |
+
|
94 |
+
if all_spans_for_all_documents is None:
|
95 |
+
logger.warning("no relevant spans found in any document")
|
96 |
+
exit(0)
|
97 |
+
|
98 |
+
logger.info(f"dumping results to {args.output_path}...")
|
99 |
+
all_spans_for_all_documents.to_json(args.output_path)
|
100 |
+
|
101 |
+
logger.info("done")
|
retriever/related_span_retriever_with_relations_from_other_docs.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.langchain_modules.DocumentAwareSpanRetrieverWithRelations
|
2 |
+
reversed_relations_suffix: _reversed
|
3 |
+
relation_labels:
|
4 |
+
- supports_reversed
|
5 |
+
- contradicts_reversed
|
6 |
+
retrieve_from_same_document: false
|
7 |
+
retrieve_from_different_documents: true
|
8 |
+
pie_document_type:
|
9 |
+
_target_: pie_modules.utils.resolve_type
|
10 |
+
type_or_str: pytorch_ie.documents.TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
11 |
+
docstore:
|
12 |
+
_target_: src.langchain_modules.DatasetsPieDocumentStore
|
13 |
+
search_kwargs:
|
14 |
+
k: 10
|
15 |
+
search_type: similarity_score_threshold
|
16 |
+
vectorstore:
|
17 |
+
_target_: src.langchain_modules.QdrantSpanVectorStore
|
18 |
+
embedding:
|
19 |
+
_target_: src.langchain_modules.HuggingFaceSpanEmbeddings
|
20 |
+
model:
|
21 |
+
_target_: src.models.utils.adapters.load_model_with_adapter
|
22 |
+
model_kwargs:
|
23 |
+
pretrained_model_name_or_path: allenai/specter2_base
|
24 |
+
adapter_kwargs:
|
25 |
+
adapter_name_or_path: allenai/specter2
|
26 |
+
load_as: proximity
|
27 |
+
source: hf
|
28 |
+
pipeline_kwargs:
|
29 |
+
tokenizer: allenai/specter2_base
|
30 |
+
stride: 64
|
31 |
+
batch_size: 32
|
32 |
+
model_max_length: 512
|
33 |
+
client:
|
34 |
+
_target_: qdrant_client.QdrantClient
|
35 |
+
location: ":memory:"
|
36 |
+
collection_name: adus
|
37 |
+
vector_params:
|
38 |
+
distance:
|
39 |
+
_target_: qdrant_client.http.models.Distance
|
40 |
+
value: Cosine
|
41 |
+
label_mapping:
|
42 |
+
background_claim:
|
43 |
+
- background_claim
|
44 |
+
- own_claim
|
45 |
+
own_claim:
|
46 |
+
- background_claim
|
47 |
+
- own_claim
|
src/__init__.py
ADDED
File without changes
|
src/hf_pipeline/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .feature_extraction import FeatureExtractionPipelineWithStriding
|
src/hf_pipeline/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (241 Bytes). View file
|
|
src/hf_pipeline/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (239 Bytes). View file
|
|
src/hf_pipeline/__pycache__/feature_extraction.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
src/hf_pipeline/__pycache__/feature_extraction.cpython-39.pyc
ADDED
Binary file (10.9 kB). View file
|
|
src/hf_pipeline/feature_extraction.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers.pipelines.base import ArgumentHandler, ChunkPipeline, Dataset
|
7 |
+
from transformers.utils import is_tf_available, is_torch_available
|
8 |
+
|
9 |
+
if is_tf_available():
|
10 |
+
import tensorflow as tf
|
11 |
+
from transformers.models.auto.modeling_tf_auto import (
|
12 |
+
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
13 |
+
)
|
14 |
+
if is_torch_available():
|
15 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
16 |
+
|
17 |
+
|
18 |
+
def list_of_dicts2dict_of_lists(list_of_dicts: list[dict]) -> dict[str, list]:
|
19 |
+
return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0].keys()}
|
20 |
+
|
21 |
+
|
22 |
+
class FeatureExtractionArgumentHandler(ArgumentHandler):
|
23 |
+
"""Handles arguments for feature extraction."""
|
24 |
+
|
25 |
+
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
26 |
+
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
|
27 |
+
inputs = list(inputs)
|
28 |
+
batch_size = len(inputs)
|
29 |
+
elif isinstance(inputs, str):
|
30 |
+
inputs = [inputs]
|
31 |
+
batch_size = 1
|
32 |
+
elif (
|
33 |
+
Dataset is not None
|
34 |
+
and isinstance(inputs, Dataset)
|
35 |
+
or isinstance(inputs, types.GeneratorType)
|
36 |
+
):
|
37 |
+
return inputs, None
|
38 |
+
else:
|
39 |
+
raise ValueError("At least one input is required.")
|
40 |
+
|
41 |
+
offset_mapping = kwargs.get("offset_mapping")
|
42 |
+
if offset_mapping:
|
43 |
+
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
|
44 |
+
offset_mapping = [offset_mapping]
|
45 |
+
if len(offset_mapping) != batch_size:
|
46 |
+
raise ValueError("offset_mapping should have the same batch size as the input")
|
47 |
+
return inputs, offset_mapping
|
48 |
+
|
49 |
+
|
50 |
+
class FeatureExtractionPipelineWithStriding(ChunkPipeline):
|
51 |
+
"""Same as transformers.FeatureExtractionPipeline, but with long input handling. Inspired by
|
52 |
+
transformers.TokenClassificationPipeline. The functionality is triggered when providing the
|
53 |
+
"stride" parameter (can be 0). When passing "create_unique_embeddings_per_token=True", only one
|
54 |
+
embedding (and other data, see flags below) per token will be returned (this makes use of
|
55 |
+
min_distance_to_border, see "return_min_distance_to_border" below for details). Note that this
|
56 |
+
removes data for special token positions!
|
57 |
+
|
58 |
+
Per default, it will return just the embeddings. If any of the return_ADDITIONAL_RESULT is
|
59 |
+
enabled (see below), it will return dictionaries with "last_hidden_state" and all
|
60 |
+
ADDITIONAL_RESULT depending on the flags.
|
61 |
+
|
62 |
+
Flags to return additional results:
|
63 |
+
return_offset_mapping: If enabled, return the offset mapping.
|
64 |
+
return_special_tokens_mask: If enabled, return the special tokens mask.
|
65 |
+
return_sequence_indices: If enabled, return the sequence indices.
|
66 |
+
return_position_ids: If enabled, return the position ids from, values are in [0, model_max_length).
|
67 |
+
return_min_distance_to_border: If enabled, return the minimum distance to the "border" of
|
68 |
+
the input that gets passed into the model. This is useful when striding is used which may
|
69 |
+
produce multiple embeddings for a token (compare values in offset_mapping). In this case,
|
70 |
+
min_distance_to_border can be used to select the embedding that is more in the center
|
71 |
+
of the input by choosing the entry with the *higher* min_distance_to_border.
|
72 |
+
"""
|
73 |
+
|
74 |
+
default_input_names = "sequences"
|
75 |
+
|
76 |
+
def __init__(self, args_parser=FeatureExtractionArgumentHandler(), *args, **kwargs):
|
77 |
+
super().__init__(*args, **kwargs)
|
78 |
+
self.check_model_type(
|
79 |
+
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
80 |
+
if self.framework == "tf"
|
81 |
+
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
82 |
+
)
|
83 |
+
|
84 |
+
self._args_parser = args_parser
|
85 |
+
|
86 |
+
def _sanitize_parameters(
|
87 |
+
self,
|
88 |
+
offset_mapping: Optional[List[Tuple[int, int]]] = None,
|
89 |
+
stride: Optional[int] = None,
|
90 |
+
create_unique_embeddings_per_token: Optional[bool] = False,
|
91 |
+
return_offset_mapping: Optional[bool] = None,
|
92 |
+
return_special_tokens_mask: Optional[bool] = None,
|
93 |
+
return_sequence_indices: Optional[bool] = None,
|
94 |
+
return_position_ids: Optional[bool] = None,
|
95 |
+
return_min_distance_to_border: Optional[bool] = None,
|
96 |
+
return_tensors=None,
|
97 |
+
):
|
98 |
+
preprocess_params = {}
|
99 |
+
if offset_mapping is not None:
|
100 |
+
preprocess_params["offset_mapping"] = offset_mapping
|
101 |
+
|
102 |
+
if stride is not None:
|
103 |
+
if stride >= self.tokenizer.model_max_length:
|
104 |
+
raise ValueError(
|
105 |
+
"`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
|
106 |
+
)
|
107 |
+
|
108 |
+
if self.tokenizer.is_fast:
|
109 |
+
tokenizer_params = {
|
110 |
+
"return_overflowing_tokens": True,
|
111 |
+
"padding": True,
|
112 |
+
"stride": stride,
|
113 |
+
}
|
114 |
+
preprocess_params["tokenizer_params"] = tokenizer_params # type: ignore
|
115 |
+
else:
|
116 |
+
raise ValueError(
|
117 |
+
"`stride` was provided to process all the text but you're using a slow tokenizer."
|
118 |
+
" Please use a fast tokenizer."
|
119 |
+
)
|
120 |
+
postprocess_params = {}
|
121 |
+
if create_unique_embeddings_per_token is not None:
|
122 |
+
postprocess_params["create_unique_embeddings_per_token"] = (
|
123 |
+
create_unique_embeddings_per_token
|
124 |
+
)
|
125 |
+
if return_offset_mapping is not None:
|
126 |
+
postprocess_params["return_offset_mapping"] = return_offset_mapping
|
127 |
+
if return_special_tokens_mask is not None:
|
128 |
+
postprocess_params["return_special_tokens_mask"] = return_special_tokens_mask
|
129 |
+
if return_sequence_indices is not None:
|
130 |
+
postprocess_params["return_sequence_indices"] = return_sequence_indices
|
131 |
+
if return_position_ids is not None:
|
132 |
+
postprocess_params["return_position_ids"] = return_position_ids
|
133 |
+
if return_min_distance_to_border is not None:
|
134 |
+
postprocess_params["return_min_distance_to_border"] = return_min_distance_to_border
|
135 |
+
if return_tensors is not None:
|
136 |
+
postprocess_params["return_tensors"] = return_tensors
|
137 |
+
return preprocess_params, {}, postprocess_params
|
138 |
+
|
139 |
+
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
140 |
+
|
141 |
+
_inputs, offset_mapping = self._args_parser(inputs, **kwargs)
|
142 |
+
if offset_mapping:
|
143 |
+
kwargs["offset_mapping"] = offset_mapping
|
144 |
+
|
145 |
+
return super().__call__(inputs, **kwargs)
|
146 |
+
|
147 |
+
def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
|
148 |
+
tokenizer_params = preprocess_params.pop("tokenizer_params", {})
|
149 |
+
truncation = (
|
150 |
+
True
|
151 |
+
if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0
|
152 |
+
else False
|
153 |
+
)
|
154 |
+
inputs = self.tokenizer(
|
155 |
+
sentence,
|
156 |
+
return_tensors=self.framework,
|
157 |
+
truncation=truncation,
|
158 |
+
return_special_tokens_mask=True,
|
159 |
+
return_offsets_mapping=self.tokenizer.is_fast,
|
160 |
+
**tokenizer_params,
|
161 |
+
)
|
162 |
+
inputs.pop("overflow_to_sample_mapping", None)
|
163 |
+
num_chunks = len(inputs["input_ids"])
|
164 |
+
|
165 |
+
for i in range(num_chunks):
|
166 |
+
if self.framework == "tf":
|
167 |
+
model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()}
|
168 |
+
else:
|
169 |
+
model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
|
170 |
+
if offset_mapping is not None:
|
171 |
+
model_inputs["offset_mapping"] = offset_mapping
|
172 |
+
model_inputs["sentence"] = sentence if i == 0 else None
|
173 |
+
model_inputs["is_last"] = i == num_chunks - 1
|
174 |
+
|
175 |
+
yield model_inputs
|
176 |
+
|
177 |
+
def _forward(self, model_inputs, **kwargs):
|
178 |
+
# Forward
|
179 |
+
special_tokens_mask = model_inputs.pop("special_tokens_mask")
|
180 |
+
offset_mapping = model_inputs.pop("offset_mapping", None)
|
181 |
+
sentence = model_inputs.pop("sentence")
|
182 |
+
is_last = model_inputs.pop("is_last")
|
183 |
+
if self.framework == "tf":
|
184 |
+
last_hidden_state = self.model(**model_inputs)[0]
|
185 |
+
else:
|
186 |
+
output = self.model(**model_inputs)
|
187 |
+
last_hidden_state = (
|
188 |
+
output["last_hidden_state"] if isinstance(output, dict) else output[0]
|
189 |
+
)
|
190 |
+
|
191 |
+
return {
|
192 |
+
"last_hidden_state": last_hidden_state,
|
193 |
+
"special_tokens_mask": special_tokens_mask,
|
194 |
+
"offset_mapping": offset_mapping,
|
195 |
+
"sentence": sentence,
|
196 |
+
"is_last": is_last,
|
197 |
+
**model_inputs,
|
198 |
+
}
|
199 |
+
|
200 |
+
def postprocess_tensor(self, data, return_tensors=False):
|
201 |
+
if return_tensors:
|
202 |
+
return data
|
203 |
+
if self.framework == "pt":
|
204 |
+
return data.tolist()
|
205 |
+
elif self.framework == "tf":
|
206 |
+
return data.numpy().tolist()
|
207 |
+
else:
|
208 |
+
raise ValueError(f"unknown framework: {self.framework}")
|
209 |
+
|
210 |
+
def make_embeddings_unique_per_token(
|
211 |
+
self, data, offset_mapping, special_tokens_mask, min_distance_to_border
|
212 |
+
):
|
213 |
+
char_offsets2token_pos = defaultdict(list)
|
214 |
+
bs, seq_len = offset_mapping.shape[:2]
|
215 |
+
if bs != 1:
|
216 |
+
raise ValueError(f"expected result batch size 1, but it is: {bs}")
|
217 |
+
for token_idx, ((char_start, shar_end), is_special_token, min_dist) in enumerate(
|
218 |
+
zip(
|
219 |
+
offset_mapping[0].tolist(),
|
220 |
+
special_tokens_mask[0].tolist(),
|
221 |
+
min_distance_to_border[0].tolist(),
|
222 |
+
)
|
223 |
+
):
|
224 |
+
|
225 |
+
if not is_special_token:
|
226 |
+
char_offsets2token_pos[(char_start, shar_end)].append((token_idx, min_dist))
|
227 |
+
|
228 |
+
# tokens_with_multiple_embeddings = {k: v for k, v in char_offsets2token_pos.items() if len(v) > 1}
|
229 |
+
char_offsets2best_token_pos = {
|
230 |
+
k: max(v, key=lambda pos_dist: pos_dist[1])[0]
|
231 |
+
for k, v in char_offsets2token_pos.items()
|
232 |
+
}
|
233 |
+
# sort by char offsets (start and end)
|
234 |
+
sorted_char_offsets_token_positions = sorted(
|
235 |
+
char_offsets2best_token_pos.items(),
|
236 |
+
key=lambda char_offsets_tok_pos: (
|
237 |
+
char_offsets_tok_pos[0][0],
|
238 |
+
char_offsets_tok_pos[0][1],
|
239 |
+
),
|
240 |
+
)
|
241 |
+
best_indices = [tok_pos for char_offsets, tok_pos in sorted_char_offsets_token_positions]
|
242 |
+
|
243 |
+
result = {k: v[0][best_indices].unsqueeze(0) for k, v in data.items()}
|
244 |
+
return result
|
245 |
+
|
246 |
+
def postprocess(
|
247 |
+
self,
|
248 |
+
all_outputs,
|
249 |
+
create_unique_embeddings_per_token: bool = False,
|
250 |
+
return_offset_mapping: bool = False,
|
251 |
+
return_special_tokens_mask: bool = False,
|
252 |
+
return_sequence_indices: bool = False,
|
253 |
+
return_position_ids: bool = False,
|
254 |
+
return_min_distance_to_border: bool = False,
|
255 |
+
return_tensors: bool = False,
|
256 |
+
):
|
257 |
+
|
258 |
+
all_outputs_dict = list_of_dicts2dict_of_lists(all_outputs)
|
259 |
+
if self.framework == "pt":
|
260 |
+
result = {
|
261 |
+
"last_hidden_state": torch.concat(all_outputs_dict["last_hidden_state"], axis=1)
|
262 |
+
}
|
263 |
+
if return_offset_mapping or create_unique_embeddings_per_token:
|
264 |
+
result["offset_mapping"] = torch.concat(all_outputs_dict["offset_mapping"], axis=1)
|
265 |
+
if return_special_tokens_mask or create_unique_embeddings_per_token:
|
266 |
+
result["special_tokens_mask"] = torch.concat(
|
267 |
+
all_outputs_dict["special_tokens_mask"], axis=1
|
268 |
+
)
|
269 |
+
if return_sequence_indices:
|
270 |
+
sequence_indices = []
|
271 |
+
for seq_idx, model_outputs in enumerate(all_outputs):
|
272 |
+
sequence_indices.append(torch.ones_like(model_outputs["input_ids"]) * seq_idx)
|
273 |
+
result["sequence_indices"] = torch.concat(sequence_indices, axis=1)
|
274 |
+
if return_position_ids:
|
275 |
+
position_ids = []
|
276 |
+
for seq_idx, model_outputs in enumerate(all_outputs):
|
277 |
+
seq_len = model_outputs["input_ids"].size(1)
|
278 |
+
position_ids.append(torch.arange(seq_len).unsqueeze(0))
|
279 |
+
result["indices"] = torch.concat(position_ids, axis=1)
|
280 |
+
if return_min_distance_to_border or create_unique_embeddings_per_token:
|
281 |
+
min_distance_to_border = []
|
282 |
+
for seq_idx, model_outputs in enumerate(all_outputs):
|
283 |
+
seq_len = model_outputs["input_ids"].size(1)
|
284 |
+
current_indices = torch.arange(seq_len).unsqueeze(0)
|
285 |
+
min_distance_to_border.append(
|
286 |
+
torch.minimum(current_indices, seq_len - current_indices)
|
287 |
+
)
|
288 |
+
result["min_distance_to_border"] = torch.concat(min_distance_to_border, axis=1)
|
289 |
+
elif self.framework == "tf":
|
290 |
+
raise NotImplementedError()
|
291 |
+
else:
|
292 |
+
raise ValueError(f"unknown framework: {self.framework}")
|
293 |
+
|
294 |
+
if create_unique_embeddings_per_token:
|
295 |
+
offset_mapping = result["offset_mapping"]
|
296 |
+
if not return_offset_mapping:
|
297 |
+
del result["offset_mapping"]
|
298 |
+
special_tokens_mask = result["special_tokens_mask"]
|
299 |
+
if not return_special_tokens_mask:
|
300 |
+
del result["special_tokens_mask"]
|
301 |
+
min_distance_to_border = result["min_distance_to_border"]
|
302 |
+
if not return_min_distance_to_border:
|
303 |
+
del result["min_distance_to_border"]
|
304 |
+
result = self.make_embeddings_unique_per_token(
|
305 |
+
data=result,
|
306 |
+
offset_mapping=offset_mapping,
|
307 |
+
special_tokens_mask=special_tokens_mask,
|
308 |
+
min_distance_to_border=min_distance_to_border,
|
309 |
+
)
|
310 |
+
|
311 |
+
result = {
|
312 |
+
k: self.postprocess_tensor(v, return_tensors=return_tensors) for k, v in result.items()
|
313 |
+
}
|
314 |
+
if set(result) == {"last_hidden_state"}:
|
315 |
+
return result["last_hidden_state"]
|
316 |
+
else:
|
317 |
+
return result
|
src/langchain_modules/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .basic_pie_document_store import BasicPieDocumentStore
|
2 |
+
from .datasets_pie_document_store import DatasetsPieDocumentStore
|
3 |
+
from .huggingface_span_embeddings import HuggingFaceSpanEmbeddings
|
4 |
+
from .pie_document_store import PieDocumentStore
|
5 |
+
from .qdrant_span_vectorstore import QdrantSpanVectorStore
|
6 |
+
from .serializable_store import SerializableStore
|
7 |
+
from .span_embeddings import SpanEmbeddings
|
8 |
+
from .span_retriever import DocumentAwareSpanRetriever, DocumentAwareSpanRetrieverWithRelations
|
9 |
+
from .span_vectorstore import SpanVectorStore
|
src/langchain_modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (489 Bytes). View file
|
|
src/langchain_modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (801 Bytes). View file
|
|
src/langchain_modules/__pycache__/basic_pie_document_store.cpython-39.pyc
ADDED
Binary file (4.79 kB). View file
|
|
src/langchain_modules/__pycache__/datasets_pie_document_store.cpython-39.pyc
ADDED
Binary file (7.68 kB). View file
|
|
src/langchain_modules/__pycache__/huggingface_span_embeddings.cpython-310.pyc
ADDED
Binary file (5.31 kB). View file
|
|
src/langchain_modules/__pycache__/huggingface_span_embeddings.cpython-39.pyc
ADDED
Binary file (7.1 kB). View file
|
|
src/langchain_modules/__pycache__/pie_document_store.cpython-39.pyc
ADDED
Binary file (4.21 kB). View file
|
|
src/langchain_modules/__pycache__/qdrant_span_vectorstore.cpython-310.pyc
ADDED
Binary file (6.82 kB). View file
|
|
src/langchain_modules/__pycache__/qdrant_span_vectorstore.cpython-39.pyc
ADDED
Binary file (12.9 kB). View file
|
|
src/langchain_modules/__pycache__/serializable_store.cpython-39.pyc
ADDED
Binary file (4.1 kB). View file
|
|
src/langchain_modules/__pycache__/span_embeddings.cpython-310.pyc
ADDED
Binary file (3.42 kB). View file
|
|
src/langchain_modules/__pycache__/span_embeddings.cpython-39.pyc
ADDED
Binary file (3.91 kB). View file
|
|
src/langchain_modules/__pycache__/span_retriever.cpython-310.pyc
ADDED
Binary file (13 kB). View file
|
|
src/langchain_modules/__pycache__/span_retriever.cpython-39.pyc
ADDED
Binary file (28 kB). View file
|
|
src/langchain_modules/__pycache__/span_vectorstore.cpython-39.pyc
ADDED
Binary file (5.31 kB). View file
|
|
src/langchain_modules/basic_pie_document_store.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from itertools import islice
|
6 |
+
from typing import Iterator, List, Optional, Sequence, Tuple
|
7 |
+
|
8 |
+
from langchain.storage import create_kv_docstore
|
9 |
+
from langchain_core.documents import Document as LCDocument
|
10 |
+
from langchain_core.stores import BaseStore, ByteStore
|
11 |
+
from pie_datasets import Dataset, DatasetDict
|
12 |
+
|
13 |
+
from .pie_document_store import PieDocumentStore
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class BasicPieDocumentStore(PieDocumentStore):
|
19 |
+
"""PIE Document store that uses a client to store and retrieve documents."""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
client: Optional[BaseStore[str, LCDocument]] = None,
|
24 |
+
byte_store: Optional[ByteStore] = None,
|
25 |
+
):
|
26 |
+
if byte_store is not None:
|
27 |
+
client = create_kv_docstore(byte_store)
|
28 |
+
elif client is None:
|
29 |
+
raise Exception("You must pass a `byte_store` parameter.")
|
30 |
+
|
31 |
+
self.client = client
|
32 |
+
|
33 |
+
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
|
34 |
+
return self.client.mget(keys)
|
35 |
+
|
36 |
+
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
|
37 |
+
self.client.mset(items)
|
38 |
+
|
39 |
+
def mdelete(self, keys: Sequence[str]) -> None:
|
40 |
+
self.client.mdelete(keys)
|
41 |
+
|
42 |
+
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
43 |
+
return self.client.yield_keys(prefix=prefix)
|
44 |
+
|
45 |
+
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
|
46 |
+
all_doc_ids = []
|
47 |
+
all_metadata = []
|
48 |
+
pie_documents_path = os.path.join(path, "pie_documents")
|
49 |
+
if os.path.exists(pie_documents_path):
|
50 |
+
# remove existing directory
|
51 |
+
logger.warning(f"Removing existing directory: {pie_documents_path}")
|
52 |
+
shutil.rmtree(pie_documents_path)
|
53 |
+
os.makedirs(pie_documents_path, exist_ok=True)
|
54 |
+
doc_ids_iter = iter(self.client.yield_keys())
|
55 |
+
while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
|
56 |
+
all_doc_ids.extend(batch_doc_ids)
|
57 |
+
docs = self.client.mget(batch_doc_ids)
|
58 |
+
pie_docs = []
|
59 |
+
for doc in docs:
|
60 |
+
pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT]
|
61 |
+
pie_docs.append(pie_doc)
|
62 |
+
all_metadata.append(
|
63 |
+
{k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
|
64 |
+
)
|
65 |
+
pie_dataset = Dataset.from_documents(pie_docs)
|
66 |
+
DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path)
|
67 |
+
if len(all_doc_ids) > 0:
|
68 |
+
doc_ids_path = os.path.join(path, "doc_ids.json")
|
69 |
+
with open(doc_ids_path, "w") as f:
|
70 |
+
json.dump(all_doc_ids, f)
|
71 |
+
if len(all_metadata) > 0:
|
72 |
+
metadata_path = os.path.join(path, "metadata.json")
|
73 |
+
with open(metadata_path, "w") as f:
|
74 |
+
json.dump(all_metadata, f)
|
75 |
+
|
76 |
+
def _load_from_directory(self, path: str, **kwargs) -> None:
|
77 |
+
pie_documents_path = os.path.join(path, "pie_documents")
|
78 |
+
if not os.path.exists(pie_documents_path):
|
79 |
+
logger.warning(
|
80 |
+
f"Directory {pie_documents_path} does not exist, don't load any documents."
|
81 |
+
)
|
82 |
+
return None
|
83 |
+
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path)
|
84 |
+
pie_docs = pie_dataset["train"]
|
85 |
+
metadata_path = os.path.join(path, "metadata.json")
|
86 |
+
if os.path.exists(metadata_path):
|
87 |
+
with open(metadata_path, "r") as f:
|
88 |
+
all_metadata = json.load(f)
|
89 |
+
else:
|
90 |
+
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
|
91 |
+
all_metadata = [{} for _ in pie_docs]
|
92 |
+
docs = [
|
93 |
+
self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata)
|
94 |
+
]
|
95 |
+
doc_ids_path = os.path.join(path, "doc_ids.json")
|
96 |
+
if os.path.exists(doc_ids_path):
|
97 |
+
with open(doc_ids_path, "r") as f:
|
98 |
+
all_doc_ids = json.load(f)
|
99 |
+
else:
|
100 |
+
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
|
101 |
+
all_doc_ids = [doc.id for doc in pie_docs]
|
102 |
+
self.client.mset(zip(all_doc_ids, docs))
|
103 |
+
logger.info(f"Loaded {len(docs)} documents from {path} into docstore")
|
src/langchain_modules/datasets_pie_document_store.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple
|
6 |
+
|
7 |
+
from datasets import Dataset as HFDataset
|
8 |
+
from langchain_core.documents import Document as LCDocument
|
9 |
+
from pie_datasets import Dataset, DatasetDict, concatenate_datasets
|
10 |
+
from pytorch_ie.documents import TextBasedDocument
|
11 |
+
|
12 |
+
from .pie_document_store import PieDocumentStore
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class DatasetsPieDocumentStore(PieDocumentStore):
|
18 |
+
"""PIE Document store that uses Huggingface Datasets as the backend."""
|
19 |
+
|
20 |
+
def __init__(self) -> None:
|
21 |
+
self._data: Optional[Dataset] = None
|
22 |
+
# keys map to indices in the dataset
|
23 |
+
self._keys: Dict[str, int] = {}
|
24 |
+
self._metadata: Dict[str, Any] = {}
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return len(self._keys)
|
28 |
+
|
29 |
+
def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]:
|
30 |
+
if self._data is None:
|
31 |
+
return []
|
32 |
+
return self._data.apply_hf_func(func=HFDataset.select, indices=indices)
|
33 |
+
|
34 |
+
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
|
35 |
+
if self._data is None or len(keys) == 0:
|
36 |
+
return []
|
37 |
+
keys_in_data = [key for key in keys if key in self._keys]
|
38 |
+
indices = [self._keys[key] for key in keys_in_data]
|
39 |
+
dataset = self._get_pie_docs_by_indices(indices)
|
40 |
+
metadatas = [self._metadata.get(key, {}) for key in keys_in_data]
|
41 |
+
return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)]
|
42 |
+
|
43 |
+
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
|
44 |
+
if len(items) == 0:
|
45 |
+
return
|
46 |
+
keys, new_docs = zip(*items)
|
47 |
+
pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs])
|
48 |
+
if self._data is None:
|
49 |
+
idx_start = 0
|
50 |
+
self._data = Dataset.from_documents(pie_docs)
|
51 |
+
else:
|
52 |
+
# we pass the features to the new dataset to mitigate issues caused by
|
53 |
+
# slightly different inferred features
|
54 |
+
dataset = Dataset.from_documents(pie_docs, features=self._data.features)
|
55 |
+
idx_start = len(self._data)
|
56 |
+
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
|
57 |
+
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
|
58 |
+
self._keys.update(keys_dict)
|
59 |
+
self._metadata.update(
|
60 |
+
{key: metadata for key, metadata in zip(keys, metadatas) if metadata}
|
61 |
+
)
|
62 |
+
|
63 |
+
def add_pie_dataset(
|
64 |
+
self,
|
65 |
+
dataset: Dataset,
|
66 |
+
keys: Optional[List[str]] = None,
|
67 |
+
metadatas: Optional[List[Dict[str, Any]]] = None,
|
68 |
+
) -> None:
|
69 |
+
if len(dataset) == 0:
|
70 |
+
return
|
71 |
+
if keys is None:
|
72 |
+
keys = [doc.id for doc in dataset]
|
73 |
+
if len(keys) != len(set(keys)):
|
74 |
+
raise ValueError("Keys must be unique.")
|
75 |
+
if None in keys:
|
76 |
+
raise ValueError("Keys must not be None.")
|
77 |
+
if metadatas is None:
|
78 |
+
metadatas = [{} for _ in range(len(dataset))]
|
79 |
+
if len(keys) != len(dataset) or len(keys) != len(metadatas):
|
80 |
+
raise ValueError("Keys, dataset and metadatas must have the same length.")
|
81 |
+
|
82 |
+
if self._data is None:
|
83 |
+
idx_start = 0
|
84 |
+
self._data = dataset
|
85 |
+
else:
|
86 |
+
idx_start = len(self._data)
|
87 |
+
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
|
88 |
+
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
|
89 |
+
self._keys.update(keys_dict)
|
90 |
+
metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata}
|
91 |
+
self._metadata.update(metadatas_dict)
|
92 |
+
|
93 |
+
def mdelete(self, keys: Sequence[str]) -> None:
|
94 |
+
for key in keys:
|
95 |
+
idx = self._keys.pop(key, None)
|
96 |
+
if idx is not None:
|
97 |
+
self._metadata.pop(key, None)
|
98 |
+
|
99 |
+
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
100 |
+
return (key for key in self._keys if prefix is None or key.startswith(prefix))
|
101 |
+
|
102 |
+
def _purge_invalid_entries(self):
|
103 |
+
if self._data is None or len(self._keys) == len(self._data):
|
104 |
+
return
|
105 |
+
self._data = self._get_pie_docs_by_indices(self._keys.values())
|
106 |
+
|
107 |
+
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
|
108 |
+
self._purge_invalid_entries()
|
109 |
+
if len(self) == 0:
|
110 |
+
logger.warning("No documents to save.")
|
111 |
+
return
|
112 |
+
|
113 |
+
all_doc_ids = list(self._keys)
|
114 |
+
all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids]
|
115 |
+
pie_documents_path = os.path.join(path, "pie_documents")
|
116 |
+
if os.path.exists(pie_documents_path):
|
117 |
+
# remove existing directory
|
118 |
+
logger.warning(f"Removing existing directory: {pie_documents_path}")
|
119 |
+
shutil.rmtree(pie_documents_path)
|
120 |
+
os.makedirs(pie_documents_path, exist_ok=True)
|
121 |
+
DatasetDict({"train": self._data}).to_json(pie_documents_path)
|
122 |
+
doc_ids_path = os.path.join(path, "doc_ids.json")
|
123 |
+
with open(doc_ids_path, "w") as f:
|
124 |
+
json.dump(all_doc_ids, f)
|
125 |
+
metadata_path = os.path.join(path, "metadata.json")
|
126 |
+
with open(metadata_path, "w") as f:
|
127 |
+
json.dump(all_metadatas, f)
|
128 |
+
|
129 |
+
def _load_from_directory(self, path: str, **kwargs) -> None:
|
130 |
+
doc_ids_path = os.path.join(path, "doc_ids.json")
|
131 |
+
if os.path.exists(doc_ids_path):
|
132 |
+
with open(doc_ids_path, "r") as f:
|
133 |
+
all_doc_ids = json.load(f)
|
134 |
+
else:
|
135 |
+
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
|
136 |
+
all_doc_ids = None
|
137 |
+
metadata_path = os.path.join(path, "metadata.json")
|
138 |
+
if os.path.exists(metadata_path):
|
139 |
+
with open(metadata_path, "r") as f:
|
140 |
+
all_metadata = json.load(f)
|
141 |
+
else:
|
142 |
+
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
|
143 |
+
all_metadata = None
|
144 |
+
pie_documents_path = os.path.join(path, "pie_documents")
|
145 |
+
if not os.path.exists(pie_documents_path):
|
146 |
+
logger.warning(
|
147 |
+
f"Directory {pie_documents_path} does not exist, don't load any documents."
|
148 |
+
)
|
149 |
+
return None
|
150 |
+
# If we have a dataset already loaded, we use its features to load the new dataset
|
151 |
+
# This is to mitigate issues caused by slightly different inferred features.
|
152 |
+
features = self._data.features if self._data is not None else None
|
153 |
+
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features)
|
154 |
+
pie_docs = pie_dataset["train"]
|
155 |
+
self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata)
|
156 |
+
logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore")
|
src/langchain_modules/huggingface_span_embeddings.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, Dict, List, Optional, Union # type: ignore[import-not-found]
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from pydantic import BaseModel, ConfigDict, Field
|
6 |
+
from transformers import pipeline
|
7 |
+
|
8 |
+
from ..hf_pipeline import FeatureExtractionPipelineWithStriding
|
9 |
+
from .span_embeddings import SpanEmbeddings
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
DEFAULT_MODEL_NAME = "allenai/specter2_base"
|
14 |
+
|
15 |
+
|
16 |
+
class HuggingFaceSpanEmbeddings(BaseModel, SpanEmbeddings):
|
17 |
+
"""An implementation of SpanEmbeddings using a modified HuggingFace Transformers
|
18 |
+
feature-extraction pipeline, adapted for long text inputs by chunking with optional stride
|
19 |
+
(see src.hf_pipeline.FeatureExtractionPipelineWithStriding).
|
20 |
+
|
21 |
+
Note that calculating embeddings for multiple spans is efficient when all spans for a
|
22 |
+
text are passed in a single call to embed_document_spans, as the text embedding is computed
|
23 |
+
only once per unique text, and the span embeddings are simply pooled from these text embeddings.
|
24 |
+
|
25 |
+
It accepts any model that can be used with the HuggingFace feature-extraction pipeline, also
|
26 |
+
models with adapters such as SPECTER2 (see https://huggingface.co/allenai/specter2). In this case,
|
27 |
+
the model should be loaded beforehand and passed as parameter 'model' instead of the model identifier.
|
28 |
+
See https://huggingface.co/docs/transformers/main_classes/pipelines for further information.
|
29 |
+
|
30 |
+
To use, you should have the ``transformers`` python package installed.
|
31 |
+
|
32 |
+
Example:
|
33 |
+
.. code-block:: python
|
34 |
+
from transformers import AutoModel
|
35 |
+
|
36 |
+
model = "allenai/specter2_base"
|
37 |
+
pipeline_kwargs = {'device': 'cpu', 'stride': 64, 'batch_size': 8}
|
38 |
+
encode_kwargs = {'normalize_embeddings': False}
|
39 |
+
hf = HuggingFaceSpanEmbeddings(
|
40 |
+
model=model,
|
41 |
+
pipeline_kwargs=pipeline_kwargs,
|
42 |
+
)
|
43 |
+
|
44 |
+
text = "This is a test sentence."
|
45 |
+
|
46 |
+
# calculate embeddings for text[0:4]="This" and text[15:23]="sentence"
|
47 |
+
embeddings = hf.embed_document_spans(texts=[text, text], starts=[0, 11], ends=[4, 19])
|
48 |
+
"""
|
49 |
+
|
50 |
+
client: Any = None #: :meta private:
|
51 |
+
model: Optional[Any] = DEFAULT_MODEL_NAME
|
52 |
+
pooling_strategy: str = "mean"
|
53 |
+
"""Model name to use."""
|
54 |
+
pipeline_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
55 |
+
"""Keyword arguments to pass to the Huggingface pipeline constructor."""
|
56 |
+
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
57 |
+
"""Keyword arguments to pass when calling the pipeline."""
|
58 |
+
# show_progress: bool = False
|
59 |
+
"""Whether to show a progress bar."""
|
60 |
+
model_max_length: Optional[int] = None
|
61 |
+
"""The maximum input length of the model. Required for some model checkpoints with outdated configs."""
|
62 |
+
|
63 |
+
def __init__(self, **kwargs: Any):
|
64 |
+
"""Initialize the sentence_transformer."""
|
65 |
+
super().__init__(**kwargs)
|
66 |
+
|
67 |
+
self.client = pipeline(
|
68 |
+
"feature-extraction",
|
69 |
+
model=self.model,
|
70 |
+
pipeline_class=FeatureExtractionPipelineWithStriding,
|
71 |
+
trust_remote_code=True,
|
72 |
+
**self.pipeline_kwargs,
|
73 |
+
)
|
74 |
+
|
75 |
+
# The Transformers library is buggy since 4.40.0,
|
76 |
+
# see https://github.com/huggingface/transformers/issues/30643,
|
77 |
+
# so we need to set the max_length to e.g. 512 manually
|
78 |
+
if self.model_max_length is not None:
|
79 |
+
self.client.tokenizer.model_max_length = self.model_max_length
|
80 |
+
|
81 |
+
# Check if the model has a valid max length
|
82 |
+
max_input_size = self.client.tokenizer.model_max_length
|
83 |
+
if max_input_size > 1e5: # A high threshold to catch "unlimited" values
|
84 |
+
raise ValueError(
|
85 |
+
"The tokenizer does not specify a valid `model_max_length` attribute. "
|
86 |
+
"Consider setting it manually by passing `model_max_length` to the "
|
87 |
+
"HuggingFaceSpanEmbeddings constructor."
|
88 |
+
)
|
89 |
+
|
90 |
+
model_config = ConfigDict(
|
91 |
+
extra="forbid",
|
92 |
+
protected_namespaces=(),
|
93 |
+
)
|
94 |
+
|
95 |
+
def get_span_embedding(
|
96 |
+
self,
|
97 |
+
last_hidden_state: torch.Tensor,
|
98 |
+
offset_mapping: torch.Tensor,
|
99 |
+
start: Union[int, List[int]],
|
100 |
+
end: Union[int, List[int]],
|
101 |
+
**unused_kwargs,
|
102 |
+
) -> Optional[List[float]]:
|
103 |
+
"""Pool the span embeddings."""
|
104 |
+
if isinstance(start, int):
|
105 |
+
start = [start]
|
106 |
+
if isinstance(end, int):
|
107 |
+
end = [end]
|
108 |
+
if len(start) != len(end):
|
109 |
+
raise ValueError("start and end should have the same length.")
|
110 |
+
if len(start) == 0:
|
111 |
+
raise ValueError("start and end should not be empty.")
|
112 |
+
if last_hidden_state.shape[0] != 1:
|
113 |
+
raise ValueError("last_hidden_state should have a batch size of 1.")
|
114 |
+
if last_hidden_state.shape[0] != offset_mapping.shape[0]:
|
115 |
+
raise ValueError(
|
116 |
+
"last_hidden_state and offset_mapping should have the same batch size."
|
117 |
+
)
|
118 |
+
offset_mapping = offset_mapping[0]
|
119 |
+
last_hidden_state = last_hidden_state[0]
|
120 |
+
|
121 |
+
mask = (start[0] <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= end[0])
|
122 |
+
for s, e in zip(start[1:], end[1:]):
|
123 |
+
mask = mask | ((s <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= e))
|
124 |
+
span_embeddings = last_hidden_state[mask]
|
125 |
+
if span_embeddings.shape[0] == 0:
|
126 |
+
return None
|
127 |
+
if self.pooling_strategy == "mean":
|
128 |
+
return span_embeddings.mean(dim=0).tolist()
|
129 |
+
elif self.pooling_strategy == "max":
|
130 |
+
return span_embeddings.max(dim=0).values.tolist()
|
131 |
+
else:
|
132 |
+
raise ValueError(f"Unknown pool strategy: {self.pooling_strategy}")
|
133 |
+
|
134 |
+
def embed_document_spans(
|
135 |
+
self,
|
136 |
+
texts: List[str],
|
137 |
+
starts: Union[List[int], List[List[int]]],
|
138 |
+
ends: Union[List[int], List[List[int]]],
|
139 |
+
) -> List[Optional[List[float]]]:
|
140 |
+
"""Compute doc embeddings using a HuggingFace transformer model.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
texts: The list of texts to embed.
|
144 |
+
starts: The list of start indices or list of lists of start indices (multi-span).
|
145 |
+
ends: The list of end indices or list of lists of end indices (multi-span).
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
List of embeddings, one for each text.
|
149 |
+
"""
|
150 |
+
pipeline_kwargs = self.encode_kwargs.copy()
|
151 |
+
pipeline_kwargs["return_offset_mapping"] = True
|
152 |
+
# we enable long text handling by providing the stride parameter
|
153 |
+
if pipeline_kwargs.get("stride", None) is None:
|
154 |
+
pipeline_kwargs["stride"] = 0
|
155 |
+
# when stride is positive, we need to create unique embeddings per token
|
156 |
+
if pipeline_kwargs["stride"] > 0:
|
157 |
+
pipeline_kwargs["create_unique_embeddings_per_token"] = True
|
158 |
+
# we ask for tensors to efficiently compute the span embeddings
|
159 |
+
pipeline_kwargs["return_tensors"] = True
|
160 |
+
|
161 |
+
unique_texts = sorted(set(texts))
|
162 |
+
idx2unique_idx = {i: unique_texts.index(text) for i, text in enumerate(texts)}
|
163 |
+
pipeline_results = self.client(unique_texts, **pipeline_kwargs)
|
164 |
+
embeddings = [
|
165 |
+
self.get_span_embedding(
|
166 |
+
start=starts[idx], end=ends[idx], **pipeline_results[idx2unique_idx[idx]]
|
167 |
+
)
|
168 |
+
for idx in range(len(texts))
|
169 |
+
]
|
170 |
+
return embeddings
|
171 |
+
|
172 |
+
def embed_query_span(
|
173 |
+
self, text: str, start: Union[int, List[int]], end: Union[int, List[int]]
|
174 |
+
) -> Optional[List[float]]:
|
175 |
+
"""Compute query embeddings using a HuggingFace transformer model.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
text: The text to embed.
|
179 |
+
start: The start index or list of start indices (multi-span).
|
180 |
+
end: The end index or list of end indices (multi-span).
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
Embeddings for the text.
|
184 |
+
"""
|
185 |
+
starts: Union[List[int], List[List[int]]] = [start] # type: ignore[assignment]
|
186 |
+
ends: Union[List[int], List[List[int]]] = [end] # type: ignore[assignment]
|
187 |
+
return self.embed_document_spans([text], starts=starts, ends=ends)[0]
|
188 |
+
|
189 |
+
@property
|
190 |
+
def embedding_dim(self) -> int:
|
191 |
+
"""Get the embedding dimension."""
|
192 |
+
return self.client.model.config.hidden_size
|
src/langchain_modules/pie_document_store.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import logging
|
3 |
+
from copy import copy
|
4 |
+
from typing import Iterator, List, Optional, Sequence, Tuple
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
from langchain_core.documents import Document as LCDocument
|
8 |
+
from langchain_core.stores import BaseStore
|
9 |
+
from pytorch_ie.documents import TextBasedDocument
|
10 |
+
|
11 |
+
from .serializable_store import SerializableStore
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC):
|
17 |
+
"""Abstract base class for document stores specialized in storing and retrieving pie documents."""
|
18 |
+
|
19 |
+
METADATA_KEY_PIE_DOCUMENT: str = "pie_document"
|
20 |
+
"""Key for the pie document in the (langchain) document metadata."""
|
21 |
+
|
22 |
+
def wrap(self, pie_document: TextBasedDocument, **metadata) -> LCDocument:
|
23 |
+
"""Wrap the pie document in an LCDocument."""
|
24 |
+
return LCDocument(
|
25 |
+
id=pie_document.id,
|
26 |
+
page_content="",
|
27 |
+
metadata={self.METADATA_KEY_PIE_DOCUMENT: pie_document, **metadata},
|
28 |
+
)
|
29 |
+
|
30 |
+
def unwrap(self, document: LCDocument) -> TextBasedDocument:
|
31 |
+
"""Get the pie document from the langchain document."""
|
32 |
+
return document.metadata[self.METADATA_KEY_PIE_DOCUMENT]
|
33 |
+
|
34 |
+
def unwrap_with_metadata(self, document: LCDocument) -> Tuple[TextBasedDocument, dict]:
|
35 |
+
"""Get the pie document and metadata from the langchain document."""
|
36 |
+
metadata = copy(document.metadata)
|
37 |
+
pie_document = metadata.pop(self.METADATA_KEY_PIE_DOCUMENT)
|
38 |
+
return pie_document, metadata
|
39 |
+
|
40 |
+
@abc.abstractmethod
|
41 |
+
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
|
42 |
+
pass
|
43 |
+
|
44 |
+
@abc.abstractmethod
|
45 |
+
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
|
46 |
+
pass
|
47 |
+
|
48 |
+
@abc.abstractmethod
|
49 |
+
def mdelete(self, keys: Sequence[str]) -> None:
|
50 |
+
pass
|
51 |
+
|
52 |
+
@abc.abstractmethod
|
53 |
+
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
54 |
+
pass
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(list(self.yield_keys()))
|
58 |
+
|
59 |
+
def overview(self, layer_captions: dict, use_predictions: bool = False) -> pd.DataFrame:
|
60 |
+
"""Get an overview of the document store, including the number of items in each layer for each document
|
61 |
+
in the store.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
layer_captions: A dictionary mapping layer names to captions.
|
65 |
+
use_predictions: Whether to use predictions instead of the actual layers.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
DataFrame: A pandas DataFrame containing the overview.
|
69 |
+
"""
|
70 |
+
rows = []
|
71 |
+
for doc_id in self.yield_keys():
|
72 |
+
document = self.mget([doc_id])[0]
|
73 |
+
pie_document = self.unwrap(document)
|
74 |
+
layers = {
|
75 |
+
caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
|
76 |
+
}
|
77 |
+
layer_sizes = {
|
78 |
+
f"num_{caption}s": len(layer) + (len(layer.predictions) if use_predictions else 0)
|
79 |
+
for caption, layer in layers.items()
|
80 |
+
}
|
81 |
+
rows.append({"doc_id": doc_id, **layer_sizes})
|
82 |
+
df = pd.DataFrame(rows)
|
83 |
+
return df
|
84 |
+
|
85 |
+
def as_dict(self, document: LCDocument) -> dict:
|
86 |
+
"""Convert the langchain document to a dictionary."""
|
87 |
+
pie_document, metadata = self.unwrap_with_metadata(document)
|
88 |
+
return {self.METADATA_KEY_PIE_DOCUMENT: pie_document.asdict(), "metadata": metadata}
|
src/langchain_modules/qdrant_span_vectorstore.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import uuid
|
7 |
+
from collections import defaultdict
|
8 |
+
from itertools import islice
|
9 |
+
from typing import ( # type: ignore[import-not-found]
|
10 |
+
Any,
|
11 |
+
Dict,
|
12 |
+
Generator,
|
13 |
+
Iterable,
|
14 |
+
Iterator,
|
15 |
+
List,
|
16 |
+
Optional,
|
17 |
+
Sequence,
|
18 |
+
Tuple,
|
19 |
+
Union,
|
20 |
+
)
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
from langchain_core.documents import Document as LCDocument
|
24 |
+
from langchain_qdrant import QdrantVectorStore, RetrievalMode
|
25 |
+
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
|
26 |
+
from qdrant_client import QdrantClient, models
|
27 |
+
from qdrant_client.http.models import Record
|
28 |
+
|
29 |
+
from .span_embeddings import SpanEmbeddings
|
30 |
+
from .span_vectorstore import SpanVectorStore
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
class QdrantSpanVectorStore(SpanVectorStore, QdrantVectorStore):
|
36 |
+
"""An implementation of the SpanVectorStore interface that uses Qdrant
|
37 |
+
as backend for storing and retrieving span embeddings."""
|
38 |
+
|
39 |
+
EMBEDDINGS_FILE = "embeddings.npy"
|
40 |
+
PAYLOADS_FILE = "payloads.json"
|
41 |
+
INDEX_FILE = "index.json"
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
client: QdrantClient,
|
46 |
+
collection_name: str,
|
47 |
+
embedding: SpanEmbeddings,
|
48 |
+
vector_params: Optional[Dict[str, Any]] = None,
|
49 |
+
**kwargs,
|
50 |
+
):
|
51 |
+
if not client.collection_exists(collection_name):
|
52 |
+
logger.info(f'Collection "{collection_name}" does not exist. Creating it now.')
|
53 |
+
client.create_collection(
|
54 |
+
collection_name=collection_name,
|
55 |
+
vectors_config=models.VectorParams(size=embedding.embedding_dim, **vector_params),
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
logger.info(f'Collection "{collection_name}" already exists.')
|
59 |
+
super().__init__(
|
60 |
+
client=client, collection_name=collection_name, embedding=embedding, **kwargs
|
61 |
+
)
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return self.client.get_collection(collection_name=self.collection_name).points_count
|
65 |
+
|
66 |
+
def get_by_ids_with_vectors(self, ids: Sequence[str | int], /) -> List[LCDocument]:
|
67 |
+
results = self.client.retrieve(
|
68 |
+
self.collection_name, ids, with_payload=True, with_vectors=True
|
69 |
+
)
|
70 |
+
|
71 |
+
return [
|
72 |
+
self._document_from_point(
|
73 |
+
result,
|
74 |
+
self.collection_name,
|
75 |
+
self.content_payload_key,
|
76 |
+
self.metadata_payload_key,
|
77 |
+
)
|
78 |
+
for result in results
|
79 |
+
]
|
80 |
+
|
81 |
+
def construct_filter(
|
82 |
+
self,
|
83 |
+
query_span: Union[Span, MultiSpan],
|
84 |
+
metadata_doc_id_key: str,
|
85 |
+
doc_id_whitelist: Optional[Sequence[str]] = None,
|
86 |
+
doc_id_blacklist: Optional[Sequence[str]] = None,
|
87 |
+
) -> Optional[models.Filter]:
|
88 |
+
"""Construct a filter for the retrieval. It should enforce that:
|
89 |
+
- if the span is labeled, the retrieved span has the same label, or
|
90 |
+
- if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
|
91 |
+
- if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
|
92 |
+
- if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist
|
93 |
+
|
94 |
+
Args:
|
95 |
+
query_span: The query span.
|
96 |
+
metadata_doc_id_key: The key in the metadata that holds the document id.
|
97 |
+
doc_id_whitelist: A list of document ids to restrict the retrieval to.
|
98 |
+
doc_id_blacklist: A list of document ids to exclude from the retrieval.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
A filter object.
|
102 |
+
"""
|
103 |
+
filter_kwargs = defaultdict(list)
|
104 |
+
# if the span is labeled, enforce that the retrieved span has the same label
|
105 |
+
if isinstance(query_span, (LabeledSpan, LabeledMultiSpan)):
|
106 |
+
if self.label_mapping is not None:
|
107 |
+
target_labels = self.label_mapping.get(query_span.label, [])
|
108 |
+
else:
|
109 |
+
target_labels = [query_span.label]
|
110 |
+
filter_kwargs["must"].append(
|
111 |
+
models.FieldCondition(
|
112 |
+
key=f"metadata.{self.METADATA_SPAN_KEY}.label",
|
113 |
+
match=models.MatchAny(any=target_labels),
|
114 |
+
)
|
115 |
+
)
|
116 |
+
elif self.label_mapping is not None:
|
117 |
+
raise TypeError("Label mapping is only supported for labeled spans")
|
118 |
+
|
119 |
+
if doc_id_blacklist is not None and doc_id_whitelist is not None:
|
120 |
+
overlap = set(doc_id_whitelist) & set(doc_id_blacklist)
|
121 |
+
if len(overlap) > 0:
|
122 |
+
raise ValueError(
|
123 |
+
f"Overlap between doc_id_whitelist and doc_id_blacklist: {overlap}"
|
124 |
+
)
|
125 |
+
|
126 |
+
if doc_id_whitelist is not None:
|
127 |
+
filter_kwargs["must"].append(
|
128 |
+
models.FieldCondition(
|
129 |
+
key=f"metadata.{metadata_doc_id_key}",
|
130 |
+
match=(
|
131 |
+
models.MatchValue(value=doc_id_whitelist[0])
|
132 |
+
if len(doc_id_whitelist) == 1
|
133 |
+
else models.MatchAny(any=doc_id_whitelist)
|
134 |
+
),
|
135 |
+
)
|
136 |
+
)
|
137 |
+
if doc_id_blacklist is not None:
|
138 |
+
filter_kwargs["must_not"].append(
|
139 |
+
models.FieldCondition(
|
140 |
+
key=f"metadata.{metadata_doc_id_key}",
|
141 |
+
match=(
|
142 |
+
models.MatchValue(value=doc_id_blacklist[0])
|
143 |
+
if len(doc_id_blacklist) == 1
|
144 |
+
else models.MatchAny(any=doc_id_blacklist)
|
145 |
+
),
|
146 |
+
)
|
147 |
+
)
|
148 |
+
if len(filter_kwargs) > 0:
|
149 |
+
return models.Filter(**filter_kwargs)
|
150 |
+
return None
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def _document_from_point(
|
154 |
+
cls,
|
155 |
+
scored_point: Any,
|
156 |
+
collection_name: str,
|
157 |
+
content_payload_key: str,
|
158 |
+
metadata_payload_key: str,
|
159 |
+
) -> LCDocument:
|
160 |
+
metadata = scored_point.payload.get(metadata_payload_key) or {}
|
161 |
+
metadata["_collection_name"] = collection_name
|
162 |
+
if hasattr(scored_point, "score"):
|
163 |
+
metadata[cls.RELEVANCE_SCORE_KEY] = scored_point.score
|
164 |
+
if hasattr(scored_point, "vector"):
|
165 |
+
metadata[cls.METADATA_VECTOR_KEY] = scored_point.vector
|
166 |
+
return LCDocument(
|
167 |
+
id=scored_point.id,
|
168 |
+
page_content=scored_point.payload.get(content_payload_key, ""),
|
169 |
+
metadata=metadata,
|
170 |
+
)
|
171 |
+
|
172 |
+
def _build_vectors_with_metadata(
|
173 |
+
self,
|
174 |
+
texts: Iterable[str],
|
175 |
+
metadatas: List[dict],
|
176 |
+
) -> List[models.VectorStruct]:
|
177 |
+
starts = [metadata[self.METADATA_SPAN_KEY][self.SPAN_START_KEY] for metadata in metadatas]
|
178 |
+
ends = [metadata[self.METADATA_SPAN_KEY][self.SPAN_END_KEY] for metadata in metadatas]
|
179 |
+
if self.retrieval_mode == RetrievalMode.DENSE:
|
180 |
+
batch_embeddings = self.embeddings.embed_document_spans(list(texts), starts, ends)
|
181 |
+
return [
|
182 |
+
{
|
183 |
+
self.vector_name: vector,
|
184 |
+
}
|
185 |
+
for vector in batch_embeddings
|
186 |
+
]
|
187 |
+
|
188 |
+
elif self.retrieval_mode == RetrievalMode.SPARSE:
|
189 |
+
raise ValueError("Sparse retrieval mode is not yet implemented.")
|
190 |
+
|
191 |
+
elif self.retrieval_mode == RetrievalMode.HYBRID:
|
192 |
+
raise NotImplementedError("Hybrid retrieval mode is not yet implemented.")
|
193 |
+
else:
|
194 |
+
raise ValueError(f"Unknown retrieval mode. {self.retrieval_mode} to build vectors.")
|
195 |
+
|
196 |
+
def _build_payloads_from_metadata(
|
197 |
+
self,
|
198 |
+
metadatas: Iterable[dict],
|
199 |
+
metadata_payload_key: str,
|
200 |
+
) -> List[dict]:
|
201 |
+
payloads = [{metadata_payload_key: metadata} for metadata in metadatas]
|
202 |
+
|
203 |
+
return payloads
|
204 |
+
|
205 |
+
def _generate_batches(
|
206 |
+
self,
|
207 |
+
texts: Iterable[str],
|
208 |
+
metadatas: Optional[List[dict]] = None,
|
209 |
+
ids: Optional[Sequence[str | int]] = None,
|
210 |
+
batch_size: int = 64,
|
211 |
+
) -> Generator[tuple[list[str | int], list[models.PointStruct]], Any, None]:
|
212 |
+
"""Generate batches of points to index. Same as in `QdrantVectorStore` but metadata is used
|
213 |
+
to build vectors and payloads."""
|
214 |
+
|
215 |
+
texts_iterator = iter(texts)
|
216 |
+
if metadatas is None:
|
217 |
+
raise ValueError("Metadata must be provided to generate batches.")
|
218 |
+
metadatas_iterator = iter(metadatas)
|
219 |
+
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
220 |
+
|
221 |
+
while batch_texts := list(islice(texts_iterator, batch_size)):
|
222 |
+
batch_metadatas = list(islice(metadatas_iterator, batch_size))
|
223 |
+
batch_ids = list(islice(ids_iterator, batch_size))
|
224 |
+
points = [
|
225 |
+
models.PointStruct(
|
226 |
+
id=point_id,
|
227 |
+
vector=vector,
|
228 |
+
payload=payload,
|
229 |
+
)
|
230 |
+
for point_id, vector, payload in zip(
|
231 |
+
batch_ids,
|
232 |
+
self._build_vectors_with_metadata(batch_texts, metadatas=batch_metadatas),
|
233 |
+
# we do not save the text in the payload because the text is the full
|
234 |
+
# document which is usually already saved in the docstore
|
235 |
+
self._build_payloads_from_metadata(
|
236 |
+
metadatas=batch_metadatas,
|
237 |
+
metadata_payload_key=self.metadata_payload_key,
|
238 |
+
),
|
239 |
+
)
|
240 |
+
if vector[self.vector_name] is not None
|
241 |
+
]
|
242 |
+
|
243 |
+
yield [point.id for point in points], points
|
244 |
+
|
245 |
+
def similarity_search_with_score_by_vector(
|
246 |
+
self,
|
247 |
+
embedding: List[float],
|
248 |
+
k: int = 4,
|
249 |
+
filter: Optional[models.Filter] = None,
|
250 |
+
search_params: Optional[models.SearchParams] = None,
|
251 |
+
offset: int = 0,
|
252 |
+
score_threshold: Optional[float] = None,
|
253 |
+
consistency: Optional[models.ReadConsistency] = None,
|
254 |
+
**kwargs: Any,
|
255 |
+
) -> List[Tuple[LCDocument, float]]:
|
256 |
+
"""Return docs most similar to query vector.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
List of documents most similar to the query text and distance for each.
|
260 |
+
"""
|
261 |
+
query_options = {
|
262 |
+
"collection_name": self.collection_name,
|
263 |
+
"query_filter": filter,
|
264 |
+
"search_params": search_params,
|
265 |
+
"limit": k,
|
266 |
+
"offset": offset,
|
267 |
+
"with_payload": True,
|
268 |
+
"with_vectors": False,
|
269 |
+
"score_threshold": score_threshold,
|
270 |
+
"consistency": consistency,
|
271 |
+
**kwargs,
|
272 |
+
}
|
273 |
+
|
274 |
+
results = self.client.query_points(
|
275 |
+
query=embedding,
|
276 |
+
using=self.vector_name,
|
277 |
+
**query_options,
|
278 |
+
).points
|
279 |
+
|
280 |
+
return [
|
281 |
+
(
|
282 |
+
self._document_from_point(
|
283 |
+
result,
|
284 |
+
self.collection_name,
|
285 |
+
self.content_payload_key,
|
286 |
+
self.metadata_payload_key,
|
287 |
+
),
|
288 |
+
result.score,
|
289 |
+
)
|
290 |
+
for result in results
|
291 |
+
]
|
292 |
+
|
293 |
+
def _as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[Any]]:
|
294 |
+
data, _ = self.client.scroll(
|
295 |
+
collection_name=self.collection_name, with_vectors=True, limit=len(self)
|
296 |
+
)
|
297 |
+
vectors_np = np.array([point.vector for point in data])
|
298 |
+
payloads = [point.payload for point in data]
|
299 |
+
emb_ids = [point.id for point in data]
|
300 |
+
return emb_ids, vectors_np, payloads
|
301 |
+
|
302 |
+
# TODO: or use create_snapshot and restore_snapshot?
|
303 |
+
def _save_to_directory(self, path: str, **kwargs) -> None:
|
304 |
+
indices, vectors, payloads = self._as_indices_vectors_payloads()
|
305 |
+
np.save(os.path.join(path, self.EMBEDDINGS_FILE), vectors)
|
306 |
+
with open(os.path.join(path, self.PAYLOADS_FILE), "w") as f:
|
307 |
+
json.dump(payloads, f, indent=2)
|
308 |
+
with open(os.path.join(path, self.INDEX_FILE), "w") as f:
|
309 |
+
json.dump(indices, f)
|
310 |
+
|
311 |
+
def _load_from_directory(self, path: str, **kwargs) -> None:
|
312 |
+
with open(os.path.join(path, self.INDEX_FILE), "r") as f:
|
313 |
+
index = json.load(f)
|
314 |
+
embeddings_np: np.ndarray = np.load(os.path.join(path, self.EMBEDDINGS_FILE))
|
315 |
+
with open(os.path.join(path, self.PAYLOADS_FILE), "r") as f:
|
316 |
+
payloads = json.load(f)
|
317 |
+
points = models.Batch(ids=index, vectors=embeddings_np.tolist(), payloads=payloads)
|
318 |
+
self.client.upsert(
|
319 |
+
collection_name=self.collection_name,
|
320 |
+
points=points,
|
321 |
+
)
|
322 |
+
logger.info(f"Loaded {len(index)} points into collection {self.collection_name}.")
|
323 |
+
|
324 |
+
def mget(self, keys: Sequence[str]) -> list[Optional[Record]]:
|
325 |
+
return self.client.retrieve(
|
326 |
+
self.collection_name, ids=keys, with_payload=True, with_vectors=True
|
327 |
+
)
|
328 |
+
|
329 |
+
def mset(self, key_value_pairs: Sequence[tuple[str, Record]]) -> None:
|
330 |
+
self.client.upsert(
|
331 |
+
collection_name=self.collection_name,
|
332 |
+
points=models.Batch(
|
333 |
+
ids=[key for key, _ in key_value_pairs],
|
334 |
+
vectors=[value.vector for _, value in key_value_pairs],
|
335 |
+
payloads=[value.payload for _, value in key_value_pairs],
|
336 |
+
),
|
337 |
+
)
|
338 |
+
|
339 |
+
def mdelete(self, keys: Sequence[str]) -> None:
|
340 |
+
self.client.delete(collection_name=self.collection_name, points_selector=keys)
|
341 |
+
|
342 |
+
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
343 |
+
for point in self.client.scroll(
|
344 |
+
collection_name=self.collection_name,
|
345 |
+
with_vectors=False,
|
346 |
+
with_payload=False,
|
347 |
+
limit=len(self),
|
348 |
+
)[0]:
|
349 |
+
yield point.id
|
src/langchain_modules/serializable_store.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class SerializableStore(ABC):
|
11 |
+
"""Abstract base class for serializable stores."""
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def _save_to_directory(self, path: str, **kwargs) -> None:
|
15 |
+
"""Save the store to a directory.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
path (str): The path to a directory.
|
19 |
+
"""
|
20 |
+
|
21 |
+
@abstractmethod
|
22 |
+
def _load_from_directory(self, path: str, **kwargs) -> None:
|
23 |
+
"""Load the store from a directory.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
path (str): The path to the file.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def save_to_directory(self, path: str, **kwargs) -> None:
|
30 |
+
"""Save the store to a directory.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
path (str): The path to a directory.
|
34 |
+
"""
|
35 |
+
os.makedirs(path, exist_ok=True)
|
36 |
+
self._save_to_directory(path, **kwargs)
|
37 |
+
|
38 |
+
def load_from_directory(self, path: str, **kwargs) -> None:
|
39 |
+
"""Load the store from a directory.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
path (str): The path to the file.
|
43 |
+
"""
|
44 |
+
if not os.path.exists(path):
|
45 |
+
raise FileNotFoundError(f'Directory "{path}" not found.')
|
46 |
+
|
47 |
+
self._load_from_directory(path, **kwargs)
|
48 |
+
|
49 |
+
def save_to_archive(
|
50 |
+
self,
|
51 |
+
base_name: str,
|
52 |
+
format: str,
|
53 |
+
) -> str:
|
54 |
+
"""Save the store to an archive.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
base_name (str): The base name of the archive.
|
58 |
+
format (str): The format of the archive.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
str: The path to the archive.
|
62 |
+
"""
|
63 |
+
temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
|
64 |
+
# remove the temporary directory if it already exists
|
65 |
+
if os.path.exists(temp_dir):
|
66 |
+
shutil.rmtree(temp_dir)
|
67 |
+
# save the documents to the directory
|
68 |
+
self.save_to_directory(temp_dir)
|
69 |
+
# zip the directory
|
70 |
+
result_file_path = shutil.make_archive(
|
71 |
+
base_name=base_name, root_dir=temp_dir, format=format
|
72 |
+
)
|
73 |
+
# remove the temporary directory
|
74 |
+
shutil.rmtree(temp_dir)
|
75 |
+
return result_file_path
|
76 |
+
|
77 |
+
def load_from_archive(
|
78 |
+
self,
|
79 |
+
file_name: str,
|
80 |
+
) -> None:
|
81 |
+
"""Load the store from an archive.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
file_name (str): The path to the archive.
|
85 |
+
"""
|
86 |
+
if not os.path.exists(file_name):
|
87 |
+
raise FileNotFoundError(f'Archive file "{file_name}" not found.')
|
88 |
+
|
89 |
+
temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
|
90 |
+
# remove the temporary directory if it already exists
|
91 |
+
if os.path.exists(temp_dir):
|
92 |
+
shutil.rmtree(temp_dir)
|
93 |
+
|
94 |
+
# unzip the file
|
95 |
+
shutil.unpack_archive(file_name, temp_dir)
|
96 |
+
# load the documents from the directory
|
97 |
+
self.load_from_directory(temp_dir)
|
98 |
+
# remove the temporary directory
|
99 |
+
shutil.rmtree(temp_dir)
|
100 |
+
|
101 |
+
def save_to_disc(self, path: str) -> None:
|
102 |
+
"""Save the store to disc. Depending on the path, this can be a directory or an archive.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
path (str): The path to a directory or an archive.
|
106 |
+
"""
|
107 |
+
# if it is a zip file, save to archive
|
108 |
+
if path.lower().endswith(".zip"):
|
109 |
+
logger.info(f"Saving to archive at {path} ...")
|
110 |
+
# get base name without extension
|
111 |
+
base_name = os.path.splitext(path)[0]
|
112 |
+
result_path = self.save_to_archive(base_name, format="zip")
|
113 |
+
if not result_path.endswith(path):
|
114 |
+
logger.warning(f"Saved to {result_path} instead of {path}.")
|
115 |
+
# if it does not have an extension, save to directory
|
116 |
+
elif not os.path.splitext(path)[1]:
|
117 |
+
logger.info(f"Saving to directory at {path} ...")
|
118 |
+
self.save_to_directory(path)
|
119 |
+
else:
|
120 |
+
raise ValueError("Unsupported file format. Only .zip and directories are supported.")
|
121 |
+
|
122 |
+
def load_from_disc(self, path: str) -> None:
|
123 |
+
"""Load the store from disc. Depending on the path, this can be a directory or an archive.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
path (str): The path to a directory or an archive.
|
127 |
+
"""
|
128 |
+
# if it is a zip file, load from archive
|
129 |
+
if path.lower().endswith(".zip"):
|
130 |
+
logger.info(f"Loading from archive at {path} ...")
|
131 |
+
self.load_from_archive(path)
|
132 |
+
# if it is a directory, load from directory
|
133 |
+
elif os.path.isdir(path):
|
134 |
+
logger.info(f"Loading from directory at {path} ...")
|
135 |
+
self.load_from_directory(path)
|
136 |
+
else:
|
137 |
+
raise ValueError("Unsupported file format. Only .zip and directories are supported.")
|
src/langchain_modules/span_embeddings.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any, Dict, List, Optional, Union # type: ignore[import-not-found]
|
3 |
+
|
4 |
+
from langchain_core.embeddings import Embeddings
|
5 |
+
from langchain_core.runnables.config import run_in_executor
|
6 |
+
|
7 |
+
|
8 |
+
class SpanEmbeddings(Embeddings, ABC):
|
9 |
+
"""Interface for models that embed text spans within documents."""
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def embed_document_spans(
|
13 |
+
self,
|
14 |
+
texts: list[str],
|
15 |
+
starts: Union[list[int], List[List[int]]],
|
16 |
+
ends: Union[list[int], List[List[int]]],
|
17 |
+
) -> list[Optional[list[float]]]:
|
18 |
+
"""Embed search docs.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
texts: List of text to embed.
|
22 |
+
starts: List of start indices or list of lists of start indices (multi-span).
|
23 |
+
ends: List of end indices or list of lists of end indices (multi-span).
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
List of embeddings.
|
27 |
+
"""
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def embed_query_span(
|
31 |
+
self, text: str, start: Union[int, list[int]], end: Union[int, list[int]]
|
32 |
+
) -> Optional[list[float]]:
|
33 |
+
"""Embed query text.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
text: Text to embed.
|
37 |
+
start: Start index or list of start indices (multi-span).
|
38 |
+
end: End index or list of end indices (multi-span).
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
Embedding.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def embed_documents(self, texts: list[str]) -> list[Optional[list[float]]]:
|
45 |
+
"""Embed search docs.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
texts: List of text to embed.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
List of embeddings.
|
52 |
+
"""
|
53 |
+
return self.embed_document_spans(texts, [0] * len(texts), [len(text) for text in texts])
|
54 |
+
|
55 |
+
def embed_query(self, text: str) -> Optional[list[float]]:
|
56 |
+
"""Embed query text.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
text: Text to embed.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
Embedding.
|
63 |
+
"""
|
64 |
+
return self.embed_query_span(text, 0, len(text))
|
65 |
+
|
66 |
+
async def aembed_document_spans(
|
67 |
+
self,
|
68 |
+
texts: list[str],
|
69 |
+
starts: Union[list[int], list[list[int]]],
|
70 |
+
ends: Union[list[int], list[list[int]]],
|
71 |
+
) -> list[Optional[list[float]]]:
|
72 |
+
"""Asynchronous Embed search docs.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
texts: List of text to embed.
|
76 |
+
starts: List of start indices or list of lists of start indices (multi-span).
|
77 |
+
ends: List of end indices or list of lists of end indices (multi-span).
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
List of embeddings.
|
81 |
+
"""
|
82 |
+
return await run_in_executor(None, self.embed_document_spans, texts, starts, ends)
|
83 |
+
|
84 |
+
async def aembed_query_spans(
|
85 |
+
self, text: str, start: Union[int, list[int]], end: Union[int, list[int]]
|
86 |
+
) -> Optional[list[float]]:
|
87 |
+
"""Asynchronous Embed query text.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
text: Text to embed.
|
91 |
+
start: Start index or list of start indices (multi-span).
|
92 |
+
end: End index or list of end indices (multi-span).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Embedding.
|
96 |
+
"""
|
97 |
+
return await run_in_executor(None, self.embed_query_span, text, start, end)
|
98 |
+
|
99 |
+
@property
|
100 |
+
@abstractmethod
|
101 |
+
def embedding_dim(self) -> int:
|
102 |
+
"""Get the embedding dimension."""
|
103 |
+
...
|
src/langchain_modules/span_retriever.py
ADDED
@@ -0,0 +1,860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
from collections import defaultdict
|
5 |
+
from copy import copy
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
|
8 |
+
|
9 |
+
from langchain_core.callbacks import (
|
10 |
+
AsyncCallbackManagerForRetrieverRun,
|
11 |
+
CallbackManagerForRetrieverRun,
|
12 |
+
)
|
13 |
+
from langchain_core.documents import BaseDocumentCompressor
|
14 |
+
from langchain_core.documents import Document as LCDocument
|
15 |
+
from langchain_core.retrievers import BaseRetriever
|
16 |
+
from pydantic import Field
|
17 |
+
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
|
18 |
+
from pytorch_ie.core.document import BaseAnnotationList
|
19 |
+
from pytorch_ie.documents import (
|
20 |
+
TextBasedDocument,
|
21 |
+
TextDocumentWithLabeledMultiSpans,
|
22 |
+
TextDocumentWithLabeledSpans,
|
23 |
+
TextDocumentWithSpans,
|
24 |
+
)
|
25 |
+
|
26 |
+
from .pie_document_store import PieDocumentStore
|
27 |
+
from .serializable_store import SerializableStore
|
28 |
+
from .span_vectorstore import SpanVectorStore
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
def _parse_config(config_string: str, format: str) -> Dict[str, Any]:
|
34 |
+
"""Parse a configuration string."""
|
35 |
+
if format == "json":
|
36 |
+
import json
|
37 |
+
|
38 |
+
return json.loads(config_string)
|
39 |
+
elif format == "yaml":
|
40 |
+
import yaml
|
41 |
+
|
42 |
+
return yaml.safe_load(config_string)
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unsupported format: {format}. Use 'json' or 'yaml'.")
|
45 |
+
|
46 |
+
|
47 |
+
METADATA_KEY_CHILD_ID2IDX = "child_id2idx"
|
48 |
+
|
49 |
+
|
50 |
+
class SpanNotFoundError(ValueError):
|
51 |
+
def __init__(self, span_id: str, doc_id: Optional[str] = None, msg: Optional[str] = None):
|
52 |
+
if msg is None:
|
53 |
+
if doc_id is not None:
|
54 |
+
msg = f"Span with id [{span_id}] not found in document [{doc_id}]"
|
55 |
+
else:
|
56 |
+
msg = f"Span with id [{span_id}] not found in the vectorstore"
|
57 |
+
super().__init__(msg)
|
58 |
+
self.span_id = span_id
|
59 |
+
self.doc_id = doc_id
|
60 |
+
|
61 |
+
|
62 |
+
class DocumentNotFoundError(ValueError):
|
63 |
+
def __init__(self, doc_id: str, msg: Optional[str] = None):
|
64 |
+
msg = msg or f"Document with id [{doc_id}] not found in the docstore"
|
65 |
+
super().__init__(msg)
|
66 |
+
self.doc_id = doc_id
|
67 |
+
|
68 |
+
|
69 |
+
class SearchType(str, Enum):
|
70 |
+
"""Enumerator of the types of search to perform."""
|
71 |
+
|
72 |
+
similarity = "similarity"
|
73 |
+
"""Similarity search."""
|
74 |
+
similarity_score_threshold = "similarity_score_threshold"
|
75 |
+
"""Similarity search with a score threshold."""
|
76 |
+
mmr = "mmr"
|
77 |
+
"""Maximal Marginal Relevance reranking of similarity search."""
|
78 |
+
|
79 |
+
|
80 |
+
class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore):
|
81 |
+
"""Retriever for contextualized text spans, i.e. spans within text documents.
|
82 |
+
It accepts spans as queries and retrieves spans with their containing document.
|
83 |
+
Note that the query span (and its document) must already be in the retriever's
|
84 |
+
store."""
|
85 |
+
|
86 |
+
pie_document_type: Type[TextBasedDocument]
|
87 |
+
"""The name of the span annotation layer in the pie document."""
|
88 |
+
use_predicted_annotations_key: str = "use_predicted_annotations"
|
89 |
+
"""Whether to use the predicted annotations or the gold annotations."""
|
90 |
+
retrieve_from_same_document: bool = False
|
91 |
+
"""Whether to retrieve spans exclusively from the same document as the query span."""
|
92 |
+
retrieve_from_different_documents: bool = False
|
93 |
+
"""Whether to retrieve spans exclusively from different documents than the query span."""
|
94 |
+
|
95 |
+
# content from langchain_core.retrievers.MultiVectorRetriever
|
96 |
+
vectorstore: SpanVectorStore
|
97 |
+
"""The underlying vectorstore to use to store small chunks
|
98 |
+
and their embedding vectors"""
|
99 |
+
docstore: PieDocumentStore
|
100 |
+
"""The storage interface for the parent documents"""
|
101 |
+
id_key: str = "doc_id"
|
102 |
+
"""The key to use to track the parent id. This will be stored in the
|
103 |
+
metadata of child documents."""
|
104 |
+
search_kwargs: dict = Field(default_factory=dict)
|
105 |
+
"""Keyword arguments to pass to the search function."""
|
106 |
+
search_type: SearchType = SearchType.similarity
|
107 |
+
"""Type of search to perform (similarity / mmr)"""
|
108 |
+
|
109 |
+
# content taken from langchain_core.retrievers.ParentDocumentRetriever
|
110 |
+
child_metadata_fields: Optional[Sequence[str]] = None
|
111 |
+
"""Metadata fields to leave in child documents. If None, leave all parent document
|
112 |
+
metadata.
|
113 |
+
"""
|
114 |
+
|
115 |
+
# re-ranking
|
116 |
+
compressor: Optional[BaseDocumentCompressor] = None
|
117 |
+
"""Compressor for compressing retrieved documents."""
|
118 |
+
compressor_context_size: int = 50
|
119 |
+
"""Size of the context to use around the query and retrieved spans when compressing."""
|
120 |
+
compressor_query_context_size: Optional[int] = 10
|
121 |
+
"""Size of the context to use around the query when compressing. If None, will use the
|
122 |
+
same value as `compressor_context_size`."""
|
123 |
+
|
124 |
+
@classmethod
|
125 |
+
def instantiate_from_config(
|
126 |
+
cls, config: Dict[str, Any], overwrites: Optional[Dict[str, Any]] = None
|
127 |
+
) -> "DocumentAwareSpanRetriever":
|
128 |
+
"""Instantiate a retriever from a configuration dictionary."""
|
129 |
+
from hydra.utils import instantiate
|
130 |
+
|
131 |
+
return instantiate(config, **(overwrites or {}))
|
132 |
+
|
133 |
+
@classmethod
|
134 |
+
def instantiate_from_config_string(
|
135 |
+
cls, config_string: str, format: str, overwrites: Optional[Dict[str, Any]] = None
|
136 |
+
) -> "DocumentAwareSpanRetriever":
|
137 |
+
"""Instantiate a retriever from a configuration string."""
|
138 |
+
return cls.instantiate_from_config(
|
139 |
+
_parse_config(config_string, format=format), overwrites=overwrites
|
140 |
+
)
|
141 |
+
|
142 |
+
@classmethod
|
143 |
+
def instantiate_from_config_file(
|
144 |
+
cls, config_path: str, overwrites: Optional[Dict[str, Any]] = None
|
145 |
+
) -> "DocumentAwareSpanRetriever":
|
146 |
+
"""Instantiate a retriever from a configuration file."""
|
147 |
+
with open(config_path, "r") as file:
|
148 |
+
config_string = file.read()
|
149 |
+
if config_path.endswith(".json"):
|
150 |
+
return cls.instantiate_from_config_string(
|
151 |
+
config_string, format="json", overwrites=overwrites
|
152 |
+
)
|
153 |
+
elif config_path.endswith(".yaml"):
|
154 |
+
return cls.instantiate_from_config_string(
|
155 |
+
config_string, format="yaml", overwrites=overwrites
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"Unsupported file extension: {config_path}")
|
159 |
+
|
160 |
+
@property
|
161 |
+
def pie_annotation_layer_name(self) -> str:
|
162 |
+
if issubclass(self.pie_document_type, TextDocumentWithSpans):
|
163 |
+
return "spans"
|
164 |
+
elif issubclass(self.pie_document_type, TextDocumentWithLabeledSpans):
|
165 |
+
return "labeled_spans"
|
166 |
+
elif issubclass(self.pie_document_type, TextDocumentWithLabeledMultiSpans):
|
167 |
+
return "labeled_multi_spans"
|
168 |
+
else:
|
169 |
+
raise ValueError(
|
170 |
+
f"Unsupported pie document type: {self.pie_document_type}. "
|
171 |
+
"Must be one of TextDocumentWithSpans, TextDocumentWithLabeledSpans, "
|
172 |
+
"or TextDocumentWithLabeledMultiSpans."
|
173 |
+
)
|
174 |
+
|
175 |
+
def _span_to_dict(self, span: Union[Span, MultiSpan]) -> dict:
|
176 |
+
span_dict = {}
|
177 |
+
if isinstance(span, Span):
|
178 |
+
span_dict[self.vectorstore.SPAN_START_KEY] = span.start
|
179 |
+
span_dict[self.vectorstore.SPAN_END_KEY] = span.end
|
180 |
+
span_dict["type"] = "Span"
|
181 |
+
elif isinstance(span, MultiSpan):
|
182 |
+
starts, ends = zip(*span.slices)
|
183 |
+
span_dict[self.vectorstore.SPAN_START_KEY] = starts
|
184 |
+
span_dict[self.vectorstore.SPAN_END_KEY] = ends
|
185 |
+
span_dict["type"] = "MultiSpan"
|
186 |
+
else:
|
187 |
+
raise ValueError(f"Unsupported span type: {type(span)}")
|
188 |
+
if isinstance(span, (LabeledSpan, LabeledMultiSpan)):
|
189 |
+
span_dict["label"] = span.label
|
190 |
+
span_dict["score"] = span.score
|
191 |
+
return span_dict
|
192 |
+
|
193 |
+
def _dict_to_span(self, span_dict: dict) -> Union[Span, MultiSpan]:
|
194 |
+
|
195 |
+
if span_dict["type"] == "Span":
|
196 |
+
kwargs = dict(
|
197 |
+
start=span_dict[self.vectorstore.SPAN_START_KEY],
|
198 |
+
end=span_dict[self.vectorstore.SPAN_END_KEY],
|
199 |
+
)
|
200 |
+
if "label" in span_dict:
|
201 |
+
kwargs["label"] = span_dict["label"]
|
202 |
+
kwargs["score"] = span_dict["score"]
|
203 |
+
return LabeledSpan(**kwargs)
|
204 |
+
else:
|
205 |
+
return Span(**kwargs)
|
206 |
+
elif span_dict["type"] == "MultiSpan":
|
207 |
+
starts = span_dict[self.vectorstore.SPAN_START_KEY]
|
208 |
+
ends = span_dict[self.vectorstore.SPAN_END_KEY]
|
209 |
+
slices = tuple((start, end) for start, end in zip(starts, ends))
|
210 |
+
kwargs = dict(slices=slices)
|
211 |
+
if "label" in span_dict:
|
212 |
+
kwargs["label"] = span_dict["label"]
|
213 |
+
kwargs["score"] = span_dict["score"]
|
214 |
+
return LabeledMultiSpan(**kwargs)
|
215 |
+
else:
|
216 |
+
return MultiSpan(**kwargs)
|
217 |
+
else:
|
218 |
+
raise ValueError(f"Unsupported span type: {span_dict['type']}")
|
219 |
+
|
220 |
+
def use_predicted_annotations(self, doc: LCDocument) -> bool:
|
221 |
+
"""Check if the document uses predicted spans."""
|
222 |
+
return doc.metadata.get(self.use_predicted_annotations_key, True)
|
223 |
+
|
224 |
+
def get_document(self, doc_id: str) -> LCDocument:
|
225 |
+
"""Get a document by its id."""
|
226 |
+
documents = self.docstore.mget([doc_id])
|
227 |
+
if len(documents) == 0 or documents[0] is None:
|
228 |
+
raise DocumentNotFoundError(doc_id=doc_id)
|
229 |
+
if len(documents) > 1:
|
230 |
+
raise ValueError(f"Multiple documents found with id: {doc_id}")
|
231 |
+
return documents[0]
|
232 |
+
|
233 |
+
def get_span_document(self, span_id: str, with_vector: bool = False) -> LCDocument:
|
234 |
+
"""Get a span document by its id."""
|
235 |
+
if with_vector:
|
236 |
+
span_docs = self.vectorstore.get_by_ids_with_vectors([span_id])
|
237 |
+
else:
|
238 |
+
span_docs = self.vectorstore.get_by_ids([span_id])
|
239 |
+
if len(span_docs) == 0 or span_docs[0] is None:
|
240 |
+
raise SpanNotFoundError(span_id=span_id)
|
241 |
+
if len(span_docs) > 1:
|
242 |
+
raise ValueError(f"Multiple span documents found with id: {span_id}")
|
243 |
+
return span_docs[0]
|
244 |
+
|
245 |
+
def get_base_layer(
|
246 |
+
self, pie_document: TextBasedDocument, use_predicted_annotations: bool
|
247 |
+
) -> BaseAnnotationList:
|
248 |
+
"""Get the base layer of the pie document."""
|
249 |
+
|
250 |
+
if self.pie_annotation_layer_name not in pie_document:
|
251 |
+
raise ValueError(
|
252 |
+
f'The pie document must contain the annotation layer "{self.pie_annotation_layer_name}"'
|
253 |
+
)
|
254 |
+
layer = pie_document[self.pie_annotation_layer_name]
|
255 |
+
return layer.predictions if use_predicted_annotations else layer
|
256 |
+
|
257 |
+
def get_span_by_id(self, span_id: str) -> Union[Span, MultiSpan]:
|
258 |
+
"""Get a span annotation by its id."""
|
259 |
+
span_doc = self.get_span_document(span_id)
|
260 |
+
doc_id = span_doc.metadata[self.id_key]
|
261 |
+
doc = self.get_document(doc_id)
|
262 |
+
return self.get_span_from_doc_by_id(doc=doc, span_id=span_id)
|
263 |
+
|
264 |
+
def get_span_from_doc_by_id(self, doc: LCDocument, span_id: str) -> Union[Span, MultiSpan]:
|
265 |
+
"""Get the span of a query."""
|
266 |
+
base_layer = self.get_base_layer(
|
267 |
+
self.docstore.unwrap(doc),
|
268 |
+
use_predicted_annotations=self.use_predicted_annotations(doc),
|
269 |
+
)
|
270 |
+
span_idx = doc.metadata[METADATA_KEY_CHILD_ID2IDX].get(span_id)
|
271 |
+
if span_idx is None:
|
272 |
+
raise SpanNotFoundError(span_id=span_id, doc_id=doc.id)
|
273 |
+
return base_layer[span_idx]
|
274 |
+
|
275 |
+
def get_span_id2idx_from_doc(self, doc: Union[LCDocument, str]) -> Dict[str, int]:
|
276 |
+
"""Get all span ids from a document.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
doc: Document or document id
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
Dictionary mapping span ids to their index in the base layer.
|
283 |
+
"""
|
284 |
+
|
285 |
+
if isinstance(doc, str):
|
286 |
+
doc = self.get_document(doc)
|
287 |
+
return doc.metadata[METADATA_KEY_CHILD_ID2IDX]
|
288 |
+
|
289 |
+
def prepare_search_kwargs(
|
290 |
+
self,
|
291 |
+
span_id: str,
|
292 |
+
doc_id_whitelist: Optional[List[str]] = None,
|
293 |
+
doc_id_blacklist: Optional[List[str]] = None,
|
294 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
295 |
+
) -> Tuple[dict, LCDocument]:
|
296 |
+
# get the span document
|
297 |
+
query_span_doc = self.get_span_document(span_id, with_vector=True)
|
298 |
+
query_doc_id = query_span_doc.metadata[self.id_key]
|
299 |
+
query_doc = self.get_document(query_doc_id)
|
300 |
+
|
301 |
+
# TODO: why do we do this? Just to be the same as the result of the search when doing compression?
|
302 |
+
# add "pie_document" to the metadata
|
303 |
+
query_span_doc.metadata[self.docstore.METADATA_KEY_PIE_DOCUMENT] = self.docstore.unwrap(
|
304 |
+
query_doc
|
305 |
+
)
|
306 |
+
|
307 |
+
search_kwargs = copy(self.search_kwargs)
|
308 |
+
search_kwargs.update(kwargs or {})
|
309 |
+
|
310 |
+
query_span = self.get_span_from_doc_by_id(doc=query_doc, span_id=span_id)
|
311 |
+
|
312 |
+
if self.retrieve_from_different_documents and self.retrieve_from_same_document:
|
313 |
+
raise ValueError("Cannot retrieve from both same and different documents")
|
314 |
+
|
315 |
+
if self.retrieve_from_same_document:
|
316 |
+
if doc_id_whitelist is None:
|
317 |
+
doc_id_whitelist = [query_doc_id]
|
318 |
+
elif query_doc_id not in doc_id_whitelist:
|
319 |
+
doc_id_whitelist.append(query_doc_id)
|
320 |
+
|
321 |
+
if self.retrieve_from_different_documents:
|
322 |
+
if doc_id_blacklist is None:
|
323 |
+
doc_id_blacklist = [query_doc_id]
|
324 |
+
elif query_doc_id not in doc_id_blacklist:
|
325 |
+
doc_id_blacklist.append(query_doc_id)
|
326 |
+
|
327 |
+
query_filter = self.vectorstore.construct_filter(
|
328 |
+
query_span=query_span,
|
329 |
+
metadata_doc_id_key=self.id_key,
|
330 |
+
doc_id_whitelist=doc_id_whitelist,
|
331 |
+
doc_id_blacklist=doc_id_blacklist,
|
332 |
+
)
|
333 |
+
if query_filter is not None:
|
334 |
+
search_kwargs["filter"] = query_filter
|
335 |
+
|
336 |
+
# get the vector of the reference span
|
337 |
+
search_kwargs["embedding"] = query_span_doc.metadata[self.vectorstore.METADATA_VECTOR_KEY]
|
338 |
+
return search_kwargs, query_span_doc
|
339 |
+
|
340 |
+
def _prepare_query_for_compression(self, query_doc: LCDocument) -> str:
|
341 |
+
return self._prepare_doc_for_compression(
|
342 |
+
query_doc, context_size=self.compressor_query_context_size
|
343 |
+
).page_content
|
344 |
+
|
345 |
+
def _prepare_doc_for_compression(
|
346 |
+
self, doc: LCDocument, context_size: Optional[int] = None
|
347 |
+
) -> LCDocument:
|
348 |
+
if context_size is None:
|
349 |
+
context_size = self.compressor_context_size
|
350 |
+
pie_doc: TextBasedDocument = self.docstore.unwrap(doc)
|
351 |
+
text = pie_doc.text
|
352 |
+
span_dict = doc.metadata[self.vectorstore.METADATA_SPAN_KEY]
|
353 |
+
span_start = span_dict[self.vectorstore.SPAN_START_KEY]
|
354 |
+
span_end = span_dict[self.vectorstore.SPAN_END_KEY]
|
355 |
+
if isinstance(span_start, list):
|
356 |
+
span_start = span_start[0]
|
357 |
+
if isinstance(span_end, list):
|
358 |
+
span_end = span_end[0]
|
359 |
+
context_start = span_start - context_size
|
360 |
+
context_end = span_end + context_size
|
361 |
+
doc.page_content = text[max(0, context_start) : min(context_end, len(text))]
|
362 |
+
# save the original relevance score and remove it because otherwise we will not be able to get
|
363 |
+
# the reranking relevance score
|
364 |
+
if "relevance_score" in doc.metadata:
|
365 |
+
doc.metadata["relevance_score_without_reranking"] = doc.metadata.pop("relevance_score")
|
366 |
+
return doc
|
367 |
+
|
368 |
+
def _get_relevant_documents(
|
369 |
+
self,
|
370 |
+
query: str,
|
371 |
+
*,
|
372 |
+
run_manager: CallbackManagerForRetrieverRun,
|
373 |
+
doc_id_whitelist: Optional[List[str]] = None,
|
374 |
+
doc_id_blacklist: Optional[List[str]] = None,
|
375 |
+
**kwargs: Any,
|
376 |
+
) -> List[LCDocument]:
|
377 |
+
"""Get span documents relevant to a query span
|
378 |
+
Args:
|
379 |
+
query: The span id to find relevant spans for
|
380 |
+
run_manager: The callbacks handler to use
|
381 |
+
Returns:
|
382 |
+
List of relevant span documents with metadata from the parent document
|
383 |
+
"""
|
384 |
+
|
385 |
+
search_kwargs, query_span_doc = self.prepare_search_kwargs(
|
386 |
+
span_id=query,
|
387 |
+
kwargs=kwargs,
|
388 |
+
doc_id_whitelist=doc_id_whitelist,
|
389 |
+
doc_id_blacklist=doc_id_blacklist,
|
390 |
+
)
|
391 |
+
if self.search_type == SearchType.mmr:
|
392 |
+
span_docs = self.vectorstore.max_marginal_relevance_search_by_vector(**search_kwargs)
|
393 |
+
elif self.search_type == SearchType.similarity_score_threshold:
|
394 |
+
sub_docs_and_similarities = self.vectorstore.similarity_search_with_score_by_vector(
|
395 |
+
**search_kwargs
|
396 |
+
)
|
397 |
+
span_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
|
398 |
+
else:
|
399 |
+
span_docs = self.vectorstore.similarity_search_by_vector(**search_kwargs)
|
400 |
+
|
401 |
+
# We do this to maintain the order of the ids that are returned
|
402 |
+
doc_ids = []
|
403 |
+
for span_doc in span_docs:
|
404 |
+
if self.id_key not in span_doc.metadata:
|
405 |
+
raise ValueError(f"Metadata must contain the key {self.id_key}")
|
406 |
+
if span_doc.metadata[self.id_key] not in doc_ids:
|
407 |
+
doc_ids.append(span_doc.metadata[self.id_key])
|
408 |
+
docs = self.docstore.mget(doc_ids)
|
409 |
+
doc_id2doc = dict(zip(doc_ids, docs))
|
410 |
+
for span_doc in span_docs:
|
411 |
+
doc = doc_id2doc[span_doc.metadata[self.id_key]]
|
412 |
+
span_doc.metadata.update(doc.metadata)
|
413 |
+
span_doc.metadata["attached_span"] = self.get_span_from_doc_by_id(
|
414 |
+
doc=doc, span_id=span_doc.id
|
415 |
+
)
|
416 |
+
span_doc.metadata["query_span_id"] = query
|
417 |
+
# filter out the query span doc
|
418 |
+
span_docs_filtered = [
|
419 |
+
span_doc for span_doc in span_docs if span_doc.id != query_span_doc.id
|
420 |
+
]
|
421 |
+
if self.compressor is None:
|
422 |
+
return span_docs_filtered
|
423 |
+
if span_docs_filtered:
|
424 |
+
prepared_docs = [
|
425 |
+
self._prepare_doc_for_compression(sub_doc) for sub_doc in span_docs_filtered
|
426 |
+
]
|
427 |
+
prepared_query = self._prepare_query_for_compression(query_span_doc)
|
428 |
+
compressed_docs = self.compressor.compress_documents(
|
429 |
+
documents=prepared_docs, query=prepared_query, callbacks=run_manager.get_child()
|
430 |
+
)
|
431 |
+
return list(compressed_docs)
|
432 |
+
else:
|
433 |
+
return []
|
434 |
+
|
435 |
+
async def _aget_relevant_documents(
|
436 |
+
self,
|
437 |
+
query: str,
|
438 |
+
*,
|
439 |
+
run_manager: AsyncCallbackManagerForRetrieverRun,
|
440 |
+
doc_id_whitelist: Optional[List[str]] = None,
|
441 |
+
doc_id_blacklist: Optional[List[str]] = None,
|
442 |
+
**kwargs: Any,
|
443 |
+
) -> List[LCDocument]:
|
444 |
+
"""Asynchronously get span documents relevant to a query span
|
445 |
+
Args:
|
446 |
+
query: The span id to find relevant spans for
|
447 |
+
run_manager: The callbacks handler to use
|
448 |
+
Returns:
|
449 |
+
List of relevant span documents with metadata from the parent document
|
450 |
+
"""
|
451 |
+
search_kwargs, query_span_doc = self.prepare_search_kwargs(
|
452 |
+
span_id=query,
|
453 |
+
kwargs=kwargs,
|
454 |
+
doc_id_whitelist=doc_id_whitelist,
|
455 |
+
doc_id_blacklist=doc_id_blacklist,
|
456 |
+
)
|
457 |
+
if self.search_type == SearchType.mmr:
|
458 |
+
span_docs = await self.vectorstore.amax_marginal_relevance_search_by_vector(
|
459 |
+
**search_kwargs
|
460 |
+
)
|
461 |
+
elif self.search_type == SearchType.similarity_score_threshold:
|
462 |
+
sub_docs_and_similarities = (
|
463 |
+
await self.vectorstore.asimilarity_search_with_score_by_vector(**search_kwargs)
|
464 |
+
)
|
465 |
+
span_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
|
466 |
+
else:
|
467 |
+
span_docs = await self.vectorstore.asimilarity_search_by_vector(**search_kwargs)
|
468 |
+
|
469 |
+
# We do this to maintain the order of the ids that are returned
|
470 |
+
ids = []
|
471 |
+
for span_doc in span_docs:
|
472 |
+
if self.id_key not in span_doc.metadata:
|
473 |
+
raise ValueError(f"Metadata must contain the key {self.id_key}")
|
474 |
+
if span_doc.metadata[self.id_key] not in ids:
|
475 |
+
ids.append(span_doc.metadata[self.id_key])
|
476 |
+
docs = await self.docstore.amget(ids)
|
477 |
+
doc_id2doc = dict(zip(ids, docs))
|
478 |
+
for span_doc in span_docs:
|
479 |
+
doc = doc_id2doc[span_doc.metadata[self.id_key]]
|
480 |
+
span_doc.metadata.update(doc.metadata)
|
481 |
+
span_doc.metadata["attached_span"] = self.get_span_from_doc_by_id(
|
482 |
+
doc=doc, span_id=span_doc.id
|
483 |
+
)
|
484 |
+
span_doc.metadata["query_span_id"] = query
|
485 |
+
# filter out the query span doc
|
486 |
+
span_docs_filtered = [
|
487 |
+
span_doc for span_doc in span_docs if span_doc.id != query_span_doc.id
|
488 |
+
]
|
489 |
+
|
490 |
+
if self.compressor is None:
|
491 |
+
return span_docs_filtered
|
492 |
+
if docs:
|
493 |
+
prepared_docs = [
|
494 |
+
self._prepare_doc_for_compression(sub_doc) for sub_doc in span_docs_filtered
|
495 |
+
]
|
496 |
+
prepared_query = self._prepare_query_for_compression(query_span_doc)
|
497 |
+
compressed_docs = await self.base_compressor.acompress_documents(
|
498 |
+
prepared_docs, query=prepared_query, callbacks=run_manager.get_child()
|
499 |
+
)
|
500 |
+
return list(compressed_docs)
|
501 |
+
else:
|
502 |
+
return []
|
503 |
+
|
504 |
+
def create_span_documents(
|
505 |
+
self, documents: List[LCDocument]
|
506 |
+
) -> Tuple[List[LCDocument], Dict[str, int]]:
|
507 |
+
span_docs = []
|
508 |
+
id2idx = {}
|
509 |
+
for i, doc in enumerate(documents):
|
510 |
+
pie_doc, metadata = self.docstore.unwrap_with_metadata(doc)
|
511 |
+
base_layer = self.get_base_layer(
|
512 |
+
pie_doc, use_predicted_annotations=self.use_predicted_annotations(doc)
|
513 |
+
)
|
514 |
+
if len(base_layer) == 0:
|
515 |
+
logger.warning(f"No spans found in document {i} (id: {doc.id})")
|
516 |
+
for idx, labeled_span in enumerate(base_layer):
|
517 |
+
_metadata = {k: v for k, v in metadata.items() if k != METADATA_KEY_CHILD_ID2IDX}
|
518 |
+
# save as dict to avoid serialization issues
|
519 |
+
_metadata[self.vectorstore.METADATA_SPAN_KEY] = self._span_to_dict(labeled_span)
|
520 |
+
new_doc = LCDocument(
|
521 |
+
id=str(uuid.uuid4()), page_content=pie_doc.text, metadata=_metadata
|
522 |
+
)
|
523 |
+
span_docs.append(new_doc)
|
524 |
+
id2idx[new_doc.id] = idx
|
525 |
+
return span_docs, id2idx
|
526 |
+
|
527 |
+
def _split_docs_for_adding(
|
528 |
+
self,
|
529 |
+
documents: List[LCDocument],
|
530 |
+
ids: Optional[List[str]] = None,
|
531 |
+
add_to_docstore: bool = True,
|
532 |
+
) -> Tuple[List[LCDocument], List[Tuple[str, LCDocument]]]:
|
533 |
+
if ids is None:
|
534 |
+
doc_ids = [doc.id for doc in documents]
|
535 |
+
if not add_to_docstore:
|
536 |
+
raise ValueError("If ids are not passed in, `add_to_docstore` MUST be True")
|
537 |
+
else:
|
538 |
+
if len(documents) != len(ids):
|
539 |
+
raise ValueError(
|
540 |
+
"Got uneven list of documents and ids. "
|
541 |
+
"If `ids` is provided, should be same length as `documents`."
|
542 |
+
)
|
543 |
+
doc_ids = ids
|
544 |
+
|
545 |
+
if len(set(doc_ids)) != len(doc_ids):
|
546 |
+
raise ValueError("IDs must be unique")
|
547 |
+
|
548 |
+
docs = []
|
549 |
+
full_docs = []
|
550 |
+
for i, doc in enumerate(documents):
|
551 |
+
_id = doc_ids[i]
|
552 |
+
sub_docs, sub_doc_id2idx = self.create_span_documents([doc])
|
553 |
+
if self.child_metadata_fields is not None:
|
554 |
+
for sub_doc in sub_docs:
|
555 |
+
sub_doc.metadata = {k: sub_doc.metadata[k] for k in self.child_metadata_fields}
|
556 |
+
for sub_doc in sub_docs:
|
557 |
+
# Add the parent id to the child document id
|
558 |
+
sub_doc.metadata[self.id_key] = _id
|
559 |
+
docs.extend(sub_docs)
|
560 |
+
doc.metadata[METADATA_KEY_CHILD_ID2IDX] = sub_doc_id2idx
|
561 |
+
full_docs.append((_id, doc))
|
562 |
+
|
563 |
+
return docs, full_docs
|
564 |
+
|
565 |
+
def remove_missing_span_ids_from_document(
|
566 |
+
self, document: LCDocument, span_ids: Set[str]
|
567 |
+
) -> LCDocument:
|
568 |
+
"""Remove invalid span ids from the span to idx mapping
|
569 |
+
of the document.
|
570 |
+
|
571 |
+
Args:
|
572 |
+
document: Document to remove invalid span ids from
|
573 |
+
span_ids: Set of valid span ids
|
574 |
+
|
575 |
+
Returns:
|
576 |
+
Document with invalid span ids removed
|
577 |
+
"""
|
578 |
+
span_id2idx = document.metadata[METADATA_KEY_CHILD_ID2IDX]
|
579 |
+
new_doc = copy(document)
|
580 |
+
filtered_span_id2idx = {
|
581 |
+
span_id: idx for span_id, idx in span_id2idx.items() if span_id in span_ids
|
582 |
+
}
|
583 |
+
new_doc.metadata[METADATA_KEY_CHILD_ID2IDX] = filtered_span_id2idx
|
584 |
+
missed_span_ids = set(span_id2idx.keys()) - span_ids
|
585 |
+
if len(missed_span_ids) > 0:
|
586 |
+
layer = self.get_base_layer(
|
587 |
+
self.docstore.unwrap(document),
|
588 |
+
use_predicted_annotations=self.use_predicted_annotations(document),
|
589 |
+
)
|
590 |
+
resolved_missed_spans = [
|
591 |
+
layer[span_id2idx[span_id]].resolve() for span_id in missed_span_ids
|
592 |
+
]
|
593 |
+
logger.warning(
|
594 |
+
f"Document {document.id} contains spans that can not be added to the "
|
595 |
+
f"vectorstore because no vector could be calculated:\n{resolved_missed_spans}.\n"
|
596 |
+
"These spans will be not queryable."
|
597 |
+
)
|
598 |
+
return document
|
599 |
+
|
600 |
+
def add_documents(
|
601 |
+
self,
|
602 |
+
documents: List[LCDocument],
|
603 |
+
ids: Optional[List[str]] = None,
|
604 |
+
add_to_docstore: bool = True,
|
605 |
+
**kwargs: Any,
|
606 |
+
) -> None:
|
607 |
+
"""Adds documents to the docstore and vectorstores.
|
608 |
+
|
609 |
+
Args:
|
610 |
+
documents: List of documents to add
|
611 |
+
ids: Optional list of ids for documents. If provided should be the same
|
612 |
+
length as the list of documents. Can be provided if parent documents
|
613 |
+
are already in the document store and you don't want to re-add
|
614 |
+
to the docstore. If not provided, random UUIDs will be used as
|
615 |
+
ids.
|
616 |
+
add_to_docstore: Boolean of whether to add documents to docstore.
|
617 |
+
This can be false if and only if `ids` are provided. You may want
|
618 |
+
to set this to False if the documents are already in the docstore
|
619 |
+
and you don't want to re-add them.
|
620 |
+
"""
|
621 |
+
docs, full_docs = self._split_docs_for_adding(documents, ids, add_to_docstore)
|
622 |
+
added_span_ids = self.vectorstore.add_documents(docs, **kwargs)
|
623 |
+
full_docs = [
|
624 |
+
(doc_id, self.remove_missing_span_ids_from_document(doc, set(added_span_ids)))
|
625 |
+
for doc_id, doc in full_docs
|
626 |
+
]
|
627 |
+
if add_to_docstore:
|
628 |
+
self.docstore.mset(full_docs)
|
629 |
+
|
630 |
+
async def aadd_documents(
|
631 |
+
self,
|
632 |
+
documents: List[LCDocument],
|
633 |
+
ids: Optional[List[str]] = None,
|
634 |
+
add_to_docstore: bool = True,
|
635 |
+
**kwargs: Any,
|
636 |
+
) -> None:
|
637 |
+
docs, full_docs = self._split_docs_for_adding(documents, ids, add_to_docstore)
|
638 |
+
added_span_ids = await self.vectorstore.aadd_documents(docs, **kwargs)
|
639 |
+
full_docs = [
|
640 |
+
(doc_id, self.remove_missing_span_ids_from_document(doc, set(added_span_ids)))
|
641 |
+
for doc_id, doc in full_docs
|
642 |
+
]
|
643 |
+
if add_to_docstore:
|
644 |
+
await self.docstore.amset(full_docs)
|
645 |
+
|
646 |
+
def delete_documents(self, ids: List[str]) -> None:
|
647 |
+
"""Remove documents from the docstore and vectorstores.
|
648 |
+
|
649 |
+
Args:
|
650 |
+
ids: List of ids to remove
|
651 |
+
"""
|
652 |
+
# get all child ids
|
653 |
+
child_ids = []
|
654 |
+
for doc in self.docstore.mget(ids):
|
655 |
+
child_ids.extend(doc.metadata[METADATA_KEY_CHILD_ID2IDX])
|
656 |
+
|
657 |
+
self.vectorstore.delete(child_ids)
|
658 |
+
self.docstore.mdelete(ids)
|
659 |
+
|
660 |
+
async def adelete_documents(self, ids: List[str]) -> None:
|
661 |
+
"""Asynchronously remove documents from the docstore and vectorstores.
|
662 |
+
|
663 |
+
Args:
|
664 |
+
ids: List of ids to remove
|
665 |
+
"""
|
666 |
+
# get all child ids
|
667 |
+
child_ids = []
|
668 |
+
docs: List[LCDocument] = await self.docstore.amget(ids)
|
669 |
+
for doc in docs:
|
670 |
+
child_ids.extend(doc.metadata[METADATA_KEY_CHILD_ID2IDX])
|
671 |
+
|
672 |
+
await self.vectorstore.adelete(child_ids)
|
673 |
+
await self.docstore.amdelete(ids)
|
674 |
+
|
675 |
+
def add_pie_documents(
|
676 |
+
self,
|
677 |
+
documents: List[TextBasedDocument],
|
678 |
+
use_predicted_annotations: bool,
|
679 |
+
metadata: Optional[Dict[str, Any]] = None,
|
680 |
+
) -> None:
|
681 |
+
"""Add pie documents to the retriever.
|
682 |
+
|
683 |
+
Args:
|
684 |
+
documents: List of pie documents to add
|
685 |
+
use_predicted_annotations: Whether to use the predicted annotations or the gold annotations
|
686 |
+
metadata: Optional metadata to add to each document
|
687 |
+
"""
|
688 |
+
metadata = metadata or {}
|
689 |
+
metadata = copy(metadata)
|
690 |
+
metadata[self.use_predicted_annotations_key] = use_predicted_annotations
|
691 |
+
docs = [self.docstore.wrap(doc, **metadata) for doc in documents]
|
692 |
+
|
693 |
+
# delete any existing documents with the same ids (simply overwriting would keep the spans)
|
694 |
+
new_docs_ids = [doc.id for doc in docs]
|
695 |
+
existing_docs = self.docstore.mget(new_docs_ids)
|
696 |
+
existing_doc_ids = [doc.id for doc in existing_docs]
|
697 |
+
self.delete_documents(existing_doc_ids)
|
698 |
+
|
699 |
+
self.add_documents(docs)
|
700 |
+
|
701 |
+
def _save_to_directory(self, path: str, **kwargs) -> None:
|
702 |
+
logger.info(f'Saving docstore and vectorstore to "{path}" ...')
|
703 |
+
self.docstore.save_to_directory(os.path.join(path, "docstore"))
|
704 |
+
self.vectorstore.save_to_directory(os.path.join(path, "vectorstore"))
|
705 |
+
|
706 |
+
def _load_from_directory(self, path: str, **kwargs) -> None:
|
707 |
+
logger.info(f'Loading docstore and vectorstore from "{path}" ...')
|
708 |
+
self.docstore.load_from_directory(os.path.join(path, "docstore"))
|
709 |
+
self.vectorstore.load_from_directory(os.path.join(path, "vectorstore"))
|
710 |
+
|
711 |
+
|
712 |
+
METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES = "relation_label2tails_with_scores"
|
713 |
+
|
714 |
+
|
715 |
+
class DocumentAwareSpanRetrieverWithRelations(DocumentAwareSpanRetriever):
|
716 |
+
"""Retriever for related contextualized text spans, i.e. spans linked by relations
|
717 |
+
to reference spans that are similar to the query span. It accepts spans as queries and
|
718 |
+
retrieves spans with their containing document and the reference span."""
|
719 |
+
|
720 |
+
relation_layer_name: str = "binary_relations"
|
721 |
+
"""The name of the relation annotation layer in the pie document."""
|
722 |
+
relation_labels: Optional[List[str]] = None
|
723 |
+
"""The list of relation labels to consider."""
|
724 |
+
span_labels: Optional[List[str]] = None
|
725 |
+
"""The list of span labels to consider."""
|
726 |
+
reversed_relations_suffix: Optional[str] = None
|
727 |
+
"""Whether to consider reverse relations as well."""
|
728 |
+
|
729 |
+
def get_relation_layer(
|
730 |
+
self, pie_document: TextBasedDocument, use_predicted_annotations: bool
|
731 |
+
) -> BaseAnnotationList:
|
732 |
+
"""Get the relation layer of the pie document."""
|
733 |
+
if self.relation_layer_name not in pie_document:
|
734 |
+
raise ValueError(
|
735 |
+
f'The pie document must contain the annotation layer "{self.relation_layer_name}"'
|
736 |
+
)
|
737 |
+
layer = pie_document[self.relation_layer_name]
|
738 |
+
return layer.predictions if use_predicted_annotations else layer
|
739 |
+
|
740 |
+
def create_span_documents(
|
741 |
+
self, documents: List[LCDocument]
|
742 |
+
) -> Tuple[List[LCDocument], Dict[str, int]]:
|
743 |
+
span_docs = []
|
744 |
+
id2idx = {}
|
745 |
+
for i, doc in enumerate(documents):
|
746 |
+
pie_doc, metadata = self.docstore.unwrap_with_metadata(doc)
|
747 |
+
use_predicted_annotations = self.use_predicted_annotations(doc)
|
748 |
+
base_layer = self.get_base_layer(
|
749 |
+
pie_doc, use_predicted_annotations=use_predicted_annotations
|
750 |
+
)
|
751 |
+
if len(base_layer) == 0:
|
752 |
+
logger.warning(f"No spans found in document {i} (id: {doc.id})")
|
753 |
+
id2span = {str(uuid.uuid4()): span for span in base_layer}
|
754 |
+
span2id = {span: span_id for span_id, span in id2span.items()}
|
755 |
+
if len(id2span) != len(span2id):
|
756 |
+
raise ValueError("Span ids and spans must be unique")
|
757 |
+
relations = self.get_relation_layer(
|
758 |
+
pie_doc, use_predicted_annotations=use_predicted_annotations
|
759 |
+
)
|
760 |
+
head2label2tails_with_scores: Dict[str, Dict[str, List[Tuple[str, float]]]] = (
|
761 |
+
defaultdict(lambda: defaultdict(list))
|
762 |
+
)
|
763 |
+
|
764 |
+
for relation in relations:
|
765 |
+
if self.relation_labels is None or relation.label in self.relation_labels:
|
766 |
+
head2label2tails_with_scores[span2id[relation.head]][relation.label].append(
|
767 |
+
(span2id[relation.tail], relation.score)
|
768 |
+
)
|
769 |
+
if self.reversed_relations_suffix is not None:
|
770 |
+
reversed_label = f"{relation.label}{self.reversed_relations_suffix}"
|
771 |
+
if self.relation_labels is None or reversed_label in self.relation_labels:
|
772 |
+
head2label2tails_with_scores[span2id[relation.tail]][
|
773 |
+
reversed_label
|
774 |
+
].append((span2id[relation.head], relation.score))
|
775 |
+
|
776 |
+
for idx, span in enumerate(base_layer):
|
777 |
+
span_id = span2id[span]
|
778 |
+
_metadata = {k: v for k, v in metadata.items() if k != METADATA_KEY_CHILD_ID2IDX}
|
779 |
+
# save as dict to avoid serialization issues
|
780 |
+
_metadata[self.vectorstore.METADATA_SPAN_KEY] = self._span_to_dict(span)
|
781 |
+
relation_label2tails_with_scores = head2label2tails_with_scores[span_id]
|
782 |
+
_metadata[METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES] = dict(
|
783 |
+
relation_label2tails_with_scores
|
784 |
+
)
|
785 |
+
new_doc = LCDocument(id=span_id, page_content=pie_doc.text, metadata=_metadata)
|
786 |
+
span_docs.append(new_doc)
|
787 |
+
id2idx[span_id] = idx
|
788 |
+
return span_docs, id2idx
|
789 |
+
|
790 |
+
def _get_relevant_documents(
|
791 |
+
self,
|
792 |
+
query: str,
|
793 |
+
return_related: bool = False,
|
794 |
+
*,
|
795 |
+
run_manager: CallbackManagerForRetrieverRun,
|
796 |
+
**kwargs: Any,
|
797 |
+
) -> List[LCDocument]:
|
798 |
+
"""Get span documents relevant to a query span. We follow one hop of relations.
|
799 |
+
|
800 |
+
Args:
|
801 |
+
query: The span id to find relevant spans for
|
802 |
+
return_related: Whether to return related spans
|
803 |
+
run_manager: The callbacks handler to use
|
804 |
+
Returns:
|
805 |
+
List of relevant span documents with metadata from the parent document
|
806 |
+
"""
|
807 |
+
similar_span_docs = super()._get_relevant_documents(
|
808 |
+
query=query, run_manager=run_manager, **kwargs
|
809 |
+
)
|
810 |
+
if not return_related:
|
811 |
+
return similar_span_docs
|
812 |
+
|
813 |
+
related_docs = []
|
814 |
+
for head_span_doc in similar_span_docs:
|
815 |
+
doc_id = head_span_doc.metadata[self.id_key]
|
816 |
+
doc = self.get_document(doc_id)
|
817 |
+
query_span_id = head_span_doc.metadata["query_span_id"]
|
818 |
+
|
819 |
+
for relation_label, tails_with_score in head_span_doc.metadata[
|
820 |
+
METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES
|
821 |
+
].items():
|
822 |
+
for tail_id, relation_score in tails_with_score:
|
823 |
+
# in the case that we query against the same document,
|
824 |
+
# we don't want to return the same span as the query span
|
825 |
+
if tail_id == query_span_id:
|
826 |
+
continue
|
827 |
+
|
828 |
+
try:
|
829 |
+
attached_tail_span = self.get_span_from_doc_by_id(doc=doc, span_id=tail_id)
|
830 |
+
# this may happen if the tail span could not be added to the vectorstore, e.g. because
|
831 |
+
# the token span length is zero and no vector could be calculated
|
832 |
+
except SpanNotFoundError:
|
833 |
+
logger.warning(
|
834 |
+
f"Tail span with id [{tail_id}] not found in the vectorstore. Skipping."
|
835 |
+
)
|
836 |
+
continue
|
837 |
+
|
838 |
+
# TODO: handle via filter? see vectorstore.construct_filter
|
839 |
+
if self.span_labels is not None:
|
840 |
+
if not isinstance(attached_tail_span, (LabeledSpan, LabeledMultiSpan)):
|
841 |
+
raise ValueError(
|
842 |
+
"Span must must be a labeled span if span_labels is provided"
|
843 |
+
)
|
844 |
+
if attached_tail_span.label not in self.span_labels:
|
845 |
+
continue
|
846 |
+
|
847 |
+
related_docs.append(
|
848 |
+
LCDocument(
|
849 |
+
id=tail_id,
|
850 |
+
page_content="",
|
851 |
+
metadata={
|
852 |
+
"relation_score": relation_score,
|
853 |
+
"head_id": head_span_doc.id,
|
854 |
+
"relation_label": relation_label,
|
855 |
+
"attached_tail_span": attached_tail_span,
|
856 |
+
**head_span_doc.metadata,
|
857 |
+
},
|
858 |
+
)
|
859 |
+
)
|
860 |
+
return related_docs
|
src/langchain_modules/span_vectorstore.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
3 |
+
|
4 |
+
from langchain_core.documents import Document as LCDocument
|
5 |
+
from langchain_core.runnables import run_in_executor
|
6 |
+
from langchain_core.stores import BaseStore
|
7 |
+
from langchain_core.vectorstores import VectorStore
|
8 |
+
from pytorch_ie.annotations import MultiSpan, Span
|
9 |
+
|
10 |
+
from .serializable_store import SerializableStore
|
11 |
+
from .span_embeddings import SpanEmbeddings
|
12 |
+
|
13 |
+
|
14 |
+
class SpanVectorStore(VectorStore, BaseStore, SerializableStore, ABC):
|
15 |
+
"""Abstract base class for vector stores specialized in storing
|
16 |
+
and retrieving embeddings for text spans within documents."""
|
17 |
+
|
18 |
+
METADATA_SPAN_KEY: str = "pie_labeled_span"
|
19 |
+
"""Key for the span data in the (langchain) document metadata."""
|
20 |
+
SPAN_START_KEY: str = "start"
|
21 |
+
"""Key for the start of the span in the span data."""
|
22 |
+
SPAN_END_KEY: str = "end"
|
23 |
+
"""Key for the end of the span in the span data."""
|
24 |
+
METADATA_VECTOR_KEY: str = "vector"
|
25 |
+
"""Key for the vector in the (langchain) document metadata."""
|
26 |
+
RELEVANCE_SCORE_KEY: str = "relevance_score"
|
27 |
+
"""Key for the relevance score in the (langchain) document metadata."""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
label_mapping: Optional[Dict[str, List[str]]] = None,
|
32 |
+
**kwargs: Any,
|
33 |
+
):
|
34 |
+
"""Initialize the SpanVectorStore.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
label_mapping: Mapping from query span labels to target span labels. If provided,
|
38 |
+
only spans with a label in the mapping for the query span's label are retrieved.
|
39 |
+
**kwargs: Additional arguments.
|
40 |
+
"""
|
41 |
+
self.label_mapping = label_mapping
|
42 |
+
super().__init__(**kwargs)
|
43 |
+
|
44 |
+
@property
|
45 |
+
def embeddings(self) -> SpanEmbeddings:
|
46 |
+
"""Get the dense embeddings instance that is being used.
|
47 |
+
|
48 |
+
Raises:
|
49 |
+
ValueError: If embeddings are `None`.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Embeddings: An instance of `Embeddings`.
|
53 |
+
"""
|
54 |
+
result = super().embeddings
|
55 |
+
if not isinstance(result, SpanEmbeddings):
|
56 |
+
raise ValueError(f"Embeddings must be of type SpanEmbeddings, but got: {result}")
|
57 |
+
return result
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def get_by_ids_with_vectors(self, ids: Sequence[Union[str, int]], /) -> List[LCDocument]:
|
61 |
+
"""Get documents by their ids.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
ids: List of document ids.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
List of documents including their vectors in the metadata at key `metadata_vector_key`.
|
68 |
+
"""
|
69 |
+
...
|
70 |
+
|
71 |
+
@abstractmethod
|
72 |
+
def construct_filter(
|
73 |
+
self,
|
74 |
+
query_span: Union[Span, MultiSpan],
|
75 |
+
metadata_doc_id_key: str,
|
76 |
+
doc_id_whitelist: Optional[Sequence[str]] = None,
|
77 |
+
doc_id_blacklist: Optional[Sequence[str]] = None,
|
78 |
+
) -> Any:
|
79 |
+
"""Construct a filter for the retrieval. It should enforce that:
|
80 |
+
- if the span is labeled, the retrieved span has the same label, or
|
81 |
+
- if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
|
82 |
+
- if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
|
83 |
+
- if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist
|
84 |
+
|
85 |
+
Args:
|
86 |
+
query_span: The query span.
|
87 |
+
metadata_doc_id_key: The key in the metadata that holds the document id.
|
88 |
+
doc_id_whitelist: A list of document ids to restrict the retrieval to.
|
89 |
+
doc_id_blacklist: A list of document ids to exclude from the retrieval.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
A filter object.
|
93 |
+
"""
|
94 |
+
...
|
95 |
+
|
96 |
+
@abstractmethod
|
97 |
+
def similarity_search_with_score_by_vector(
|
98 |
+
self, embedding: list[float], k: int = 4, **kwargs: Any
|
99 |
+
) -> list[LCDocument]:
|
100 |
+
"""Return docs most similar to embedding vector.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
embedding: Embedding to look up documents similar to.
|
104 |
+
k: Number of Documents to return. Defaults to 4.
|
105 |
+
**kwargs: Arguments to pass to the search method.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
List of Documents most similar to the query vector.
|
109 |
+
"""
|
110 |
+
...
|
111 |
+
|
112 |
+
async def asimilarity_search_with_score_by_vector(
|
113 |
+
self, embedding: list[float], k: int = 4, **kwargs: Any
|
114 |
+
) -> list[LCDocument]:
|
115 |
+
"""Async return docs most similar to embedding vector.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
embedding: Embedding to look up documents similar to.
|
119 |
+
k: Number of Documents to return. Defaults to 4.
|
120 |
+
**kwargs: Arguments to pass to the search method.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
List of Documents most similar to the query vector.
|
124 |
+
"""
|
125 |
+
|
126 |
+
# This is a temporary workaround to make the similarity search
|
127 |
+
# asynchronous. The proper solution is to make the similarity search
|
128 |
+
# asynchronous in the vector store implementations.
|
129 |
+
return await run_in_executor(
|
130 |
+
None, self.similarity_search_with_score_by_vector, embedding, k=k, **kwargs
|
131 |
+
)
|