ArneBinder commited on
Commit
2cc87ec
·
verified ·
1 Parent(s): a347ab7

new demo setup with langchain retriever

Browse files

based on https://github.com/ArneBinder/pie-document-level/pull/298

Files changed (38) hide show
  1. app.py +471 -279
  2. model_utils.py +339 -65
  3. rendering_utils.py +50 -40
  4. requirements.txt +28 -5
  5. retrieve_and_dump_all_relevant.py +101 -0
  6. retriever/related_span_retriever_with_relations_from_other_docs.yaml +47 -0
  7. src/__init__.py +0 -0
  8. src/hf_pipeline/__init__.py +1 -0
  9. src/hf_pipeline/__pycache__/__init__.cpython-310.pyc +0 -0
  10. src/hf_pipeline/__pycache__/__init__.cpython-39.pyc +0 -0
  11. src/hf_pipeline/__pycache__/feature_extraction.cpython-310.pyc +0 -0
  12. src/hf_pipeline/__pycache__/feature_extraction.cpython-39.pyc +0 -0
  13. src/hf_pipeline/feature_extraction.py +317 -0
  14. src/langchain_modules/__init__.py +9 -0
  15. src/langchain_modules/__pycache__/__init__.cpython-310.pyc +0 -0
  16. src/langchain_modules/__pycache__/__init__.cpython-39.pyc +0 -0
  17. src/langchain_modules/__pycache__/basic_pie_document_store.cpython-39.pyc +0 -0
  18. src/langchain_modules/__pycache__/datasets_pie_document_store.cpython-39.pyc +0 -0
  19. src/langchain_modules/__pycache__/huggingface_span_embeddings.cpython-310.pyc +0 -0
  20. src/langchain_modules/__pycache__/huggingface_span_embeddings.cpython-39.pyc +0 -0
  21. src/langchain_modules/__pycache__/pie_document_store.cpython-39.pyc +0 -0
  22. src/langchain_modules/__pycache__/qdrant_span_vectorstore.cpython-310.pyc +0 -0
  23. src/langchain_modules/__pycache__/qdrant_span_vectorstore.cpython-39.pyc +0 -0
  24. src/langchain_modules/__pycache__/serializable_store.cpython-39.pyc +0 -0
  25. src/langchain_modules/__pycache__/span_embeddings.cpython-310.pyc +0 -0
  26. src/langchain_modules/__pycache__/span_embeddings.cpython-39.pyc +0 -0
  27. src/langchain_modules/__pycache__/span_retriever.cpython-310.pyc +0 -0
  28. src/langchain_modules/__pycache__/span_retriever.cpython-39.pyc +0 -0
  29. src/langchain_modules/__pycache__/span_vectorstore.cpython-39.pyc +0 -0
  30. src/langchain_modules/basic_pie_document_store.py +103 -0
  31. src/langchain_modules/datasets_pie_document_store.py +156 -0
  32. src/langchain_modules/huggingface_span_embeddings.py +192 -0
  33. src/langchain_modules/pie_document_store.py +88 -0
  34. src/langchain_modules/qdrant_span_vectorstore.py +349 -0
  35. src/langchain_modules/serializable_store.py +137 -0
  36. src/langchain_modules/span_embeddings.py +103 -0
  37. src/langchain_modules/span_retriever.py +860 -0
  38. src/langchain_modules/span_vectorstore.py +131 -0
app.py CHANGED
@@ -1,32 +1,56 @@
 
 
 
 
 
 
 
 
 
1
  import json
2
  import logging
3
  import os.path
4
  import re
5
  import tempfile
6
- from functools import partial
7
- from typing import List, Optional, Tuple, Union
8
 
9
  import arxiv
10
  import gradio as gr
11
  import pandas as pd
12
  import requests
13
  import torch
 
14
  from bs4 import BeautifulSoup
15
- from document_store import DocumentStore, get_annotation_from_document
16
- from embedding import EmbeddingModel
17
- from model_utils import annotate_document, create_document, load_models
18
- from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
19
- from pytorch_ie import Pipeline
20
- from pytorch_ie.documents import (
21
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
22
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
 
23
  )
 
 
 
24
  from rendering_utils import HIGHLIGHT_SPANS_JS, render_displacy, render_pretty_table
25
- from transformers import PreTrainedModel, PreTrainedTokenizer
26
- from vector_store import QdrantVectorStore, SimpleVectorStore
 
 
 
27
 
28
  logger = logging.getLogger(__name__)
29
 
 
 
 
 
 
 
 
 
30
  RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
31
  RENDER_WITH_PRETTY_TABLE = "Pretty Table"
32
 
@@ -35,16 +59,31 @@ DEFAULT_MODEL_REVISION = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
35
  # local path
36
  # DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
37
  # DEFAULT_MODEL_REVISION = None
38
- DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
 
 
 
 
39
  DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
- DEFAULT_EMBEDDING_MAX_LENGTH = 512
41
- DEFAULT_EMBEDDING_BATCH_SIZE = 32
42
  DEFAULT_SPLIT_REGEX = "\n\n\n+"
43
  DEFAULT_ARXIV_ID = "1706.03762"
 
 
 
44
 
45
  # Whether to handle segmented entities in the document. If True, labeled_spans are converted
46
  # to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them.
 
47
  HANDLE_PARTS_OF_SAME = True
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def escape_regex(regex: str) -> str:
@@ -59,19 +98,38 @@ def unescape_regex(regex: str) -> str:
59
  return result
60
 
61
 
 
 
 
 
 
62
  def render_annotated_document(
63
- document: Union[
64
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
65
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
66
- ],
67
  render_with: str,
68
  render_kwargs_json: str,
69
  ) -> str:
 
 
 
 
70
  render_kwargs = json.loads(render_kwargs_json)
71
  if render_with == RENDER_WITH_PRETTY_TABLE:
72
- html = render_pretty_table(document, **render_kwargs)
 
 
 
 
 
 
73
  elif render_with == RENDER_WITH_DISPLACY:
74
- html = render_displacy(document, **render_kwargs)
 
 
 
 
 
 
75
  else:
76
  raise ValueError(f"Unknown render_with value: {render_with}")
77
 
@@ -79,84 +137,47 @@ def render_annotated_document(
79
 
80
 
81
  def wrapped_process_text(
82
- text: str,
83
- doc_id: str,
84
- models: Tuple[Pipeline, Optional[EmbeddingModel]],
85
- document_store: DocumentStore,
86
- split_regex_escaped: str,
87
- handle_parts_of_same: bool = False,
88
- ) -> Tuple[
89
- dict,
90
- Union[
91
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
92
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
93
- ],
94
- ]:
95
  try:
96
- document = create_document(
97
- text=text,
98
- doc_id=doc_id,
99
- split_regex=unescape_regex(split_regex_escaped)
100
- if len(split_regex_escaped) > 0
101
- else None,
102
- )
103
- document = annotate_document(
104
- document=document,
105
- annotation_pipeline=models[0],
106
- embedding_model=models[1],
107
- handle_parts_of_same=handle_parts_of_same,
108
- )
109
- document_store.add_document(document)
110
  except Exception as e:
111
  raise gr.Error(f"Failed to process text: {e}")
112
- # remove the embeddings because they are very large
113
- if document.metadata.get("embeddings"):
114
- document.metadata = {k: v for k, v in document.metadata.items() if k != "embeddings"}
115
  # Return as dict and document to avoid serialization issues
116
- return document.asdict(), document
117
 
118
 
119
  def process_uploaded_files(
120
- file_names: List[str],
121
- models: Tuple[Pipeline, Optional[EmbeddingModel]],
122
- document_store: DocumentStore,
123
- split_regex_escaped: str,
124
- show_max_cross_doc_sims: bool = False,
125
- min_similarity: float = 0.95,
126
- handle_parts_of_same: bool = False,
127
  ) -> pd.DataFrame:
128
  try:
129
- new_documents = []
 
130
  for file_name in file_names:
131
  if file_name.lower().endswith(".txt"):
132
  # read the file content
133
  with open(file_name, "r", encoding="utf-8") as f:
134
  text = f.read()
135
  base_file_name = os.path.basename(file_name)
136
- gr.Info(f"Processing file '{base_file_name}' ...")
137
- new_document = create_document(
138
- text=text,
139
- doc_id=base_file_name,
140
- split_regex=unescape_regex(split_regex_escaped)
141
- if len(split_regex_escaped) > 0
142
- else None,
143
- )
144
- new_document = annotate_document(
145
- document=new_document,
146
- annotation_pipeline=models[0],
147
- embedding_model=models[1],
148
- handle_parts_of_same=handle_parts_of_same,
149
- )
150
- new_documents.append(new_document)
151
  else:
152
  raise gr.Error(f"Unsupported file format: {file_name}")
153
- document_store.add_documents(new_documents)
154
  except Exception as e:
155
  raise gr.Error(f"Failed to process uploaded files: {e}")
156
 
157
- return document_store.overview(
158
- with_max_cross_doc_sims=show_max_cross_doc_sims, min_similarity=min_similarity
159
- )
 
 
 
 
 
 
 
 
160
 
161
 
162
  def open_accordion():
@@ -167,30 +188,34 @@ def close_accordion():
167
  return gr.Accordion(open=False)
168
 
169
 
170
- def select_processed_document(
171
  evt: gr.SelectData,
172
- processed_documents_df: pd.DataFrame,
173
- document_store: DocumentStore,
174
- ) -> Union[
175
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
176
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
177
- ]:
 
 
 
 
 
 
 
 
 
178
  row_idx, col_idx = evt.index
179
- col_name = processed_documents_df.columns[col_idx]
180
- if not col_name.endswith("doc_id"):
181
- col_name = "doc_id"
182
- doc_id = processed_documents_df.iloc[row_idx][col_name]
183
- doc = document_store.get_document(doc_id, with_embeddings=False)
184
- return doc
185
 
186
 
187
  def set_relation_types(
188
- models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
189
  default: Optional[List[str]] = None,
190
  ) -> gr.Dropdown:
191
- arg_pipeline = models[0]
192
- if isinstance(arg_pipeline.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
193
- relation_types = arg_pipeline.taskmodule.labels_per_layer["binary_relations"]
194
  else:
195
  raise gr.Error("Unsupported taskmodule for relation types")
196
 
@@ -202,21 +227,64 @@ def set_relation_types(
202
  )
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def download_processed_documents(
206
- document_store: DocumentStore,
207
- file_name: str = "processed_documents.json",
208
- ) -> str:
 
 
 
 
 
209
  file_path = os.path.join(tempfile.gettempdir(), file_name)
210
- document_store.save_to_file(file_path, indent=2)
211
- return file_path
 
 
 
212
 
213
 
214
  def upload_processed_documents(
215
  file_name: str,
216
- document_store: DocumentStore,
217
  ) -> pd.DataFrame:
218
- document_store.add_documents_from_file(file_name)
219
- return document_store.overview()
 
 
220
 
221
 
222
  def clean_spaces(text: str) -> str:
@@ -253,21 +321,21 @@ def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[st
253
  except arxiv.HTTPError as e:
254
  raise gr.Error(f"Failed to fetch arXiv data: {e}")
255
  if len(result) == 0:
256
- raise gr.Error(f"Could not find any paper with arXive ID '{arxiv_id}'")
257
  first_result = result[0]
258
  if abstract_only:
259
  abstract_clean = first_result.summary.replace("\n", " ")
260
  return abstract_clean, first_result.entry_id
261
  if "/abs/" not in first_result.entry_id:
262
  raise gr.Error(
263
- f"Could not create the HTML URL for arXive ID '{arxiv_id}' because its entry ID has "
264
  f"an unexpected format: {first_result.entry_id}"
265
  )
266
  html_url = first_result.entry_id.replace("/abs/", "/html/")
267
  request_result = requests.get(html_url)
268
  if request_result.status_code != 200:
269
  raise gr.Error(
270
- f"Could not fetch the HTML content for arXive ID '{arxiv_id}', status code: "
271
  f"{request_result.status_code}"
272
  )
273
  html_content = request_result.text
@@ -275,19 +343,31 @@ def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[st
275
  return text_clean, html_url
276
 
277
 
 
 
 
 
 
 
 
 
 
 
278
  def main():
279
 
280
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
281
 
282
- print("Loading models ...")
283
- argumentation_model, embedding_model = load_models(
284
  model_name=DEFAULT_MODEL_NAME,
285
  revision=DEFAULT_MODEL_REVISION,
286
- embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
287
- embedding_max_length=DEFAULT_EMBEDDING_MAX_LENGTH,
288
- embedding_batch_size=DEFAULT_EMBEDDING_BATCH_SIZE,
289
  device=DEFAULT_DEVICE,
290
  )
 
 
 
 
 
291
 
292
  default_render_kwargs = {
293
  "entity_options": {
@@ -315,21 +395,10 @@ def main():
315
  }
316
 
317
  with gr.Blocks() as demo:
318
- document_store_state = gr.State(
319
- DocumentStore(
320
- span_annotation_caption="adu",
321
- relation_annotation_caption="relation",
322
- vector_store=QdrantVectorStore(),
323
- document_type=TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
324
- if not HANDLE_PARTS_OF_SAME
325
- else TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
326
- span_layer_name="labeled_spans"
327
- if not HANDLE_PARTS_OF_SAME
328
- else "labeled_multi_spans",
329
- )
330
- )
331
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
332
- models_state = gr.State((argumentation_model, embedding_model))
 
 
333
  with gr.Row():
334
  with gr.Column(scale=1):
335
  doc_id = gr.Textbox(
@@ -341,63 +410,54 @@ def main():
341
  lines=20,
342
  value=example_text,
343
  )
344
- with gr.Accordion("Load Text from arXiv", open=False):
345
- arxiv_id = gr.Textbox(
346
- label="arXiv paper ID",
347
- placeholder=f"e.g. {DEFAULT_ARXIV_ID}",
348
- max_lines=1,
349
- )
350
- load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False)
351
- load_arxiv_btn = gr.Button("Load Text from arXiv", variant="secondary")
352
- load_arxiv_btn.click(
353
- fn=load_text_from_arxiv,
354
- inputs=[arxiv_id, load_arxiv_only_abstract],
355
- outputs=[doc_text, doc_id],
356
- )
357
  with gr.Accordion("Model Configuration", open=False):
358
- model_name = gr.Textbox(
359
- label="Model Name",
360
- value=DEFAULT_MODEL_NAME,
361
- )
362
- model_revision = gr.Textbox(
363
- label="Model Revision",
364
- value=DEFAULT_MODEL_REVISION,
365
- )
366
- embedding_model_name = gr.Textbox(
367
- label=f"Embedding Model Name (e.g. {DEFAULT_EMBEDDING_MODEL_NAME})",
368
- value=DEFAULT_EMBEDDING_MODEL_NAME,
369
- )
370
- embedding_max_length = gr.Slider(
371
- label="Embedding Model Max Length",
372
- minimum=16,
373
- maximum=2048,
374
- step=8,
375
- value=DEFAULT_EMBEDDING_MAX_LENGTH,
376
- )
377
- embedding_batch_size = gr.Slider(
378
- label="Embedding Model Batch Size",
379
- minimum=1,
380
- maximum=128,
381
- step=1,
382
- value=DEFAULT_EMBEDDING_BATCH_SIZE,
383
- )
384
  device = gr.Textbox(
385
  label="Device (e.g. 'cuda' or 'cpu')",
386
  value=DEFAULT_DEVICE,
387
  )
388
- load_models_btn = gr.Button("Load Models")
389
- load_models_btn.click(
390
- fn=load_models,
391
- inputs=[
392
- model_name,
393
- model_revision,
394
- embedding_model_name,
395
- embedding_max_length,
396
- embedding_batch_size,
397
- device,
398
- ],
399
- outputs=models_state,
 
 
 
 
 
 
 
 
400
  )
 
401
  split_regex_escaped = gr.Textbox(
402
  label="Regex to partition the text",
403
  placeholder="Regular expression pattern to split the text into partitions",
@@ -406,12 +466,12 @@ def main():
406
 
407
  predict_btn = gr.Button("Analyse")
408
 
409
- document_state = gr.State()
410
-
411
  with gr.Column(scale=1):
412
 
413
- with gr.Accordion("See plain result ...", open=False) as output_accordion:
414
- document_json = gr.JSON(label="Model Output")
 
 
415
 
416
  with gr.Accordion("Render Options", open=False):
417
  render_as = gr.Dropdown(
@@ -424,9 +484,11 @@ def main():
424
  lines=5,
425
  value=json.dumps(default_render_kwargs, indent=2),
426
  )
427
- render_btn = gr.Button("Re-render")
428
 
429
- rendered_output = gr.HTML(label="Rendered Output")
 
 
430
 
431
  with gr.Column(scale=1):
432
  with gr.Accordion(
@@ -436,14 +498,11 @@ def main():
436
  headers=["id", "num_adus", "num_relations"],
437
  interactive=False,
438
  )
439
- show_max_cross_docu_sims = gr.Checkbox(
440
- label="Show max cross-document similarities", value=False
441
- )
442
  gr.Markdown("Data Snapshot:")
443
  with gr.Row():
444
  download_processed_documents_btn = gr.DownloadButton("Download")
445
  upload_processed_documents_btn = gr.UploadButton(
446
- "Upload", file_types=["json"]
447
  )
448
 
449
  upload_btn = gr.UploadButton(
@@ -452,9 +511,22 @@ def main():
452
  file_count="multiple",
453
  )
454
 
455
- with gr.Accordion("Selected ADU", open=False):
456
- selected_adu_id = gr.Textbox(label="ID", elem_id="selected_adu_id")
457
- selected_adu_text = gr.Textbox(label="Text")
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
  with gr.Accordion("Retrieval Configuration", open=False):
460
  min_similarity = gr.Slider(
@@ -462,174 +534,294 @@ def main():
462
  minimum=0.0,
463
  maximum=1.0,
464
  step=0.01,
465
- value=0.95,
466
  )
467
  top_k = gr.Slider(
468
  label="Top K",
469
  minimum=2,
470
  maximum=50,
471
  step=1,
472
- value=20,
473
  )
474
- retrieve_similar_adus_btn = gr.Button("Retrieve similar ADUs")
475
- similar_adus = gr.DataFrame(headers=["doc_id", "adu_id", "score", "text"])
476
-
477
- all2all_adu_similarities_button = gr.Button(
478
- "Compute all ADU-to-ADU similarities"
479
  )
480
- all2all_adu_similarities = gr.DataFrame(
481
- headers=["sim_score", "doc_id", "other_doc_id", "text", "other_text"]
482
  )
483
-
484
- relation_types = set_relation_types(
485
- models_state.value, default=["supports", "contradicts"]
 
 
 
 
 
 
 
 
 
486
  )
487
 
488
- # retrieve_relevant_adus_btn = gr.Button("Retrieve relevant ADUs")
489
- relevant_adus = gr.DataFrame(
490
- label="Relevant ADUs from other documents",
491
- headers=[
492
- "relation",
493
- "adu",
494
- "reference_adu",
495
- "doc_id",
496
- "sim_score",
497
- "rel_score",
498
- ],
 
 
 
 
 
499
  interactive=False,
 
500
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  render_event_kwargs = dict(
503
- fn=render_annotated_document,
504
- inputs=[document_state, render_as, render_kwargs],
 
 
 
 
 
505
  outputs=rendered_output,
506
  )
507
 
508
  show_overview_kwargs = dict(
509
- fn=lambda document_store, show_max_sims, min_sim: document_store.overview(
510
- with_max_cross_doc_sims=show_max_sims, min_similarity=min_sim
511
  ),
512
- inputs=[document_store_state, show_max_cross_docu_sims, min_similarity],
513
  outputs=[processed_documents_df],
514
  )
515
- predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
516
- fn=partial(wrapped_process_text, handle_parts_of_same=HANDLE_PARTS_OF_SAME),
 
 
 
 
 
 
 
 
 
517
  inputs=[
518
  doc_text,
519
  doc_id,
520
- models_state,
521
- document_store_state,
522
  split_regex_escaped,
523
  ],
524
- outputs=[document_json, document_state],
525
  api_name="predict",
526
- ).success(**show_overview_kwargs)
527
  render_btn.click(**render_event_kwargs, api_name="render")
528
 
529
- document_state.change(
530
- fn=lambda doc: doc.asdict(),
531
- inputs=[document_state],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  outputs=[document_json],
533
- ).success(close_accordion, inputs=[], outputs=[output_accordion]).then(
534
- **render_event_kwargs
535
  )
536
 
537
  upload_btn.upload(
538
  fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
539
  ).then(
540
- fn=partial(process_uploaded_files, handle_parts_of_same=HANDLE_PARTS_OF_SAME),
 
 
 
 
 
 
541
  inputs=[
542
  upload_btn,
543
- models_state,
544
- document_store_state,
545
  split_regex_escaped,
546
- show_max_cross_docu_sims,
547
- min_similarity,
548
  ],
549
  outputs=[processed_documents_df],
550
  )
551
  processed_documents_df.select(
552
- select_processed_document,
553
- inputs=[processed_documents_df, document_store_state],
554
- outputs=[document_state],
555
  )
556
- show_max_cross_docu_sims.change(**show_overview_kwargs)
557
 
558
  download_processed_documents_btn.click(
559
- fn=partial(download_processed_documents, file_name="processed_documents.zip"),
560
- inputs=[document_store_state],
 
 
561
  outputs=[download_processed_documents_btn],
562
  )
563
  upload_processed_documents_btn.upload(
564
- fn=upload_processed_documents,
565
- inputs=[upload_processed_documents_btn, document_store_state],
 
 
566
  outputs=[processed_documents_df],
567
  )
568
 
569
  retrieve_relevant_adus_event_kwargs = dict(
570
- fn=partial(
571
- DocumentStore.get_related_annotations_from_other_documents_df,
572
- columns=relevant_adus.headers,
 
 
 
 
573
  ),
574
  inputs=[
575
- document_store_state,
576
  selected_adu_id,
577
- document_state,
578
  min_similarity,
579
  top_k,
580
- relation_types,
581
  ],
582
- outputs=[relevant_adus],
 
 
 
 
 
583
  )
584
 
585
  selected_adu_id.change(
586
- fn=partial(
587
- get_annotation_from_document,
588
- annotation_layer="labeled_spans"
589
- if not HANDLE_PARTS_OF_SAME
590
- else "labeled_multi_spans",
591
- use_predictions=True,
592
  ),
593
- inputs=[document_state, selected_adu_id],
594
  outputs=[selected_adu_text],
595
  ).success(**retrieve_relevant_adus_event_kwargs)
596
 
597
  retrieve_similar_adus_btn.click(
598
- fn=lambda document_store, ann_id, document, min_sim, k: document_store.get_similar_annotations_df(
599
- ref_annotation_id=ann_id,
600
- ref_document=document,
601
- min_similarity=min_sim,
602
- top_k=k,
603
- annotation_layer="labeled_spans"
604
- if not HANDLE_PARTS_OF_SAME
605
- else "labeled_multi_spans",
606
  ),
607
  inputs=[
608
- document_store_state,
609
  selected_adu_id,
610
- document_state,
611
  min_similarity,
612
  top_k,
613
  ],
614
- outputs=[similar_adus],
 
 
 
 
 
615
  )
616
 
617
- models_state.change(
618
- fn=set_relation_types,
619
- inputs=[models_state],
620
- outputs=[relation_types],
 
 
 
 
 
 
 
 
 
 
 
621
  )
622
- all2all_adu_similarities_button.click(
623
- fn=partial(
624
- DocumentStore.get_all2all_adu_similarities,
625
- columns=all2all_adu_similarities.headers,
 
 
 
 
626
  ),
627
- inputs=[document_store_state, min_similarity],
628
- outputs=[all2all_adu_similarities],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  )
630
 
631
- # retrieve_relevant_adus_btn.click(
632
- # **retrieve_relevant_adus_event_kwargs
 
 
633
  # )
634
 
635
  rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
  import json
11
  import logging
12
  import os.path
13
  import re
14
  import tempfile
15
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 
16
 
17
  import arxiv
18
  import gradio as gr
19
  import pandas as pd
20
  import requests
21
  import torch
22
+ import yaml
23
  from bs4 import BeautifulSoup
24
+ from model_utils import (
25
+ add_annotated_pie_documents_from_dataset,
26
+ load_argumentation_model,
27
+ load_retriever,
28
+ process_texts,
29
+ retrieve_all_relevant_spans,
30
+ retrieve_all_similar_spans,
31
+ retrieve_relevant_spans,
32
+ retrieve_similar_spans,
33
  )
34
+ from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
35
+ from pytorch_ie import Annotation, Pipeline
36
+ from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
37
  from rendering_utils import HIGHLIGHT_SPANS_JS, render_displacy, render_pretty_table
38
+
39
+ from src.langchain_modules import (
40
+ DocumentAwareSpanRetriever,
41
+ DocumentAwareSpanRetrieverWithRelations,
42
+ )
43
 
44
  logger = logging.getLogger(__name__)
45
 
46
+
47
+ def load_retriever_config(path: str) -> str:
48
+ with open(path, "r") as file:
49
+ yaml_string = file.read()
50
+ config = yaml.safe_load(yaml_string)
51
+ return yaml.dump(config)
52
+
53
+
54
  RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
55
  RENDER_WITH_PRETTY_TABLE = "Pretty Table"
56
 
 
59
  # local path
60
  # DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
61
  # DEFAULT_MODEL_REVISION = None
62
+ DEFAULT_RETRIEVER_CONFIG = load_retriever_config(
63
+ "configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml"
64
+ )
65
+ # 0.943180 from data_dir="predictions/default/2024-10-15_23-40-18"
66
+ DEFAULT_MIN_SIMILARITY = 0.95
67
  DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
68
  DEFAULT_SPLIT_REGEX = "\n\n\n+"
69
  DEFAULT_ARXIV_ID = "1706.03762"
70
+ DEFAULT_LOAD_PIE_DATASET_KWARGS_STR = json.dumps(
71
+ dict(path="pie/sciarg", name="resolve_parts_of_same", split="train"), indent=2
72
+ )
73
 
74
  # Whether to handle segmented entities in the document. If True, labeled_spans are converted
75
  # to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them.
76
+ # This requires the networkx package to be installed.
77
  HANDLE_PARTS_OF_SAME = True
78
+ LAYER_CAPTIONS = {
79
+ "labeled_multi_spans": "adus",
80
+ "binary_relations": "relations",
81
+ "labeled_partitions": "partitions",
82
+ }
83
+ RELATION_NAME_MAPPING = {
84
+ "supports_reversed": "supported by",
85
+ "contradicts_reversed": "contradicts",
86
+ }
87
 
88
 
89
  def escape_regex(regex: str) -> str:
 
98
  return result
99
 
100
 
101
+ def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict:
102
+ document = retriever.get_document(doc_id=doc_id)
103
+ return retriever.docstore.as_dict(document)
104
+
105
+
106
  def render_annotated_document(
107
+ retriever: DocumentAwareSpanRetrieverWithRelations,
108
+ document_id: str,
 
 
109
  render_with: str,
110
  render_kwargs_json: str,
111
  ) -> str:
112
+ text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
113
+ retriever=retriever, document_id=document_id
114
+ )
115
+
116
  render_kwargs = json.loads(render_kwargs_json)
117
  if render_with == RENDER_WITH_PRETTY_TABLE:
118
+ html = render_pretty_table(
119
+ text=text,
120
+ spans=spans,
121
+ span_id2idx=span_id2idx,
122
+ binary_relations=relations,
123
+ **render_kwargs,
124
+ )
125
  elif render_with == RENDER_WITH_DISPLACY:
126
+ html = render_displacy(
127
+ text=text,
128
+ spans=spans,
129
+ span_id2idx=span_id2idx,
130
+ binary_relations=relations,
131
+ **render_kwargs,
132
+ )
133
  else:
134
  raise ValueError(f"Unknown render_with value: {render_with}")
135
 
 
137
 
138
 
139
  def wrapped_process_text(
140
+ doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs
141
+ ) -> str:
 
 
 
 
 
 
 
 
 
 
 
142
  try:
143
+ process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  except Exception as e:
145
  raise gr.Error(f"Failed to process text: {e}")
 
 
 
146
  # Return as dict and document to avoid serialization issues
147
+ return doc_id
148
 
149
 
150
  def process_uploaded_files(
151
+ file_names: List[str], retriever: DocumentAwareSpanRetriever, **kwargs
 
 
 
 
 
 
152
  ) -> pd.DataFrame:
153
  try:
154
+ doc_ids = []
155
+ texts = []
156
  for file_name in file_names:
157
  if file_name.lower().endswith(".txt"):
158
  # read the file content
159
  with open(file_name, "r", encoding="utf-8") as f:
160
  text = f.read()
161
  base_file_name = os.path.basename(file_name)
162
+ doc_ids.append(base_file_name)
163
+ texts.append(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  else:
165
  raise gr.Error(f"Unsupported file format: {file_name}")
166
+ process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
167
  except Exception as e:
168
  raise gr.Error(f"Failed to process uploaded files: {e}")
169
 
170
+ return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True)
171
+
172
+
173
+ def wrapped_add_annotated_pie_documents_from_dataset(
174
+ retriever: DocumentAwareSpanRetriever, verbose: bool, **kwargs
175
+ ) -> pd.DataFrame:
176
+ try:
177
+ add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs)
178
+ except Exception as e:
179
+ raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}")
180
+ return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True)
181
 
182
 
183
  def open_accordion():
 
188
  return gr.Accordion(open=False)
189
 
190
 
191
+ def get_cell_for_fixed_column_from_df(
192
  evt: gr.SelectData,
193
+ df: pd.DataFrame,
194
+ column: str,
195
+ ) -> Any:
196
+ """Get the value of the fixed column for the selected row in the DataFrame.
197
+ This is required can *not* with a lambda function because that will not get
198
+ the evt parameter.
199
+
200
+ Args:
201
+ evt: The event object.
202
+ df: The DataFrame.
203
+ column: The name of the column.
204
+
205
+ Returns:
206
+ The value of the fixed column for the selected row.
207
+ """
208
  row_idx, col_idx = evt.index
209
+ doc_id = df.iloc[row_idx][column]
210
+ return doc_id
 
 
 
 
211
 
212
 
213
  def set_relation_types(
214
+ argumentation_model: Pipeline,
215
  default: Optional[List[str]] = None,
216
  ) -> gr.Dropdown:
217
+ if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
218
+ relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
 
219
  else:
220
  raise gr.Error("Unsupported taskmodule for relation types")
221
 
 
227
  )
228
 
229
 
230
+ def get_span_annotation(
231
+ retriever: DocumentAwareSpanRetriever,
232
+ span_id: str,
233
+ ) -> Annotation:
234
+ if span_id.strip() == "":
235
+ raise gr.Error("No span selected.")
236
+ try:
237
+ return retriever.get_span_by_id(span_id=span_id)
238
+ except Exception as e:
239
+ raise gr.Error(f"Failed to retrieve span annotation: {e}")
240
+
241
+
242
+ def get_text_spans_and_relations_from_document(
243
+ retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str
244
+ ) -> Tuple[
245
+ str,
246
+ Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
247
+ Dict[str, int],
248
+ Sequence[BinaryRelation],
249
+ ]:
250
+ document = retriever.get_document(doc_id=document_id)
251
+ pie_document = retriever.docstore.unwrap(document)
252
+ use_predicted_annotations = retriever.use_predicted_annotations(document)
253
+ spans = retriever.get_base_layer(
254
+ pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
255
+ )
256
+ relations = retriever.get_relation_layer(
257
+ pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
258
+ )
259
+ span_id2idx = retriever.get_span_id2idx_from_doc(document)
260
+ return pie_document.text, spans, span_id2idx, relations
261
+
262
+
263
  def download_processed_documents(
264
+ retriever: DocumentAwareSpanRetriever,
265
+ file_name: str = "retriever_store",
266
+ ) -> Optional[str]:
267
+ if len(retriever.docstore) == 0:
268
+ gr.Warning("No documents to download.")
269
+ return None
270
+
271
+ # zip the directory
272
  file_path = os.path.join(tempfile.gettempdir(), file_name)
273
+
274
+ gr.Info(f"Zipping the retriever store to '{file_name}' ...")
275
+ result_file_path = retriever.save_to_archive(base_name=file_path, format="zip")
276
+
277
+ return result_file_path
278
 
279
 
280
  def upload_processed_documents(
281
  file_name: str,
282
+ retriever: DocumentAwareSpanRetriever,
283
  ) -> pd.DataFrame:
284
+ # load the documents from the zip file or directory
285
+ retriever.load_from_disc(file_name)
286
+ # return the overview of the document store
287
+ return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True)
288
 
289
 
290
  def clean_spaces(text: str) -> str:
 
321
  except arxiv.HTTPError as e:
322
  raise gr.Error(f"Failed to fetch arXiv data: {e}")
323
  if len(result) == 0:
324
+ raise gr.Error(f"Could not find any paper with arXiv ID '{arxiv_id}'")
325
  first_result = result[0]
326
  if abstract_only:
327
  abstract_clean = first_result.summary.replace("\n", " ")
328
  return abstract_clean, first_result.entry_id
329
  if "/abs/" not in first_result.entry_id:
330
  raise gr.Error(
331
+ f"Could not create the HTML URL for arXiv ID '{arxiv_id}' because its entry ID has "
332
  f"an unexpected format: {first_result.entry_id}"
333
  )
334
  html_url = first_result.entry_id.replace("/abs/", "/html/")
335
  request_result = requests.get(html_url)
336
  if request_result.status_code != 200:
337
  raise gr.Error(
338
+ f"Could not fetch the HTML content for arXiv ID '{arxiv_id}', status code: "
339
  f"{request_result.status_code}"
340
  )
341
  html_content = request_result.text
 
343
  return text_clean, html_url
344
 
345
 
346
+ def process_text_from_arxiv(
347
+ arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs
348
+ ) -> str:
349
+ try:
350
+ text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only)
351
+ except Exception as e:
352
+ raise gr.Error(f"Failed to load text from arXiv: {e}")
353
+ return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs)
354
+
355
+
356
  def main():
357
 
358
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
359
 
360
+ print("Loading argumentation model ...")
361
+ argumentation_model = load_argumentation_model(
362
  model_name=DEFAULT_MODEL_NAME,
363
  revision=DEFAULT_MODEL_REVISION,
 
 
 
364
  device=DEFAULT_DEVICE,
365
  )
366
+ print("Loading retriever ...")
367
+ retriever = load_retriever(
368
+ DEFAULT_RETRIEVER_CONFIG, device=DEFAULT_DEVICE, config_format="yaml"
369
+ )
370
+ print("Models loaded.")
371
 
372
  default_render_kwargs = {
373
  "entity_options": {
 
395
  }
396
 
397
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
399
+ # models_state = gr.State((argumentation_model, embedding_model))
400
+ argumentation_model_state = gr.State((argumentation_model,))
401
+ retriever_state = gr.State((retriever,))
402
  with gr.Row():
403
  with gr.Column(scale=1):
404
  doc_id = gr.Textbox(
 
410
  lines=20,
411
  value=example_text,
412
  )
413
+
 
 
 
 
 
 
 
 
 
 
 
 
414
  with gr.Accordion("Model Configuration", open=False):
415
+ with gr.Accordion("argumentation structure", open=True):
416
+ model_name = gr.Textbox(
417
+ label="Model Name",
418
+ value=DEFAULT_MODEL_NAME,
419
+ )
420
+ model_revision = gr.Textbox(
421
+ label="Model Revision",
422
+ value=DEFAULT_MODEL_REVISION,
423
+ )
424
+ load_arg_model_btn = gr.Button("Load Argumentation Model")
425
+
426
+ with gr.Accordion("retriever", open=True):
427
+ retriever_config = gr.Textbox(
428
+ label="Retriever Configuration",
429
+ placeholder="Configuration for the retriever",
430
+ value=DEFAULT_RETRIEVER_CONFIG,
431
+ lines=len(DEFAULT_RETRIEVER_CONFIG.split("\n")),
432
+ )
433
+ load_retriever_btn = gr.Button("Load Retriever")
434
+
 
 
 
 
 
 
435
  device = gr.Textbox(
436
  label="Device (e.g. 'cuda' or 'cpu')",
437
  value=DEFAULT_DEVICE,
438
  )
439
+ load_arg_model_btn.click(
440
+ fn=lambda _model_name, _model_revision, _device: (
441
+ load_argumentation_model(
442
+ model_name=_model_name, revision=_model_revision, device=_device
443
+ ),
444
+ ),
445
+ inputs=[model_name, model_revision, device],
446
+ outputs=argumentation_model_state,
447
+ )
448
+ load_retriever_btn.click(
449
+ fn=lambda _retriever_config, _device, _previous_retriever: (
450
+ load_retriever(
451
+ retriever_config=_retriever_config,
452
+ device=_device,
453
+ previous_retriever=_previous_retriever[0],
454
+ config_format="yaml",
455
+ ),
456
+ ),
457
+ inputs=[retriever_config, device, retriever_state],
458
+ outputs=retriever_state,
459
  )
460
+
461
  split_regex_escaped = gr.Textbox(
462
  label="Regex to partition the text",
463
  placeholder="Regular expression pattern to split the text into partitions",
 
466
 
467
  predict_btn = gr.Button("Analyse")
468
 
 
 
469
  with gr.Column(scale=1):
470
 
471
+ selected_document_id = gr.Textbox(
472
+ label="Selected Document", max_lines=1, interactive=False
473
+ )
474
+ rendered_output = gr.HTML(label="Rendered Output")
475
 
476
  with gr.Accordion("Render Options", open=False):
477
  render_as = gr.Dropdown(
 
484
  lines=5,
485
  value=json.dumps(default_render_kwargs, indent=2),
486
  )
487
+ render_btn = gr.Button("Re-render")
488
 
489
+ with gr.Accordion("See plain result ...", open=False) as document_json_accordion:
490
+ get_document_json_btn = gr.Button("Fetch annotated document as JSON")
491
+ document_json = gr.JSON(label="Model Output")
492
 
493
  with gr.Column(scale=1):
494
  with gr.Accordion(
 
498
  headers=["id", "num_adus", "num_relations"],
499
  interactive=False,
500
  )
 
 
 
501
  gr.Markdown("Data Snapshot:")
502
  with gr.Row():
503
  download_processed_documents_btn = gr.DownloadButton("Download")
504
  upload_processed_documents_btn = gr.UploadButton(
505
+ "Upload", file_types=["files"]
506
  )
507
 
508
  upload_btn = gr.UploadButton(
 
511
  file_count="multiple",
512
  )
513
 
514
+ with gr.Accordion("Import text from arXiv", open=False):
515
+ arxiv_id = gr.Textbox(
516
+ label="arXiv paper ID",
517
+ placeholder=f"e.g. {DEFAULT_ARXIV_ID}",
518
+ max_lines=1,
519
+ )
520
+ load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False)
521
+ load_arxiv_btn = gr.Button("Load & process from arXiv", variant="secondary")
522
+
523
+ with gr.Accordion("Import annotated PIE dataset", open=False):
524
+ load_pie_dataset_kwargs_str = gr.Textbox(
525
+ label="Parameters for Loading the PIE Dataset",
526
+ value=DEFAULT_LOAD_PIE_DATASET_KWARGS_STR,
527
+ lines=len(DEFAULT_LOAD_PIE_DATASET_KWARGS_STR.split("\n")),
528
+ )
529
+ load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
530
 
531
  with gr.Accordion("Retrieval Configuration", open=False):
532
  min_similarity = gr.Slider(
 
534
  minimum=0.0,
535
  maximum=1.0,
536
  step=0.01,
537
+ value=DEFAULT_MIN_SIMILARITY,
538
  )
539
  top_k = gr.Slider(
540
  label="Top K",
541
  minimum=2,
542
  maximum=50,
543
  step=1,
544
+ value=10,
545
  )
546
+ retrieve_similar_adus_btn = gr.Button(
547
+ "Retrieve *similar* ADUs for *selected* ADU"
 
 
 
548
  )
549
+ similar_adus_df = gr.DataFrame(
550
+ headers=["doc_id", "adu_id", "score", "text"], interactive=False
551
  )
552
+ retrieve_all_similar_adus_btn = gr.Button(
553
+ "Retrieve *similar* ADUs for *all* ADUs in the document"
554
+ )
555
+ all_similar_adus_df = gr.DataFrame(
556
+ headers=["doc_id", "query_adu_id", "adu_id", "score", "text"],
557
+ interactive=False,
558
+ )
559
+ retrieve_all_relevant_adus_btn = gr.Button(
560
+ "Retrieve *relevant* ADUs for *all* ADUs in the document"
561
+ )
562
+ all_relevant_adus_df = gr.DataFrame(
563
+ headers=["doc_id", "adu_id", "score", "text"], interactive=False
564
  )
565
 
566
+ # currently not used
567
+ # relation_types = set_relation_types(
568
+ # argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"]
569
+ # )
570
+
571
+ # Dummy textbox to hold the hover adu id. On click on the rendered output,
572
+ # its content will be copied to selected_adu_id which will trigger the retrieval.
573
+ hover_adu_id = gr.Textbox(
574
+ label="ID (hover)",
575
+ elem_id="hover_adu_id",
576
+ interactive=False,
577
+ visible=False,
578
+ )
579
+ selected_adu_id = gr.Textbox(
580
+ label="ID (selected)",
581
+ elem_id="selected_adu_id",
582
  interactive=False,
583
+ visible=False,
584
  )
585
+ selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False)
586
+
587
+ with gr.Accordion("Relevant ADUs from other documents", open=True):
588
+
589
+ relevant_adus_df = gr.DataFrame(
590
+ headers=[
591
+ "relation",
592
+ "adu",
593
+ "reference_adu",
594
+ "doc_id",
595
+ "sim_score",
596
+ "rel_score",
597
+ ],
598
+ interactive=False,
599
+ )
600
 
601
  render_event_kwargs = dict(
602
+ fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document(
603
+ retriever=_retriever[0],
604
+ document_id=_document_id,
605
+ render_with=_render_as,
606
+ render_kwargs_json=_render_kwargs,
607
+ ),
608
+ inputs=[retriever_state, selected_document_id, render_as, render_kwargs],
609
  outputs=rendered_output,
610
  )
611
 
612
  show_overview_kwargs = dict(
613
+ fn=lambda _retriever: _retriever[0].docstore.overview(
614
+ layer_captions=LAYER_CAPTIONS, use_predictions=True
615
  ),
616
+ inputs=[retriever_state],
617
  outputs=[processed_documents_df],
618
  )
619
+ predict_btn.click(
620
+ fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text(
621
+ text=_doc_text,
622
+ doc_id=_doc_id,
623
+ argumentation_model=_argumentation_model[0],
624
+ retriever=_retriever[0],
625
+ split_regex_escaped=(
626
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
627
+ ),
628
+ handle_parts_of_same=HANDLE_PARTS_OF_SAME,
629
+ ),
630
  inputs=[
631
  doc_text,
632
  doc_id,
633
+ argumentation_model_state,
634
+ retriever_state,
635
  split_regex_escaped,
636
  ],
637
+ outputs=[selected_document_id],
638
  api_name="predict",
639
+ ).success(**show_overview_kwargs).success(**render_event_kwargs)
640
  render_btn.click(**render_event_kwargs, api_name="render")
641
 
642
+ load_arxiv_btn.click(
643
+ fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv(
644
+ arxiv_id=_arxiv_id,
645
+ abstract_only=_load_arxiv_only_abstract,
646
+ argumentation_model=_argumentation_model[0],
647
+ retriever=_retriever[0],
648
+ split_regex_escaped=(
649
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
650
+ ),
651
+ handle_parts_of_same=HANDLE_PARTS_OF_SAME,
652
+ ),
653
+ inputs=[
654
+ arxiv_id,
655
+ load_arxiv_only_abstract,
656
+ argumentation_model_state,
657
+ retriever_state,
658
+ split_regex_escaped,
659
+ ],
660
+ outputs=[selected_document_id],
661
+ api_name="predict",
662
+ ).success(**show_overview_kwargs)
663
+
664
+ load_pie_dataset_btn.click(
665
+ fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
666
+ ).then(
667
+ fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset(
668
+ retriever=_retriever[0], verbose=True, **json.loads(_load_pie_dataset_kwargs_str)
669
+ ),
670
+ inputs=[retriever_state, load_pie_dataset_kwargs_str],
671
+ outputs=[processed_documents_df],
672
+ )
673
+
674
+ selected_document_id.change(**render_event_kwargs)
675
+
676
+ get_document_json_btn.click(
677
+ fn=lambda _retriever, _document_id: get_document_as_dict(
678
+ retriever=_retriever[0], doc_id=_document_id
679
+ ),
680
+ inputs=[retriever_state, selected_document_id],
681
  outputs=[document_json],
 
 
682
  )
683
 
684
  upload_btn.upload(
685
  fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
686
  ).then(
687
+ fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files(
688
+ file_names=_file_names,
689
+ argumentation_model=_argumentation_model[0],
690
+ retriever=_retriever[0],
691
+ split_regex_escaped=unescape_regex(_split_regex_escaped),
692
+ handle_parts_of_same=HANDLE_PARTS_OF_SAME,
693
+ ),
694
  inputs=[
695
  upload_btn,
696
+ argumentation_model_state,
697
+ retriever_state,
698
  split_regex_escaped,
 
 
699
  ],
700
  outputs=[processed_documents_df],
701
  )
702
  processed_documents_df.select(
703
+ fn=get_cell_for_fixed_column_from_df,
704
+ inputs=[processed_documents_df, gr.State("doc_id")],
705
+ outputs=[selected_document_id],
706
  )
 
707
 
708
  download_processed_documents_btn.click(
709
+ fn=lambda _retriever: download_processed_documents(
710
+ _retriever[0], file_name="processed_documents"
711
+ ),
712
+ inputs=[retriever_state],
713
  outputs=[download_processed_documents_btn],
714
  )
715
  upload_processed_documents_btn.upload(
716
+ fn=lambda file_name, _retriever: upload_processed_documents(
717
+ file_name, retriever=_retriever[0]
718
+ ),
719
+ inputs=[upload_processed_documents_btn, retriever_state],
720
  outputs=[processed_documents_df],
721
  )
722
 
723
  retrieve_relevant_adus_event_kwargs = dict(
724
+ fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
725
+ retriever=_retriever[0],
726
+ query_span_id=_selected_adu_id,
727
+ k=_top_k,
728
+ score_threshold=_min_similarity,
729
+ relation_label_mapping=RELATION_NAME_MAPPING,
730
+ # columns=relevant_adus.headers
731
  ),
732
  inputs=[
733
+ retriever_state,
734
  selected_adu_id,
 
735
  min_similarity,
736
  top_k,
 
737
  ],
738
+ outputs=[relevant_adus_df],
739
+ )
740
+ relevant_adus_df.select(
741
+ fn=get_cell_for_fixed_column_from_df,
742
+ inputs=[relevant_adus_df, gr.State("doc_id")],
743
+ outputs=[selected_document_id],
744
  )
745
 
746
  selected_adu_id.change(
747
+ fn=lambda _retriever, _selected_adu_id: get_span_annotation(
748
+ retriever=_retriever[0], span_id=_selected_adu_id
 
 
 
 
749
  ),
750
+ inputs=[retriever_state, selected_adu_id],
751
  outputs=[selected_adu_text],
752
  ).success(**retrieve_relevant_adus_event_kwargs)
753
 
754
  retrieve_similar_adus_btn.click(
755
+ fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
756
+ retriever=_retriever[0],
757
+ query_span_id=_selected_adu_id,
758
+ k=_tok_k,
759
+ score_threshold=_min_similarity,
 
 
 
760
  ),
761
  inputs=[
762
+ retriever_state,
763
  selected_adu_id,
 
764
  min_similarity,
765
  top_k,
766
  ],
767
+ outputs=[similar_adus_df],
768
+ )
769
+ similar_adus_df.select(
770
+ fn=get_cell_for_fixed_column_from_df,
771
+ inputs=[similar_adus_df, gr.State("doc_id")],
772
+ outputs=[selected_document_id],
773
  )
774
 
775
+ retrieve_all_similar_adus_btn.click(
776
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
777
+ retriever=_retriever[0],
778
+ query_doc_id=_document_id,
779
+ k=_tok_k,
780
+ score_threshold=_min_similarity,
781
+ query_span_id_column="query_span_id",
782
+ ),
783
+ inputs=[
784
+ retriever_state,
785
+ selected_document_id,
786
+ min_similarity,
787
+ top_k,
788
+ ],
789
+ outputs=[all_similar_adus_df],
790
  )
791
+
792
+ retrieve_all_relevant_adus_btn.click(
793
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans(
794
+ retriever=_retriever[0],
795
+ query_doc_id=_document_id,
796
+ k=_tok_k,
797
+ score_threshold=_min_similarity,
798
+ query_span_id_column="query_span_id",
799
  ),
800
+ inputs=[
801
+ retriever_state,
802
+ selected_document_id,
803
+ min_similarity,
804
+ top_k,
805
+ ],
806
+ outputs=[all_relevant_adus_df],
807
+ )
808
+
809
+ # select query span id from the "retrieve all" result data frames
810
+ all_similar_adus_df.select(
811
+ fn=get_cell_for_fixed_column_from_df,
812
+ inputs=[all_similar_adus_df, gr.State("query_span_id")],
813
+ outputs=[selected_adu_id],
814
+ )
815
+ all_relevant_adus_df.select(
816
+ fn=get_cell_for_fixed_column_from_df,
817
+ inputs=[all_relevant_adus_df, gr.State("query_span_id")],
818
+ outputs=[selected_adu_id],
819
  )
820
 
821
+ # argumentation_model_state.change(
822
+ # fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]),
823
+ # inputs=[argumentation_model_state],
824
+ # outputs=[relation_types],
825
  # )
826
 
827
  rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
model_utils.py CHANGED
@@ -1,10 +1,10 @@
 
1
  import logging
2
- from typing import Optional, Tuple, Union
3
 
4
  import gradio as gr
5
- import torch
6
- from annotation_utils import labeled_span_to_id
7
- from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
8
  from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
9
  from pytorch_ie import Pipeline
10
  from pytorch_ie.annotations import LabeledSpan
@@ -13,31 +13,35 @@ from pytorch_ie.documents import (
13
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
14
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
15
  )
 
 
 
 
 
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
  def annotate_document(
21
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
22
- annotation_pipeline: Pipeline,
23
- embedding_model: Optional[EmbeddingModel] = None,
24
  handle_parts_of_same: bool = False,
25
  ) -> Union[
26
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
27
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
28
  ]:
29
- """Annotate a document with the provided pipeline. If an embedding model is provided, also
30
- extract embeddings for the labeled spans.
31
 
32
  Args:
33
  document: The document to annotate.
34
- annotation_pipeline: The pipeline to use for annotation.
35
- embedding_model: The embedding model to use for extracting text span embeddings.
36
  handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
37
  """
38
 
39
  # execute prediction pipeline
40
- annotation_pipeline(document)
41
 
42
  if handle_parts_of_same:
43
  merger = SpansViaRelationMerger(
@@ -53,22 +57,6 @@ def annotate_document(
53
  )
54
  document = merger(document)
55
 
56
- if embedding_model is not None:
57
- text_span_embeddings = embedding_model(
58
- document=document,
59
- span_layer_name="labeled_spans" if not handle_parts_of_same else "labeled_multi_spans",
60
- )
61
- # convert keys to str because JSON keys must be strings
62
- text_span_embeddings_dict = {
63
- labeled_span_to_id(k): v.tolist() for k, v in text_span_embeddings.items()
64
- }
65
- document.metadata["embeddings"] = text_span_embeddings_dict
66
- else:
67
- gr.Warning(
68
- "No embedding model provided. Skipping embedding extraction. You can load an embedding "
69
- "model in the 'Model Configuration' section."
70
- )
71
-
72
  return document
73
 
74
 
@@ -101,6 +89,80 @@ def create_document(
101
  return document
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def load_argumentation_model(
105
  model_name: str,
106
  revision: Optional[str] = None,
@@ -133,48 +195,260 @@ def load_argumentation_model(
133
  return model
134
 
135
 
136
- def load_embedding_model(
137
- embedding_model_name: Optional[str] = None,
138
- # embedding_model_revision: Optional[str] = None,
139
- embedding_max_length: int = 512,
140
- embedding_batch_size: int = 16,
141
  device: str = "cpu",
142
- ) -> Optional[EmbeddingModel]:
143
- if embedding_model_name is not None and embedding_model_name.strip():
144
- try:
145
- embedding_model = HuggingfaceEmbeddingModel(
146
- embedding_model_name.strip(),
147
- # revision=embedding_model_revision,
148
- device=device,
149
- max_length=embedding_max_length,
150
- batch_size=embedding_batch_size,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
- gr.Info(f"Loaded embedding model: model_name={embedding_model_name}, device={device}")
153
- except Exception as e:
154
- raise gr.Error(f"Failed to load embedding model: {e}")
155
- else:
156
- embedding_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- return embedding_model
159
 
 
 
 
 
 
 
 
 
160
 
161
- def load_models(
162
- model_name: str,
163
- revision: Optional[str] = None,
164
- embedding_model_name: Optional[str] = None,
165
- # embedding_model_revision: Optional[str] = None,
166
- embedding_max_length: int = 512,
167
- embedding_batch_size: int = 16,
168
- device: str = "cpu",
169
- ) -> Tuple[Pipeline, Optional[EmbeddingModel]]:
170
- torch.cuda.empty_cache()
171
- argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
172
- embedding_model = load_embedding_model(
173
- embedding_model_name=embedding_model_name,
174
- # embedding_model_revision=embedding_model_revision,
175
- embedding_max_length=embedding_max_length,
176
- embedding_batch_size=embedding_batch_size,
177
- device=device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  )
179
 
180
- return argumentation_model, embedding_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import logging
3
+ from typing import Iterable, Optional, Sequence, Union
4
 
5
  import gradio as gr
6
+ import pandas as pd
7
+ from pie_datasets import Dataset, IterableDataset, load_dataset
 
8
  from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
9
  from pytorch_ie import Pipeline
10
  from pytorch_ie.annotations import LabeledSpan
 
13
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
14
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
15
  )
16
+ from typing_extensions import Protocol
17
+
18
+ from src.langchain_modules import DocumentAwareSpanRetriever
19
+ from src.langchain_modules.span_retriever import (
20
+ DocumentAwareSpanRetrieverWithRelations,
21
+ _parse_config,
22
+ )
23
 
24
  logger = logging.getLogger(__name__)
25
 
26
 
27
  def annotate_document(
28
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
29
+ argumentation_model: Pipeline,
 
30
  handle_parts_of_same: bool = False,
31
  ) -> Union[
32
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
33
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
34
  ]:
35
+ """Annotate a document with the provided pipeline.
 
36
 
37
  Args:
38
  document: The document to annotate.
39
+ argumentation_model: The pipeline to use for annotation.
 
40
  handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
41
  """
42
 
43
  # execute prediction pipeline
44
+ argumentation_model(document)
45
 
46
  if handle_parts_of_same:
47
  merger = SpansViaRelationMerger(
 
57
  )
58
  document = merger(document)
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return document
61
 
62
 
 
89
  return document
90
 
91
 
92
+ def add_annotated_pie_documents(
93
+ retriever: DocumentAwareSpanRetriever,
94
+ pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
95
+ use_predicted_annotations: bool,
96
+ verbose: bool = False,
97
+ ) -> None:
98
+ if verbose:
99
+ gr.Info(f"Create span embeddings for {len(pie_documents)} documents...")
100
+ num_docs_before = len(retriever.docstore)
101
+ retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations)
102
+ # number of documents that were overwritten
103
+ num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore)
104
+ # warn if documents were overwritten
105
+ if num_overwritten_docs > 0:
106
+ gr.Warning(f"{num_overwritten_docs} documents were overwritten.")
107
+
108
+
109
+ def process_texts(
110
+ texts: Iterable[str],
111
+ doc_ids: Iterable[str],
112
+ argumentation_model: Pipeline,
113
+ retriever: DocumentAwareSpanRetriever,
114
+ split_regex_escaped: Optional[str],
115
+ handle_parts_of_same: bool = False,
116
+ verbose: bool = False,
117
+ ) -> None:
118
+ # check that doc_ids are unique
119
+ if len(set(doc_ids)) != len(list(doc_ids)):
120
+ raise gr.Error("Document IDs must be unique.")
121
+ pie_documents = [
122
+ create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped)
123
+ for text, doc_id in zip(texts, doc_ids)
124
+ ]
125
+ if verbose:
126
+ gr.Info(f"Annotate {len(pie_documents)} documents...")
127
+ pie_documents = [
128
+ annotate_document(
129
+ document=pie_document,
130
+ argumentation_model=argumentation_model,
131
+ handle_parts_of_same=handle_parts_of_same,
132
+ )
133
+ for pie_document in pie_documents
134
+ ]
135
+ add_annotated_pie_documents(
136
+ retriever=retriever,
137
+ pie_documents=pie_documents,
138
+ use_predicted_annotations=True,
139
+ verbose=verbose,
140
+ )
141
+
142
+
143
+ def add_annotated_pie_documents_from_dataset(
144
+ retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs
145
+ ) -> None:
146
+ try:
147
+ gr.Info(
148
+ "Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2)
149
+ )
150
+ dataset = load_dataset(**load_dataset_kwargs)
151
+ if not isinstance(dataset, (Dataset, IterableDataset)):
152
+ raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
153
+ dataset_converted = dataset.to_document_type(
154
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
155
+ )
156
+ add_annotated_pie_documents(
157
+ retriever=retriever,
158
+ pie_documents=dataset_converted,
159
+ use_predicted_annotations=False,
160
+ verbose=verbose,
161
+ )
162
+ except Exception as e:
163
+ raise gr.Error(f"Failed to load dataset: {e}")
164
+
165
+
166
  def load_argumentation_model(
167
  model_name: str,
168
  revision: Optional[str] = None,
 
195
  return model
196
 
197
 
198
+ def load_retriever(
199
+ retriever_config: str,
200
+ config_format: str,
 
 
201
  device: str = "cpu",
202
+ previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
203
+ ) -> DocumentAwareSpanRetrieverWithRelations:
204
+ try:
205
+ retriever_config = _parse_config(retriever_config, format=config_format)
206
+ # set device for the embeddings pipeline
207
+ retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
208
+ result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
209
+ # if a previous retriever is provided, load all documents and vectors from the previous retriever
210
+ if previous_retriever is not None:
211
+ # documents
212
+ all_doc_ids = list(previous_retriever.docstore.yield_keys())
213
+ gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...")
214
+ all_docs = previous_retriever.docstore.mget(all_doc_ids)
215
+ result.docstore.mset([(doc.id, doc) for doc in all_docs])
216
+ # spans (with vectors)
217
+ all_span_ids = list(previous_retriever.vectorstore.yield_keys())
218
+ all_spans = previous_retriever.vectorstore.mget(all_span_ids)
219
+ result.vectorstore.mset([(span.id, span) for span in all_spans])
220
+
221
+ gr.Info("Retriever loaded successfully.")
222
+ return result
223
+ except Exception as e:
224
+ raise gr.Error(f"Failed to load retriever: {e}")
225
+
226
+
227
+ def retrieve_similar_spans(
228
+ retriever: DocumentAwareSpanRetriever,
229
+ query_span_id: str,
230
+ **kwargs,
231
+ ) -> pd.DataFrame:
232
+ if not query_span_id.strip():
233
+ raise gr.Error("No query span selected.")
234
+ try:
235
+ retrieval_result = retriever.invoke(input=query_span_id, **kwargs)
236
+ records = []
237
+ for similar_span_doc in retrieval_result:
238
+ pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
239
+ span_ann = metadata["attached_span"]
240
+ records.append(
241
+ {
242
+ "doc_id": pie_doc.id,
243
+ "span_id": similar_span_doc.id,
244
+ "score": metadata["relevance_score"],
245
+ "label": span_ann.label,
246
+ "text": str(span_ann),
247
+ }
248
  )
249
+ return (
250
+ pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"])
251
+ .sort_values(by="score", ascending=False)
252
+ .round(3)
253
+ )
254
+ except Exception as e:
255
+ raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
256
+
257
+
258
+ def retrieve_relevant_spans(
259
+ retriever: DocumentAwareSpanRetriever,
260
+ query_span_id: str,
261
+ relation_label_mapping: Optional[dict[str, str]] = None,
262
+ **kwargs,
263
+ ) -> pd.DataFrame:
264
+ if not query_span_id.strip():
265
+ raise gr.Error("No query span selected.")
266
+ try:
267
+ relation_label_mapping = relation_label_mapping or {}
268
+ retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs)
269
+ records = []
270
+ for relevant_span_doc in retrieval_result:
271
+ pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc)
272
+ span_ann = metadata["attached_span"]
273
+ tail_span_ann = metadata["attached_tail_span"]
274
+ mapped_relation_label = relation_label_mapping.get(
275
+ metadata["relation_label"], metadata["relation_label"]
276
+ )
277
+ records.append(
278
+ {
279
+ "doc_id": pie_doc.id,
280
+ "type": mapped_relation_label,
281
+ "rel_score": metadata["relation_score"],
282
+ "text": str(tail_span_ann),
283
+ "span_id": relevant_span_doc.id,
284
+ "label": tail_span_ann.label,
285
+ "ref_score": metadata["relevance_score"],
286
+ "ref_label": span_ann.label,
287
+ "ref_text": str(span_ann),
288
+ "ref_span_id": metadata["head_id"],
289
+ }
290
+ )
291
+ return (
292
+ pd.DataFrame(
293
+ records,
294
+ columns=[
295
+ "type",
296
+ # omitted for now, we get no valid relation scores for the generative model
297
+ # "rel_score",
298
+ "ref_score",
299
+ "label",
300
+ "text",
301
+ "ref_label",
302
+ "ref_text",
303
+ "doc_id",
304
+ "span_id",
305
+ "ref_span_id",
306
+ ],
307
+ )
308
+ .sort_values(by=["ref_score"], ascending=False)
309
+ .round(3)
310
+ )
311
+ except Exception as e:
312
+ raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
313
 
 
314
 
315
+ class RetrieverCallable(Protocol):
316
+ def __call__(
317
+ self,
318
+ retriever: DocumentAwareSpanRetriever,
319
+ query_span_id: str,
320
+ **kwargs,
321
+ ) -> Optional[pd.DataFrame]:
322
+ pass
323
 
324
+
325
+ def _retrieve_for_all_spans(
326
+ retriever: DocumentAwareSpanRetriever,
327
+ query_doc_id: str,
328
+ retrieve_func: RetrieverCallable,
329
+ query_span_id_column: str = "query_span_id",
330
+ **kwargs,
331
+ ) -> Optional[pd.DataFrame]:
332
+ if not query_doc_id.strip():
333
+ raise gr.Error("No query document selected.")
334
+ try:
335
+ span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id)
336
+ gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...")
337
+ span_results = {
338
+ query_span_id: retrieve_func(
339
+ retriever=retriever,
340
+ query_span_id=query_span_id,
341
+ **kwargs,
342
+ )
343
+ for query_span_id in span_id2idx.keys()
344
+ }
345
+ span_results_not_empty = {
346
+ query_span_id: df
347
+ for query_span_id, df in span_results.items()
348
+ if df is not None and not df.empty
349
+ }
350
+
351
+ # add column with query_span_id
352
+ for query_span_id, query_span_result in span_results_not_empty.items():
353
+ query_span_result[query_span_id_column] = query_span_id
354
+
355
+ if len(span_results_not_empty) == 0:
356
+ gr.Info(f"No results found for any ADU in document {query_doc_id}.")
357
+ return None
358
+ else:
359
+ result = pd.concat(span_results_not_empty.values(), ignore_index=True)
360
+ gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.")
361
+ return result
362
+ except Exception as e:
363
+ raise gr.Error(
364
+ f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}'
365
+ )
366
+
367
+
368
+ def retrieve_all_similar_spans(
369
+ retriever: DocumentAwareSpanRetriever,
370
+ query_doc_id: str,
371
+ **kwargs,
372
+ ) -> Optional[pd.DataFrame]:
373
+ return _retrieve_for_all_spans(
374
+ retriever=retriever,
375
+ query_doc_id=query_doc_id,
376
+ retrieve_func=retrieve_similar_spans,
377
+ **kwargs,
378
  )
379
 
380
+
381
+ def retrieve_all_relevant_spans(
382
+ retriever: DocumentAwareSpanRetriever,
383
+ query_doc_id: str,
384
+ **kwargs,
385
+ ) -> Optional[pd.DataFrame]:
386
+ return _retrieve_for_all_spans(
387
+ retriever=retriever,
388
+ query_doc_id=query_doc_id,
389
+ retrieve_func=retrieve_relevant_spans,
390
+ **kwargs,
391
+ )
392
+
393
+
394
+ class RetrieverForAllSpansCallable(Protocol):
395
+ def __call__(
396
+ self,
397
+ retriever: DocumentAwareSpanRetriever,
398
+ query_doc_id: str,
399
+ **kwargs,
400
+ ) -> Optional[pd.DataFrame]:
401
+ pass
402
+
403
+
404
+ def _retrieve_for_all_documents(
405
+ retriever: DocumentAwareSpanRetriever,
406
+ retrieve_func: RetrieverForAllSpansCallable,
407
+ query_doc_id_column: str = "query_doc_id",
408
+ **kwargs,
409
+ ) -> Optional[pd.DataFrame]:
410
+ try:
411
+ all_doc_ids = list(retriever.docstore.yield_keys())
412
+ gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...")
413
+ doc_results = {
414
+ doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs)
415
+ for doc_id in all_doc_ids
416
+ }
417
+ doc_results_not_empty = {
418
+ doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty
419
+ }
420
+ # add column with query_doc_id
421
+ for doc_id, doc_result in doc_results_not_empty.items():
422
+ doc_result[query_doc_id_column] = doc_id
423
+
424
+ if len(doc_results_not_empty) == 0:
425
+ gr.Info("No results found for any document.")
426
+ return None
427
+ else:
428
+ result = pd.concat(doc_results_not_empty, ignore_index=True)
429
+ gr.Info(f"Retrieved {len(result)} ADUs for all documents.")
430
+ return result
431
+ except Exception as e:
432
+ raise gr.Error(f"Failed to retrieve results for all documents: {e}")
433
+
434
+
435
+ def retrieve_all_similar_spans_for_all_documents(
436
+ retriever: DocumentAwareSpanRetriever,
437
+ **kwargs,
438
+ ) -> Optional[pd.DataFrame]:
439
+ return _retrieve_for_all_documents(
440
+ retriever=retriever,
441
+ retrieve_func=retrieve_all_similar_spans,
442
+ **kwargs,
443
+ )
444
+
445
+
446
+ def retrieve_all_relevant_spans_for_all_documents(
447
+ retriever: DocumentAwareSpanRetriever,
448
+ **kwargs,
449
+ ) -> Optional[pd.DataFrame]:
450
+ return _retrieve_for_all_documents(
451
+ retriever=retriever,
452
+ retrieve_func=retrieve_all_relevant_spans,
453
+ **kwargs,
454
+ )
rendering_utils.py CHANGED
@@ -1,14 +1,9 @@
1
  import json
2
  import logging
3
  from collections import defaultdict
4
- from typing import Dict, List, Optional, Union
5
 
6
- from annotation_utils import labeled_span_to_id
7
  from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
8
- from pytorch_ie.documents import (
9
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
10
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
11
- )
12
  from rendering_utils_displacy import EntityRenderer
13
 
14
  logger = logging.getLogger(__name__)
@@ -92,7 +87,20 @@ HIGHLIGHT_SPANS_JS = """
92
  });
93
  }
94
  }
95
- function setReferenceAduId(entityId) {
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  // get the textarea element that holds the reference adu id
97
  let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
98
  // set the value of the input field
@@ -104,6 +112,8 @@ HIGHLIGHT_SPANS_JS = """
104
 
105
  const entities = document.querySelectorAll('.entity');
106
  entities.forEach(entity => {
 
 
107
  const alreadyHasListener = entity.getAttribute('data-has-listener');
108
  if (alreadyHasListener) {
109
  return;
@@ -111,22 +121,30 @@ HIGHLIGHT_SPANS_JS = """
111
  entity.addEventListener('mouseover', () => {
112
  const entityId = entity.getAttribute('data-entity-id');
113
  highlightRelationArguments(entityId);
114
- setReferenceAduId(entityId);
115
  });
116
  entity.addEventListener('mouseout', () => {
117
  highlightRelationArguments(null);
118
  });
119
  entity.setAttribute('data-has-listener', 'true');
120
  });
 
 
 
 
 
 
 
 
121
  }
122
  """
123
 
124
 
125
  def render_pretty_table(
126
- document: Union[
127
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
128
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
129
- ],
130
  **render_kwargs,
131
  ):
132
  from prettytable import PrettyTable
@@ -134,7 +152,7 @@ def render_pretty_table(
134
  t = PrettyTable()
135
  t.field_names = ["head", "tail", "relation"]
136
  t.align = "l"
137
- for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
138
  t.add_row([str(relation.head), str(relation.tail), relation.label])
139
 
140
  html = t.get_html_string(format=True)
@@ -144,29 +162,19 @@ def render_pretty_table(
144
 
145
 
146
  def render_displacy(
147
- document: Union[
148
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
149
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
150
- ],
151
  inject_relations=True,
152
  colors_hover=None,
153
  entity_options={},
154
  **render_kwargs,
155
  ):
156
 
157
- if isinstance(document, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions):
158
- span_layer = document.labeled_spans
159
- elif isinstance(
160
- document, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
161
- ):
162
- span_layer = document.labeled_multi_spans
163
- else:
164
- raise ValueError(f"Unsupported document type: {type(document)}")
165
-
166
- span_annotations = list(span_layer) + list(span_layer.predictions)
167
  ents = []
168
- for labeled_span in span_annotations:
169
- entity_id = labeled_span_to_id(labeled_span)
170
  # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
171
  # on hover and to inject the relation data.
172
  if isinstance(labeled_span, LabeledSpan):
@@ -192,7 +200,7 @@ def render_displacy(
192
  raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
193
 
194
  spacy_doc = {
195
- "text": document.text,
196
  # the ents MUST be sorted by start and end
197
  "ents": sorted(ents, key=lambda x: (x["start"], x["end"])),
198
  "title": None,
@@ -207,12 +215,10 @@ def render_displacy(
207
 
208
  html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
209
  if inject_relations:
210
- binary_relations = list(document.binary_relations) + list(
211
- document.binary_relations.predictions
212
- )
213
  html = inject_relation_data(
214
  html,
215
- span_annotations=span_annotations,
 
216
  binary_relations=binary_relations,
217
  additional_colors=colors_hover,
218
  )
@@ -221,8 +227,9 @@ def render_displacy(
221
 
222
  def inject_relation_data(
223
  html: str,
224
- span_annotations: Union[List[LabeledSpan], List[LabeledMultiSpan]],
225
- binary_relations: List[BinaryRelation],
 
226
  additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
227
  ) -> str:
228
  from bs4 import BeautifulSoup
@@ -236,7 +243,7 @@ def inject_relation_data(
236
  entity2heads[relation.tail].append((relation.head, relation.label))
237
  entity2tails[relation.head].append((relation.tail, relation.label))
238
 
239
- ann_id2annotation = {labeled_span_to_id(entity): entity for entity in span_annotations}
240
  # Add unique IDs to each entity
241
  entities = soup.find_all(class_="entity")
242
  for entity in entities:
@@ -247,7 +254,8 @@ def inject_relation_data(
247
  entity[f"data-color-{key}"] = (
248
  json.dumps(color) if isinstance(color, dict) else color
249
  )
250
- entity_annotation = ann_id2annotation[entity["data-entity-id"]]
 
251
 
252
  # sanity check.
253
  if isinstance(entity_annotation, LabeledSpan):
@@ -265,14 +273,16 @@ def inject_relation_data(
265
  entity["data-label"] = entity_annotation.label
266
  entity["data-relation-tails"] = json.dumps(
267
  [
268
- {"entity-id": labeled_span_to_id(tail), "label": label}
269
  for tail, label in entity2tails.get(entity_annotation, [])
 
270
  ]
271
  )
272
  entity["data-relation-heads"] = json.dumps(
273
  [
274
- {"entity-id": labeled_span_to_id(head), "label": label}
275
  for head, label in entity2heads.get(entity_annotation, [])
 
276
  ]
277
  )
278
 
 
1
  import json
2
  import logging
3
  from collections import defaultdict
4
+ from typing import Dict, List, Optional, Sequence, Union
5
 
 
6
  from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
 
 
 
 
7
  from rendering_utils_displacy import EntityRenderer
8
 
9
  logger = logging.getLogger(__name__)
 
87
  });
88
  }
89
  }
90
+ function setHoverAduId(entityId) {
91
+ // get the textarea element that holds the reference adu id
92
+ let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
93
+ // set the value of the input field
94
+ hoverAduIdDiv.value = entityId;
95
+ // trigger an input event to update the state
96
+ var event = new Event('input');
97
+ hoverAduIdDiv.dispatchEvent(event);
98
+ }
99
+ function setReferenceAduIdFromHover() {
100
+ // get the hover adu id
101
+ const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
102
+ // get the value of the input field
103
+ const entityId = hoverAduIdDiv.value;
104
  // get the textarea element that holds the reference adu id
105
  let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
106
  // set the value of the input field
 
112
 
113
  const entities = document.querySelectorAll('.entity');
114
  entities.forEach(entity => {
115
+ // make the cursor a pointer
116
+ entity.style.cursor = 'pointer';
117
  const alreadyHasListener = entity.getAttribute('data-has-listener');
118
  if (alreadyHasListener) {
119
  return;
 
121
  entity.addEventListener('mouseover', () => {
122
  const entityId = entity.getAttribute('data-entity-id');
123
  highlightRelationArguments(entityId);
124
+ setHoverAduId(entityId);
125
  });
126
  entity.addEventListener('mouseout', () => {
127
  highlightRelationArguments(null);
128
  });
129
  entity.setAttribute('data-has-listener', 'true');
130
  });
131
+ const entityContainer = document.querySelector('.entities');
132
+ if (entityContainer) {
133
+ entityContainer.addEventListener('click', () => {
134
+ setReferenceAduIdFromHover();
135
+ });
136
+ // make the cursor a pointer
137
+ // entityContainer.style.cursor = 'pointer';
138
+ }
139
  }
140
  """
141
 
142
 
143
  def render_pretty_table(
144
+ text: str,
145
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
146
+ span_id2idx: Dict[str, int],
147
+ binary_relations: Sequence[BinaryRelation],
148
  **render_kwargs,
149
  ):
150
  from prettytable import PrettyTable
 
152
  t = PrettyTable()
153
  t.field_names = ["head", "tail", "relation"]
154
  t.align = "l"
155
+ for relation in list(binary_relations) + list(binary_relations):
156
  t.add_row([str(relation.head), str(relation.tail), relation.label])
157
 
158
  html = t.get_html_string(format=True)
 
162
 
163
 
164
  def render_displacy(
165
+ text: str,
166
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
167
+ span_id2idx: Dict[str, int],
168
+ binary_relations: Sequence[BinaryRelation],
169
  inject_relations=True,
170
  colors_hover=None,
171
  entity_options={},
172
  **render_kwargs,
173
  ):
174
 
 
 
 
 
 
 
 
 
 
 
175
  ents = []
176
+ for entity_id, idx in span_id2idx.items():
177
+ labeled_span = spans[idx]
178
  # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
179
  # on hover and to inject the relation data.
180
  if isinstance(labeled_span, LabeledSpan):
 
200
  raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
201
 
202
  spacy_doc = {
203
+ "text": text,
204
  # the ents MUST be sorted by start and end
205
  "ents": sorted(ents, key=lambda x: (x["start"], x["end"])),
206
  "title": None,
 
215
 
216
  html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
217
  if inject_relations:
 
 
 
218
  html = inject_relation_data(
219
  html,
220
+ spans=spans,
221
+ span_id2idx=span_id2idx,
222
  binary_relations=binary_relations,
223
  additional_colors=colors_hover,
224
  )
 
227
 
228
  def inject_relation_data(
229
  html: str,
230
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
231
+ span_id2idx: Dict[str, int],
232
+ binary_relations: Sequence[BinaryRelation],
233
  additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
234
  ) -> str:
235
  from bs4 import BeautifulSoup
 
243
  entity2heads[relation.tail].append((relation.head, relation.label))
244
  entity2tails[relation.head].append((relation.tail, relation.label))
245
 
246
+ annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()}
247
  # Add unique IDs to each entity
248
  entities = soup.find_all(class_="entity")
249
  for entity in entities:
 
254
  entity[f"data-color-{key}"] = (
255
  json.dumps(color) if isinstance(color, dict) else color
256
  )
257
+
258
+ entity_annotation = spans[span_id2idx[entity["data-entity-id"]]]
259
 
260
  # sanity check.
261
  if isinstance(entity_annotation, LabeledSpan):
 
273
  entity["data-label"] = entity_annotation.label
274
  entity["data-relation-tails"] = json.dumps(
275
  [
276
+ {"entity-id": annotation2id[tail], "label": label}
277
  for tail, label in entity2tails.get(entity_annotation, [])
278
+ if tail in annotation2id
279
  ]
280
  )
281
  entity["data-relation-heads"] = json.dumps(
282
  [
283
+ {"entity-id": annotation2id[head], "label": label}
284
  for head, label in entity2heads.get(entity_annotation, [])
285
+ if head in annotation2id
286
  ]
287
  )
288
 
requirements.txt CHANGED
@@ -1,11 +1,34 @@
1
- pytorch-ie==0.31.1
2
- gradio==4.36.0
3
  prettytable==3.10.0
4
- pie-modules==0.12.0
5
  beautifulsoup4==4.12.3
6
- datasets==2.14.4
7
  # numpy 2.0.0 breaks the code
8
  numpy==1.25.2
9
- qdrant-client==1.9.1
10
  scipy==1.13.0
11
  arxiv==2.1.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.1
 
2
  prettytable==3.10.0
 
3
  beautifulsoup4==4.12.3
 
4
  # numpy 2.0.0 breaks the code
5
  numpy==1.25.2
 
6
  scipy==1.13.0
7
  arxiv==2.1.3
8
+ pyrootutils>=1.0.0,<1.1.0
9
+
10
+ ########## from root requirements ##########
11
+ # --------- pytorch-ie --------- #
12
+ pytorch-ie>=0.29.6,<0.32.0
13
+ pie-datasets>=0.10.5,<0.11.0
14
+ pie-modules>=0.14.0,<0.15.0
15
+
16
+ # --------- models -------- #
17
+ adapters>=0.1.2,<0.2.0
18
+ # ADU retrieval (and demo, in future):
19
+ langchain>=0.3.0,<0.4.0
20
+ langchain-core>=0.3.0,<0.4.0
21
+ langchain-community>=0.3.0,<0.4.0
22
+ # we use QDrant as vectorstore backend
23
+ langchain-qdrant>=0.1.0,<0.2.0
24
+ qdrant-client>=1.12.0,<2.0.0
25
+ # 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748
26
+ huggingface_hub<0.26.0 # 0.26 seems to be broken
27
+ # to to handle segmented entities (if HANDLE_PARTS_OF_SAME=True)
28
+ networkx>=3.0.0,<4.0.0
29
+
30
+ # --------- config --------- #
31
+ hydra-core>=1.3.0
32
+
33
+ # --------- dev --------- #
34
+ pre-commit # hooks for applying linters on commit
retrieve_and_dump_all_relevant.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ import argparse
11
+ import logging
12
+
13
+ from demo.model_utils import (
14
+ retrieve_all_relevant_spans,
15
+ retrieve_all_relevant_spans_for_all_documents,
16
+ retrieve_relevant_spans,
17
+ )
18
+ from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ if __name__ == "__main__":
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument(
27
+ "-c",
28
+ "--config_path",
29
+ type=str,
30
+ default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
31
+ )
32
+ parser.add_argument(
33
+ "--data_path",
34
+ type=str,
35
+ required=True,
36
+ help="Path to a zip or directory containing a retriever dump.",
37
+ )
38
+ parser.add_argument("-k", "--top_k", type=int, default=10)
39
+ parser.add_argument("-t", "--threshold", type=float, default=0.95)
40
+ parser.add_argument(
41
+ "-o",
42
+ "--output_path",
43
+ type=str,
44
+ required=True,
45
+ )
46
+ parser.add_argument(
47
+ "--query_doc_id",
48
+ type=str,
49
+ default=None,
50
+ help="If provided, retrieve all spans for only this query document.",
51
+ )
52
+ parser.add_argument(
53
+ "--query_span_id",
54
+ type=str,
55
+ default=None,
56
+ help="If provided, retrieve all spans for only this query span.",
57
+ )
58
+ args = parser.parse_args()
59
+
60
+ logging.basicConfig(
61
+ format="%(asctime)s %(levelname)-8s %(message)s",
62
+ level=logging.INFO,
63
+ datefmt="%Y-%m-%d %H:%M:%S",
64
+ )
65
+
66
+ if not args.output_path.endswith(".json"):
67
+ raise ValueError("only support json output")
68
+
69
+ logger.info(f"instantiating retriever from {args.config_path}...")
70
+ retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
71
+ args.config_path
72
+ )
73
+ logger.info(f"loading data from {args.data_path}...")
74
+ retriever.load_from_disc(args.data_path)
75
+
76
+ search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
77
+ logger.info(f"use search_kwargs: {search_kwargs}")
78
+
79
+ if args.query_span_id is not None:
80
+ logger.warning(f"retrieving results for single span: {args.query_span_id}")
81
+ all_spans_for_all_documents = retrieve_relevant_spans(
82
+ retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
83
+ )
84
+ elif args.query_doc_id is not None:
85
+ logger.warning(f"retrieving results for single document: {args.query_doc_id}")
86
+ all_spans_for_all_documents = retrieve_all_relevant_spans(
87
+ retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
88
+ )
89
+ else:
90
+ all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
91
+ retriever=retriever, **search_kwargs
92
+ )
93
+
94
+ if all_spans_for_all_documents is None:
95
+ logger.warning("no relevant spans found in any document")
96
+ exit(0)
97
+
98
+ logger.info(f"dumping results to {args.output_path}...")
99
+ all_spans_for_all_documents.to_json(args.output_path)
100
+
101
+ logger.info("done")
retriever/related_span_retriever_with_relations_from_other_docs.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.langchain_modules.DocumentAwareSpanRetrieverWithRelations
2
+ reversed_relations_suffix: _reversed
3
+ relation_labels:
4
+ - supports_reversed
5
+ - contradicts_reversed
6
+ retrieve_from_same_document: false
7
+ retrieve_from_different_documents: true
8
+ pie_document_type:
9
+ _target_: pie_modules.utils.resolve_type
10
+ type_or_str: pytorch_ie.documents.TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
11
+ docstore:
12
+ _target_: src.langchain_modules.DatasetsPieDocumentStore
13
+ search_kwargs:
14
+ k: 10
15
+ search_type: similarity_score_threshold
16
+ vectorstore:
17
+ _target_: src.langchain_modules.QdrantSpanVectorStore
18
+ embedding:
19
+ _target_: src.langchain_modules.HuggingFaceSpanEmbeddings
20
+ model:
21
+ _target_: src.models.utils.adapters.load_model_with_adapter
22
+ model_kwargs:
23
+ pretrained_model_name_or_path: allenai/specter2_base
24
+ adapter_kwargs:
25
+ adapter_name_or_path: allenai/specter2
26
+ load_as: proximity
27
+ source: hf
28
+ pipeline_kwargs:
29
+ tokenizer: allenai/specter2_base
30
+ stride: 64
31
+ batch_size: 32
32
+ model_max_length: 512
33
+ client:
34
+ _target_: qdrant_client.QdrantClient
35
+ location: ":memory:"
36
+ collection_name: adus
37
+ vector_params:
38
+ distance:
39
+ _target_: qdrant_client.http.models.Distance
40
+ value: Cosine
41
+ label_mapping:
42
+ background_claim:
43
+ - background_claim
44
+ - own_claim
45
+ own_claim:
46
+ - background_claim
47
+ - own_claim
src/__init__.py ADDED
File without changes
src/hf_pipeline/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .feature_extraction import FeatureExtractionPipelineWithStriding
src/hf_pipeline/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (241 Bytes). View file
 
src/hf_pipeline/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (239 Bytes). View file
 
src/hf_pipeline/__pycache__/feature_extraction.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
src/hf_pipeline/__pycache__/feature_extraction.cpython-39.pyc ADDED
Binary file (10.9 kB). View file
 
src/hf_pipeline/feature_extraction.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from collections import defaultdict
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from transformers.pipelines.base import ArgumentHandler, ChunkPipeline, Dataset
7
+ from transformers.utils import is_tf_available, is_torch_available
8
+
9
+ if is_tf_available():
10
+ import tensorflow as tf
11
+ from transformers.models.auto.modeling_tf_auto import (
12
+ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
13
+ )
14
+ if is_torch_available():
15
+ from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
16
+
17
+
18
+ def list_of_dicts2dict_of_lists(list_of_dicts: list[dict]) -> dict[str, list]:
19
+ return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0].keys()}
20
+
21
+
22
+ class FeatureExtractionArgumentHandler(ArgumentHandler):
23
+ """Handles arguments for feature extraction."""
24
+
25
+ def __call__(self, inputs: Union[str, List[str]], **kwargs):
26
+ if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
27
+ inputs = list(inputs)
28
+ batch_size = len(inputs)
29
+ elif isinstance(inputs, str):
30
+ inputs = [inputs]
31
+ batch_size = 1
32
+ elif (
33
+ Dataset is not None
34
+ and isinstance(inputs, Dataset)
35
+ or isinstance(inputs, types.GeneratorType)
36
+ ):
37
+ return inputs, None
38
+ else:
39
+ raise ValueError("At least one input is required.")
40
+
41
+ offset_mapping = kwargs.get("offset_mapping")
42
+ if offset_mapping:
43
+ if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
44
+ offset_mapping = [offset_mapping]
45
+ if len(offset_mapping) != batch_size:
46
+ raise ValueError("offset_mapping should have the same batch size as the input")
47
+ return inputs, offset_mapping
48
+
49
+
50
+ class FeatureExtractionPipelineWithStriding(ChunkPipeline):
51
+ """Same as transformers.FeatureExtractionPipeline, but with long input handling. Inspired by
52
+ transformers.TokenClassificationPipeline. The functionality is triggered when providing the
53
+ "stride" parameter (can be 0). When passing "create_unique_embeddings_per_token=True", only one
54
+ embedding (and other data, see flags below) per token will be returned (this makes use of
55
+ min_distance_to_border, see "return_min_distance_to_border" below for details). Note that this
56
+ removes data for special token positions!
57
+
58
+ Per default, it will return just the embeddings. If any of the return_ADDITIONAL_RESULT is
59
+ enabled (see below), it will return dictionaries with "last_hidden_state" and all
60
+ ADDITIONAL_RESULT depending on the flags.
61
+
62
+ Flags to return additional results:
63
+ return_offset_mapping: If enabled, return the offset mapping.
64
+ return_special_tokens_mask: If enabled, return the special tokens mask.
65
+ return_sequence_indices: If enabled, return the sequence indices.
66
+ return_position_ids: If enabled, return the position ids from, values are in [0, model_max_length).
67
+ return_min_distance_to_border: If enabled, return the minimum distance to the "border" of
68
+ the input that gets passed into the model. This is useful when striding is used which may
69
+ produce multiple embeddings for a token (compare values in offset_mapping). In this case,
70
+ min_distance_to_border can be used to select the embedding that is more in the center
71
+ of the input by choosing the entry with the *higher* min_distance_to_border.
72
+ """
73
+
74
+ default_input_names = "sequences"
75
+
76
+ def __init__(self, args_parser=FeatureExtractionArgumentHandler(), *args, **kwargs):
77
+ super().__init__(*args, **kwargs)
78
+ self.check_model_type(
79
+ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
80
+ if self.framework == "tf"
81
+ else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
82
+ )
83
+
84
+ self._args_parser = args_parser
85
+
86
+ def _sanitize_parameters(
87
+ self,
88
+ offset_mapping: Optional[List[Tuple[int, int]]] = None,
89
+ stride: Optional[int] = None,
90
+ create_unique_embeddings_per_token: Optional[bool] = False,
91
+ return_offset_mapping: Optional[bool] = None,
92
+ return_special_tokens_mask: Optional[bool] = None,
93
+ return_sequence_indices: Optional[bool] = None,
94
+ return_position_ids: Optional[bool] = None,
95
+ return_min_distance_to_border: Optional[bool] = None,
96
+ return_tensors=None,
97
+ ):
98
+ preprocess_params = {}
99
+ if offset_mapping is not None:
100
+ preprocess_params["offset_mapping"] = offset_mapping
101
+
102
+ if stride is not None:
103
+ if stride >= self.tokenizer.model_max_length:
104
+ raise ValueError(
105
+ "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
106
+ )
107
+
108
+ if self.tokenizer.is_fast:
109
+ tokenizer_params = {
110
+ "return_overflowing_tokens": True,
111
+ "padding": True,
112
+ "stride": stride,
113
+ }
114
+ preprocess_params["tokenizer_params"] = tokenizer_params # type: ignore
115
+ else:
116
+ raise ValueError(
117
+ "`stride` was provided to process all the text but you're using a slow tokenizer."
118
+ " Please use a fast tokenizer."
119
+ )
120
+ postprocess_params = {}
121
+ if create_unique_embeddings_per_token is not None:
122
+ postprocess_params["create_unique_embeddings_per_token"] = (
123
+ create_unique_embeddings_per_token
124
+ )
125
+ if return_offset_mapping is not None:
126
+ postprocess_params["return_offset_mapping"] = return_offset_mapping
127
+ if return_special_tokens_mask is not None:
128
+ postprocess_params["return_special_tokens_mask"] = return_special_tokens_mask
129
+ if return_sequence_indices is not None:
130
+ postprocess_params["return_sequence_indices"] = return_sequence_indices
131
+ if return_position_ids is not None:
132
+ postprocess_params["return_position_ids"] = return_position_ids
133
+ if return_min_distance_to_border is not None:
134
+ postprocess_params["return_min_distance_to_border"] = return_min_distance_to_border
135
+ if return_tensors is not None:
136
+ postprocess_params["return_tensors"] = return_tensors
137
+ return preprocess_params, {}, postprocess_params
138
+
139
+ def __call__(self, inputs: Union[str, List[str]], **kwargs):
140
+
141
+ _inputs, offset_mapping = self._args_parser(inputs, **kwargs)
142
+ if offset_mapping:
143
+ kwargs["offset_mapping"] = offset_mapping
144
+
145
+ return super().__call__(inputs, **kwargs)
146
+
147
+ def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
148
+ tokenizer_params = preprocess_params.pop("tokenizer_params", {})
149
+ truncation = (
150
+ True
151
+ if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0
152
+ else False
153
+ )
154
+ inputs = self.tokenizer(
155
+ sentence,
156
+ return_tensors=self.framework,
157
+ truncation=truncation,
158
+ return_special_tokens_mask=True,
159
+ return_offsets_mapping=self.tokenizer.is_fast,
160
+ **tokenizer_params,
161
+ )
162
+ inputs.pop("overflow_to_sample_mapping", None)
163
+ num_chunks = len(inputs["input_ids"])
164
+
165
+ for i in range(num_chunks):
166
+ if self.framework == "tf":
167
+ model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()}
168
+ else:
169
+ model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
170
+ if offset_mapping is not None:
171
+ model_inputs["offset_mapping"] = offset_mapping
172
+ model_inputs["sentence"] = sentence if i == 0 else None
173
+ model_inputs["is_last"] = i == num_chunks - 1
174
+
175
+ yield model_inputs
176
+
177
+ def _forward(self, model_inputs, **kwargs):
178
+ # Forward
179
+ special_tokens_mask = model_inputs.pop("special_tokens_mask")
180
+ offset_mapping = model_inputs.pop("offset_mapping", None)
181
+ sentence = model_inputs.pop("sentence")
182
+ is_last = model_inputs.pop("is_last")
183
+ if self.framework == "tf":
184
+ last_hidden_state = self.model(**model_inputs)[0]
185
+ else:
186
+ output = self.model(**model_inputs)
187
+ last_hidden_state = (
188
+ output["last_hidden_state"] if isinstance(output, dict) else output[0]
189
+ )
190
+
191
+ return {
192
+ "last_hidden_state": last_hidden_state,
193
+ "special_tokens_mask": special_tokens_mask,
194
+ "offset_mapping": offset_mapping,
195
+ "sentence": sentence,
196
+ "is_last": is_last,
197
+ **model_inputs,
198
+ }
199
+
200
+ def postprocess_tensor(self, data, return_tensors=False):
201
+ if return_tensors:
202
+ return data
203
+ if self.framework == "pt":
204
+ return data.tolist()
205
+ elif self.framework == "tf":
206
+ return data.numpy().tolist()
207
+ else:
208
+ raise ValueError(f"unknown framework: {self.framework}")
209
+
210
+ def make_embeddings_unique_per_token(
211
+ self, data, offset_mapping, special_tokens_mask, min_distance_to_border
212
+ ):
213
+ char_offsets2token_pos = defaultdict(list)
214
+ bs, seq_len = offset_mapping.shape[:2]
215
+ if bs != 1:
216
+ raise ValueError(f"expected result batch size 1, but it is: {bs}")
217
+ for token_idx, ((char_start, shar_end), is_special_token, min_dist) in enumerate(
218
+ zip(
219
+ offset_mapping[0].tolist(),
220
+ special_tokens_mask[0].tolist(),
221
+ min_distance_to_border[0].tolist(),
222
+ )
223
+ ):
224
+
225
+ if not is_special_token:
226
+ char_offsets2token_pos[(char_start, shar_end)].append((token_idx, min_dist))
227
+
228
+ # tokens_with_multiple_embeddings = {k: v for k, v in char_offsets2token_pos.items() if len(v) > 1}
229
+ char_offsets2best_token_pos = {
230
+ k: max(v, key=lambda pos_dist: pos_dist[1])[0]
231
+ for k, v in char_offsets2token_pos.items()
232
+ }
233
+ # sort by char offsets (start and end)
234
+ sorted_char_offsets_token_positions = sorted(
235
+ char_offsets2best_token_pos.items(),
236
+ key=lambda char_offsets_tok_pos: (
237
+ char_offsets_tok_pos[0][0],
238
+ char_offsets_tok_pos[0][1],
239
+ ),
240
+ )
241
+ best_indices = [tok_pos for char_offsets, tok_pos in sorted_char_offsets_token_positions]
242
+
243
+ result = {k: v[0][best_indices].unsqueeze(0) for k, v in data.items()}
244
+ return result
245
+
246
+ def postprocess(
247
+ self,
248
+ all_outputs,
249
+ create_unique_embeddings_per_token: bool = False,
250
+ return_offset_mapping: bool = False,
251
+ return_special_tokens_mask: bool = False,
252
+ return_sequence_indices: bool = False,
253
+ return_position_ids: bool = False,
254
+ return_min_distance_to_border: bool = False,
255
+ return_tensors: bool = False,
256
+ ):
257
+
258
+ all_outputs_dict = list_of_dicts2dict_of_lists(all_outputs)
259
+ if self.framework == "pt":
260
+ result = {
261
+ "last_hidden_state": torch.concat(all_outputs_dict["last_hidden_state"], axis=1)
262
+ }
263
+ if return_offset_mapping or create_unique_embeddings_per_token:
264
+ result["offset_mapping"] = torch.concat(all_outputs_dict["offset_mapping"], axis=1)
265
+ if return_special_tokens_mask or create_unique_embeddings_per_token:
266
+ result["special_tokens_mask"] = torch.concat(
267
+ all_outputs_dict["special_tokens_mask"], axis=1
268
+ )
269
+ if return_sequence_indices:
270
+ sequence_indices = []
271
+ for seq_idx, model_outputs in enumerate(all_outputs):
272
+ sequence_indices.append(torch.ones_like(model_outputs["input_ids"]) * seq_idx)
273
+ result["sequence_indices"] = torch.concat(sequence_indices, axis=1)
274
+ if return_position_ids:
275
+ position_ids = []
276
+ for seq_idx, model_outputs in enumerate(all_outputs):
277
+ seq_len = model_outputs["input_ids"].size(1)
278
+ position_ids.append(torch.arange(seq_len).unsqueeze(0))
279
+ result["indices"] = torch.concat(position_ids, axis=1)
280
+ if return_min_distance_to_border or create_unique_embeddings_per_token:
281
+ min_distance_to_border = []
282
+ for seq_idx, model_outputs in enumerate(all_outputs):
283
+ seq_len = model_outputs["input_ids"].size(1)
284
+ current_indices = torch.arange(seq_len).unsqueeze(0)
285
+ min_distance_to_border.append(
286
+ torch.minimum(current_indices, seq_len - current_indices)
287
+ )
288
+ result["min_distance_to_border"] = torch.concat(min_distance_to_border, axis=1)
289
+ elif self.framework == "tf":
290
+ raise NotImplementedError()
291
+ else:
292
+ raise ValueError(f"unknown framework: {self.framework}")
293
+
294
+ if create_unique_embeddings_per_token:
295
+ offset_mapping = result["offset_mapping"]
296
+ if not return_offset_mapping:
297
+ del result["offset_mapping"]
298
+ special_tokens_mask = result["special_tokens_mask"]
299
+ if not return_special_tokens_mask:
300
+ del result["special_tokens_mask"]
301
+ min_distance_to_border = result["min_distance_to_border"]
302
+ if not return_min_distance_to_border:
303
+ del result["min_distance_to_border"]
304
+ result = self.make_embeddings_unique_per_token(
305
+ data=result,
306
+ offset_mapping=offset_mapping,
307
+ special_tokens_mask=special_tokens_mask,
308
+ min_distance_to_border=min_distance_to_border,
309
+ )
310
+
311
+ result = {
312
+ k: self.postprocess_tensor(v, return_tensors=return_tensors) for k, v in result.items()
313
+ }
314
+ if set(result) == {"last_hidden_state"}:
315
+ return result["last_hidden_state"]
316
+ else:
317
+ return result
src/langchain_modules/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .basic_pie_document_store import BasicPieDocumentStore
2
+ from .datasets_pie_document_store import DatasetsPieDocumentStore
3
+ from .huggingface_span_embeddings import HuggingFaceSpanEmbeddings
4
+ from .pie_document_store import PieDocumentStore
5
+ from .qdrant_span_vectorstore import QdrantSpanVectorStore
6
+ from .serializable_store import SerializableStore
7
+ from .span_embeddings import SpanEmbeddings
8
+ from .span_retriever import DocumentAwareSpanRetriever, DocumentAwareSpanRetrieverWithRelations
9
+ from .span_vectorstore import SpanVectorStore
src/langchain_modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (489 Bytes). View file
 
src/langchain_modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (801 Bytes). View file
 
src/langchain_modules/__pycache__/basic_pie_document_store.cpython-39.pyc ADDED
Binary file (4.79 kB). View file
 
src/langchain_modules/__pycache__/datasets_pie_document_store.cpython-39.pyc ADDED
Binary file (7.68 kB). View file
 
src/langchain_modules/__pycache__/huggingface_span_embeddings.cpython-310.pyc ADDED
Binary file (5.31 kB). View file
 
src/langchain_modules/__pycache__/huggingface_span_embeddings.cpython-39.pyc ADDED
Binary file (7.1 kB). View file
 
src/langchain_modules/__pycache__/pie_document_store.cpython-39.pyc ADDED
Binary file (4.21 kB). View file
 
src/langchain_modules/__pycache__/qdrant_span_vectorstore.cpython-310.pyc ADDED
Binary file (6.82 kB). View file
 
src/langchain_modules/__pycache__/qdrant_span_vectorstore.cpython-39.pyc ADDED
Binary file (12.9 kB). View file
 
src/langchain_modules/__pycache__/serializable_store.cpython-39.pyc ADDED
Binary file (4.1 kB). View file
 
src/langchain_modules/__pycache__/span_embeddings.cpython-310.pyc ADDED
Binary file (3.42 kB). View file
 
src/langchain_modules/__pycache__/span_embeddings.cpython-39.pyc ADDED
Binary file (3.91 kB). View file
 
src/langchain_modules/__pycache__/span_retriever.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
src/langchain_modules/__pycache__/span_retriever.cpython-39.pyc ADDED
Binary file (28 kB). View file
 
src/langchain_modules/__pycache__/span_vectorstore.cpython-39.pyc ADDED
Binary file (5.31 kB). View file
 
src/langchain_modules/basic_pie_document_store.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import shutil
5
+ from itertools import islice
6
+ from typing import Iterator, List, Optional, Sequence, Tuple
7
+
8
+ from langchain.storage import create_kv_docstore
9
+ from langchain_core.documents import Document as LCDocument
10
+ from langchain_core.stores import BaseStore, ByteStore
11
+ from pie_datasets import Dataset, DatasetDict
12
+
13
+ from .pie_document_store import PieDocumentStore
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class BasicPieDocumentStore(PieDocumentStore):
19
+ """PIE Document store that uses a client to store and retrieve documents."""
20
+
21
+ def __init__(
22
+ self,
23
+ client: Optional[BaseStore[str, LCDocument]] = None,
24
+ byte_store: Optional[ByteStore] = None,
25
+ ):
26
+ if byte_store is not None:
27
+ client = create_kv_docstore(byte_store)
28
+ elif client is None:
29
+ raise Exception("You must pass a `byte_store` parameter.")
30
+
31
+ self.client = client
32
+
33
+ def mget(self, keys: Sequence[str]) -> List[LCDocument]:
34
+ return self.client.mget(keys)
35
+
36
+ def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
37
+ self.client.mset(items)
38
+
39
+ def mdelete(self, keys: Sequence[str]) -> None:
40
+ self.client.mdelete(keys)
41
+
42
+ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
43
+ return self.client.yield_keys(prefix=prefix)
44
+
45
+ def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
46
+ all_doc_ids = []
47
+ all_metadata = []
48
+ pie_documents_path = os.path.join(path, "pie_documents")
49
+ if os.path.exists(pie_documents_path):
50
+ # remove existing directory
51
+ logger.warning(f"Removing existing directory: {pie_documents_path}")
52
+ shutil.rmtree(pie_documents_path)
53
+ os.makedirs(pie_documents_path, exist_ok=True)
54
+ doc_ids_iter = iter(self.client.yield_keys())
55
+ while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
56
+ all_doc_ids.extend(batch_doc_ids)
57
+ docs = self.client.mget(batch_doc_ids)
58
+ pie_docs = []
59
+ for doc in docs:
60
+ pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT]
61
+ pie_docs.append(pie_doc)
62
+ all_metadata.append(
63
+ {k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
64
+ )
65
+ pie_dataset = Dataset.from_documents(pie_docs)
66
+ DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path)
67
+ if len(all_doc_ids) > 0:
68
+ doc_ids_path = os.path.join(path, "doc_ids.json")
69
+ with open(doc_ids_path, "w") as f:
70
+ json.dump(all_doc_ids, f)
71
+ if len(all_metadata) > 0:
72
+ metadata_path = os.path.join(path, "metadata.json")
73
+ with open(metadata_path, "w") as f:
74
+ json.dump(all_metadata, f)
75
+
76
+ def _load_from_directory(self, path: str, **kwargs) -> None:
77
+ pie_documents_path = os.path.join(path, "pie_documents")
78
+ if not os.path.exists(pie_documents_path):
79
+ logger.warning(
80
+ f"Directory {pie_documents_path} does not exist, don't load any documents."
81
+ )
82
+ return None
83
+ pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path)
84
+ pie_docs = pie_dataset["train"]
85
+ metadata_path = os.path.join(path, "metadata.json")
86
+ if os.path.exists(metadata_path):
87
+ with open(metadata_path, "r") as f:
88
+ all_metadata = json.load(f)
89
+ else:
90
+ logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
91
+ all_metadata = [{} for _ in pie_docs]
92
+ docs = [
93
+ self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata)
94
+ ]
95
+ doc_ids_path = os.path.join(path, "doc_ids.json")
96
+ if os.path.exists(doc_ids_path):
97
+ with open(doc_ids_path, "r") as f:
98
+ all_doc_ids = json.load(f)
99
+ else:
100
+ logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
101
+ all_doc_ids = [doc.id for doc in pie_docs]
102
+ self.client.mset(zip(all_doc_ids, docs))
103
+ logger.info(f"Loaded {len(docs)} documents from {path} into docstore")
src/langchain_modules/datasets_pie_document_store.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import shutil
5
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple
6
+
7
+ from datasets import Dataset as HFDataset
8
+ from langchain_core.documents import Document as LCDocument
9
+ from pie_datasets import Dataset, DatasetDict, concatenate_datasets
10
+ from pytorch_ie.documents import TextBasedDocument
11
+
12
+ from .pie_document_store import PieDocumentStore
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class DatasetsPieDocumentStore(PieDocumentStore):
18
+ """PIE Document store that uses Huggingface Datasets as the backend."""
19
+
20
+ def __init__(self) -> None:
21
+ self._data: Optional[Dataset] = None
22
+ # keys map to indices in the dataset
23
+ self._keys: Dict[str, int] = {}
24
+ self._metadata: Dict[str, Any] = {}
25
+
26
+ def __len__(self):
27
+ return len(self._keys)
28
+
29
+ def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]:
30
+ if self._data is None:
31
+ return []
32
+ return self._data.apply_hf_func(func=HFDataset.select, indices=indices)
33
+
34
+ def mget(self, keys: Sequence[str]) -> List[LCDocument]:
35
+ if self._data is None or len(keys) == 0:
36
+ return []
37
+ keys_in_data = [key for key in keys if key in self._keys]
38
+ indices = [self._keys[key] for key in keys_in_data]
39
+ dataset = self._get_pie_docs_by_indices(indices)
40
+ metadatas = [self._metadata.get(key, {}) for key in keys_in_data]
41
+ return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)]
42
+
43
+ def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
44
+ if len(items) == 0:
45
+ return
46
+ keys, new_docs = zip(*items)
47
+ pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs])
48
+ if self._data is None:
49
+ idx_start = 0
50
+ self._data = Dataset.from_documents(pie_docs)
51
+ else:
52
+ # we pass the features to the new dataset to mitigate issues caused by
53
+ # slightly different inferred features
54
+ dataset = Dataset.from_documents(pie_docs, features=self._data.features)
55
+ idx_start = len(self._data)
56
+ self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
57
+ keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
58
+ self._keys.update(keys_dict)
59
+ self._metadata.update(
60
+ {key: metadata for key, metadata in zip(keys, metadatas) if metadata}
61
+ )
62
+
63
+ def add_pie_dataset(
64
+ self,
65
+ dataset: Dataset,
66
+ keys: Optional[List[str]] = None,
67
+ metadatas: Optional[List[Dict[str, Any]]] = None,
68
+ ) -> None:
69
+ if len(dataset) == 0:
70
+ return
71
+ if keys is None:
72
+ keys = [doc.id for doc in dataset]
73
+ if len(keys) != len(set(keys)):
74
+ raise ValueError("Keys must be unique.")
75
+ if None in keys:
76
+ raise ValueError("Keys must not be None.")
77
+ if metadatas is None:
78
+ metadatas = [{} for _ in range(len(dataset))]
79
+ if len(keys) != len(dataset) or len(keys) != len(metadatas):
80
+ raise ValueError("Keys, dataset and metadatas must have the same length.")
81
+
82
+ if self._data is None:
83
+ idx_start = 0
84
+ self._data = dataset
85
+ else:
86
+ idx_start = len(self._data)
87
+ self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
88
+ keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
89
+ self._keys.update(keys_dict)
90
+ metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata}
91
+ self._metadata.update(metadatas_dict)
92
+
93
+ def mdelete(self, keys: Sequence[str]) -> None:
94
+ for key in keys:
95
+ idx = self._keys.pop(key, None)
96
+ if idx is not None:
97
+ self._metadata.pop(key, None)
98
+
99
+ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
100
+ return (key for key in self._keys if prefix is None or key.startswith(prefix))
101
+
102
+ def _purge_invalid_entries(self):
103
+ if self._data is None or len(self._keys) == len(self._data):
104
+ return
105
+ self._data = self._get_pie_docs_by_indices(self._keys.values())
106
+
107
+ def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
108
+ self._purge_invalid_entries()
109
+ if len(self) == 0:
110
+ logger.warning("No documents to save.")
111
+ return
112
+
113
+ all_doc_ids = list(self._keys)
114
+ all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids]
115
+ pie_documents_path = os.path.join(path, "pie_documents")
116
+ if os.path.exists(pie_documents_path):
117
+ # remove existing directory
118
+ logger.warning(f"Removing existing directory: {pie_documents_path}")
119
+ shutil.rmtree(pie_documents_path)
120
+ os.makedirs(pie_documents_path, exist_ok=True)
121
+ DatasetDict({"train": self._data}).to_json(pie_documents_path)
122
+ doc_ids_path = os.path.join(path, "doc_ids.json")
123
+ with open(doc_ids_path, "w") as f:
124
+ json.dump(all_doc_ids, f)
125
+ metadata_path = os.path.join(path, "metadata.json")
126
+ with open(metadata_path, "w") as f:
127
+ json.dump(all_metadatas, f)
128
+
129
+ def _load_from_directory(self, path: str, **kwargs) -> None:
130
+ doc_ids_path = os.path.join(path, "doc_ids.json")
131
+ if os.path.exists(doc_ids_path):
132
+ with open(doc_ids_path, "r") as f:
133
+ all_doc_ids = json.load(f)
134
+ else:
135
+ logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
136
+ all_doc_ids = None
137
+ metadata_path = os.path.join(path, "metadata.json")
138
+ if os.path.exists(metadata_path):
139
+ with open(metadata_path, "r") as f:
140
+ all_metadata = json.load(f)
141
+ else:
142
+ logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
143
+ all_metadata = None
144
+ pie_documents_path = os.path.join(path, "pie_documents")
145
+ if not os.path.exists(pie_documents_path):
146
+ logger.warning(
147
+ f"Directory {pie_documents_path} does not exist, don't load any documents."
148
+ )
149
+ return None
150
+ # If we have a dataset already loaded, we use its features to load the new dataset
151
+ # This is to mitigate issues caused by slightly different inferred features.
152
+ features = self._data.features if self._data is not None else None
153
+ pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features)
154
+ pie_docs = pie_dataset["train"]
155
+ self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata)
156
+ logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore")
src/langchain_modules/huggingface_span_embeddings.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Union # type: ignore[import-not-found]
3
+
4
+ import torch
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+ from transformers import pipeline
7
+
8
+ from ..hf_pipeline import FeatureExtractionPipelineWithStriding
9
+ from .span_embeddings import SpanEmbeddings
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ DEFAULT_MODEL_NAME = "allenai/specter2_base"
14
+
15
+
16
+ class HuggingFaceSpanEmbeddings(BaseModel, SpanEmbeddings):
17
+ """An implementation of SpanEmbeddings using a modified HuggingFace Transformers
18
+ feature-extraction pipeline, adapted for long text inputs by chunking with optional stride
19
+ (see src.hf_pipeline.FeatureExtractionPipelineWithStriding).
20
+
21
+ Note that calculating embeddings for multiple spans is efficient when all spans for a
22
+ text are passed in a single call to embed_document_spans, as the text embedding is computed
23
+ only once per unique text, and the span embeddings are simply pooled from these text embeddings.
24
+
25
+ It accepts any model that can be used with the HuggingFace feature-extraction pipeline, also
26
+ models with adapters such as SPECTER2 (see https://huggingface.co/allenai/specter2). In this case,
27
+ the model should be loaded beforehand and passed as parameter 'model' instead of the model identifier.
28
+ See https://huggingface.co/docs/transformers/main_classes/pipelines for further information.
29
+
30
+ To use, you should have the ``transformers`` python package installed.
31
+
32
+ Example:
33
+ .. code-block:: python
34
+ from transformers import AutoModel
35
+
36
+ model = "allenai/specter2_base"
37
+ pipeline_kwargs = {'device': 'cpu', 'stride': 64, 'batch_size': 8}
38
+ encode_kwargs = {'normalize_embeddings': False}
39
+ hf = HuggingFaceSpanEmbeddings(
40
+ model=model,
41
+ pipeline_kwargs=pipeline_kwargs,
42
+ )
43
+
44
+ text = "This is a test sentence."
45
+
46
+ # calculate embeddings for text[0:4]="This" and text[15:23]="sentence"
47
+ embeddings = hf.embed_document_spans(texts=[text, text], starts=[0, 11], ends=[4, 19])
48
+ """
49
+
50
+ client: Any = None #: :meta private:
51
+ model: Optional[Any] = DEFAULT_MODEL_NAME
52
+ pooling_strategy: str = "mean"
53
+ """Model name to use."""
54
+ pipeline_kwargs: Dict[str, Any] = Field(default_factory=dict)
55
+ """Keyword arguments to pass to the Huggingface pipeline constructor."""
56
+ encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
57
+ """Keyword arguments to pass when calling the pipeline."""
58
+ # show_progress: bool = False
59
+ """Whether to show a progress bar."""
60
+ model_max_length: Optional[int] = None
61
+ """The maximum input length of the model. Required for some model checkpoints with outdated configs."""
62
+
63
+ def __init__(self, **kwargs: Any):
64
+ """Initialize the sentence_transformer."""
65
+ super().__init__(**kwargs)
66
+
67
+ self.client = pipeline(
68
+ "feature-extraction",
69
+ model=self.model,
70
+ pipeline_class=FeatureExtractionPipelineWithStriding,
71
+ trust_remote_code=True,
72
+ **self.pipeline_kwargs,
73
+ )
74
+
75
+ # The Transformers library is buggy since 4.40.0,
76
+ # see https://github.com/huggingface/transformers/issues/30643,
77
+ # so we need to set the max_length to e.g. 512 manually
78
+ if self.model_max_length is not None:
79
+ self.client.tokenizer.model_max_length = self.model_max_length
80
+
81
+ # Check if the model has a valid max length
82
+ max_input_size = self.client.tokenizer.model_max_length
83
+ if max_input_size > 1e5: # A high threshold to catch "unlimited" values
84
+ raise ValueError(
85
+ "The tokenizer does not specify a valid `model_max_length` attribute. "
86
+ "Consider setting it manually by passing `model_max_length` to the "
87
+ "HuggingFaceSpanEmbeddings constructor."
88
+ )
89
+
90
+ model_config = ConfigDict(
91
+ extra="forbid",
92
+ protected_namespaces=(),
93
+ )
94
+
95
+ def get_span_embedding(
96
+ self,
97
+ last_hidden_state: torch.Tensor,
98
+ offset_mapping: torch.Tensor,
99
+ start: Union[int, List[int]],
100
+ end: Union[int, List[int]],
101
+ **unused_kwargs,
102
+ ) -> Optional[List[float]]:
103
+ """Pool the span embeddings."""
104
+ if isinstance(start, int):
105
+ start = [start]
106
+ if isinstance(end, int):
107
+ end = [end]
108
+ if len(start) != len(end):
109
+ raise ValueError("start and end should have the same length.")
110
+ if len(start) == 0:
111
+ raise ValueError("start and end should not be empty.")
112
+ if last_hidden_state.shape[0] != 1:
113
+ raise ValueError("last_hidden_state should have a batch size of 1.")
114
+ if last_hidden_state.shape[0] != offset_mapping.shape[0]:
115
+ raise ValueError(
116
+ "last_hidden_state and offset_mapping should have the same batch size."
117
+ )
118
+ offset_mapping = offset_mapping[0]
119
+ last_hidden_state = last_hidden_state[0]
120
+
121
+ mask = (start[0] <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= end[0])
122
+ for s, e in zip(start[1:], end[1:]):
123
+ mask = mask | ((s <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= e))
124
+ span_embeddings = last_hidden_state[mask]
125
+ if span_embeddings.shape[0] == 0:
126
+ return None
127
+ if self.pooling_strategy == "mean":
128
+ return span_embeddings.mean(dim=0).tolist()
129
+ elif self.pooling_strategy == "max":
130
+ return span_embeddings.max(dim=0).values.tolist()
131
+ else:
132
+ raise ValueError(f"Unknown pool strategy: {self.pooling_strategy}")
133
+
134
+ def embed_document_spans(
135
+ self,
136
+ texts: List[str],
137
+ starts: Union[List[int], List[List[int]]],
138
+ ends: Union[List[int], List[List[int]]],
139
+ ) -> List[Optional[List[float]]]:
140
+ """Compute doc embeddings using a HuggingFace transformer model.
141
+
142
+ Args:
143
+ texts: The list of texts to embed.
144
+ starts: The list of start indices or list of lists of start indices (multi-span).
145
+ ends: The list of end indices or list of lists of end indices (multi-span).
146
+
147
+ Returns:
148
+ List of embeddings, one for each text.
149
+ """
150
+ pipeline_kwargs = self.encode_kwargs.copy()
151
+ pipeline_kwargs["return_offset_mapping"] = True
152
+ # we enable long text handling by providing the stride parameter
153
+ if pipeline_kwargs.get("stride", None) is None:
154
+ pipeline_kwargs["stride"] = 0
155
+ # when stride is positive, we need to create unique embeddings per token
156
+ if pipeline_kwargs["stride"] > 0:
157
+ pipeline_kwargs["create_unique_embeddings_per_token"] = True
158
+ # we ask for tensors to efficiently compute the span embeddings
159
+ pipeline_kwargs["return_tensors"] = True
160
+
161
+ unique_texts = sorted(set(texts))
162
+ idx2unique_idx = {i: unique_texts.index(text) for i, text in enumerate(texts)}
163
+ pipeline_results = self.client(unique_texts, **pipeline_kwargs)
164
+ embeddings = [
165
+ self.get_span_embedding(
166
+ start=starts[idx], end=ends[idx], **pipeline_results[idx2unique_idx[idx]]
167
+ )
168
+ for idx in range(len(texts))
169
+ ]
170
+ return embeddings
171
+
172
+ def embed_query_span(
173
+ self, text: str, start: Union[int, List[int]], end: Union[int, List[int]]
174
+ ) -> Optional[List[float]]:
175
+ """Compute query embeddings using a HuggingFace transformer model.
176
+
177
+ Args:
178
+ text: The text to embed.
179
+ start: The start index or list of start indices (multi-span).
180
+ end: The end index or list of end indices (multi-span).
181
+
182
+ Returns:
183
+ Embeddings for the text.
184
+ """
185
+ starts: Union[List[int], List[List[int]]] = [start] # type: ignore[assignment]
186
+ ends: Union[List[int], List[List[int]]] = [end] # type: ignore[assignment]
187
+ return self.embed_document_spans([text], starts=starts, ends=ends)[0]
188
+
189
+ @property
190
+ def embedding_dim(self) -> int:
191
+ """Get the embedding dimension."""
192
+ return self.client.model.config.hidden_size
src/langchain_modules/pie_document_store.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ from copy import copy
4
+ from typing import Iterator, List, Optional, Sequence, Tuple
5
+
6
+ import pandas as pd
7
+ from langchain_core.documents import Document as LCDocument
8
+ from langchain_core.stores import BaseStore
9
+ from pytorch_ie.documents import TextBasedDocument
10
+
11
+ from .serializable_store import SerializableStore
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC):
17
+ """Abstract base class for document stores specialized in storing and retrieving pie documents."""
18
+
19
+ METADATA_KEY_PIE_DOCUMENT: str = "pie_document"
20
+ """Key for the pie document in the (langchain) document metadata."""
21
+
22
+ def wrap(self, pie_document: TextBasedDocument, **metadata) -> LCDocument:
23
+ """Wrap the pie document in an LCDocument."""
24
+ return LCDocument(
25
+ id=pie_document.id,
26
+ page_content="",
27
+ metadata={self.METADATA_KEY_PIE_DOCUMENT: pie_document, **metadata},
28
+ )
29
+
30
+ def unwrap(self, document: LCDocument) -> TextBasedDocument:
31
+ """Get the pie document from the langchain document."""
32
+ return document.metadata[self.METADATA_KEY_PIE_DOCUMENT]
33
+
34
+ def unwrap_with_metadata(self, document: LCDocument) -> Tuple[TextBasedDocument, dict]:
35
+ """Get the pie document and metadata from the langchain document."""
36
+ metadata = copy(document.metadata)
37
+ pie_document = metadata.pop(self.METADATA_KEY_PIE_DOCUMENT)
38
+ return pie_document, metadata
39
+
40
+ @abc.abstractmethod
41
+ def mget(self, keys: Sequence[str]) -> List[LCDocument]:
42
+ pass
43
+
44
+ @abc.abstractmethod
45
+ def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
46
+ pass
47
+
48
+ @abc.abstractmethod
49
+ def mdelete(self, keys: Sequence[str]) -> None:
50
+ pass
51
+
52
+ @abc.abstractmethod
53
+ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
54
+ pass
55
+
56
+ def __len__(self):
57
+ return len(list(self.yield_keys()))
58
+
59
+ def overview(self, layer_captions: dict, use_predictions: bool = False) -> pd.DataFrame:
60
+ """Get an overview of the document store, including the number of items in each layer for each document
61
+ in the store.
62
+
63
+ Args:
64
+ layer_captions: A dictionary mapping layer names to captions.
65
+ use_predictions: Whether to use predictions instead of the actual layers.
66
+
67
+ Returns:
68
+ DataFrame: A pandas DataFrame containing the overview.
69
+ """
70
+ rows = []
71
+ for doc_id in self.yield_keys():
72
+ document = self.mget([doc_id])[0]
73
+ pie_document = self.unwrap(document)
74
+ layers = {
75
+ caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
76
+ }
77
+ layer_sizes = {
78
+ f"num_{caption}s": len(layer) + (len(layer.predictions) if use_predictions else 0)
79
+ for caption, layer in layers.items()
80
+ }
81
+ rows.append({"doc_id": doc_id, **layer_sizes})
82
+ df = pd.DataFrame(rows)
83
+ return df
84
+
85
+ def as_dict(self, document: LCDocument) -> dict:
86
+ """Convert the langchain document to a dictionary."""
87
+ pie_document, metadata = self.unwrap_with_metadata(document)
88
+ return {self.METADATA_KEY_PIE_DOCUMENT: pie_document.asdict(), "metadata": metadata}
src/langchain_modules/qdrant_span_vectorstore.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import uuid
7
+ from collections import defaultdict
8
+ from itertools import islice
9
+ from typing import ( # type: ignore[import-not-found]
10
+ Any,
11
+ Dict,
12
+ Generator,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Optional,
17
+ Sequence,
18
+ Tuple,
19
+ Union,
20
+ )
21
+
22
+ import numpy as np
23
+ from langchain_core.documents import Document as LCDocument
24
+ from langchain_qdrant import QdrantVectorStore, RetrievalMode
25
+ from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
26
+ from qdrant_client import QdrantClient, models
27
+ from qdrant_client.http.models import Record
28
+
29
+ from .span_embeddings import SpanEmbeddings
30
+ from .span_vectorstore import SpanVectorStore
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class QdrantSpanVectorStore(SpanVectorStore, QdrantVectorStore):
36
+ """An implementation of the SpanVectorStore interface that uses Qdrant
37
+ as backend for storing and retrieving span embeddings."""
38
+
39
+ EMBEDDINGS_FILE = "embeddings.npy"
40
+ PAYLOADS_FILE = "payloads.json"
41
+ INDEX_FILE = "index.json"
42
+
43
+ def __init__(
44
+ self,
45
+ client: QdrantClient,
46
+ collection_name: str,
47
+ embedding: SpanEmbeddings,
48
+ vector_params: Optional[Dict[str, Any]] = None,
49
+ **kwargs,
50
+ ):
51
+ if not client.collection_exists(collection_name):
52
+ logger.info(f'Collection "{collection_name}" does not exist. Creating it now.')
53
+ client.create_collection(
54
+ collection_name=collection_name,
55
+ vectors_config=models.VectorParams(size=embedding.embedding_dim, **vector_params),
56
+ )
57
+ else:
58
+ logger.info(f'Collection "{collection_name}" already exists.')
59
+ super().__init__(
60
+ client=client, collection_name=collection_name, embedding=embedding, **kwargs
61
+ )
62
+
63
+ def __len__(self):
64
+ return self.client.get_collection(collection_name=self.collection_name).points_count
65
+
66
+ def get_by_ids_with_vectors(self, ids: Sequence[str | int], /) -> List[LCDocument]:
67
+ results = self.client.retrieve(
68
+ self.collection_name, ids, with_payload=True, with_vectors=True
69
+ )
70
+
71
+ return [
72
+ self._document_from_point(
73
+ result,
74
+ self.collection_name,
75
+ self.content_payload_key,
76
+ self.metadata_payload_key,
77
+ )
78
+ for result in results
79
+ ]
80
+
81
+ def construct_filter(
82
+ self,
83
+ query_span: Union[Span, MultiSpan],
84
+ metadata_doc_id_key: str,
85
+ doc_id_whitelist: Optional[Sequence[str]] = None,
86
+ doc_id_blacklist: Optional[Sequence[str]] = None,
87
+ ) -> Optional[models.Filter]:
88
+ """Construct a filter for the retrieval. It should enforce that:
89
+ - if the span is labeled, the retrieved span has the same label, or
90
+ - if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
91
+ - if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
92
+ - if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist
93
+
94
+ Args:
95
+ query_span: The query span.
96
+ metadata_doc_id_key: The key in the metadata that holds the document id.
97
+ doc_id_whitelist: A list of document ids to restrict the retrieval to.
98
+ doc_id_blacklist: A list of document ids to exclude from the retrieval.
99
+
100
+ Returns:
101
+ A filter object.
102
+ """
103
+ filter_kwargs = defaultdict(list)
104
+ # if the span is labeled, enforce that the retrieved span has the same label
105
+ if isinstance(query_span, (LabeledSpan, LabeledMultiSpan)):
106
+ if self.label_mapping is not None:
107
+ target_labels = self.label_mapping.get(query_span.label, [])
108
+ else:
109
+ target_labels = [query_span.label]
110
+ filter_kwargs["must"].append(
111
+ models.FieldCondition(
112
+ key=f"metadata.{self.METADATA_SPAN_KEY}.label",
113
+ match=models.MatchAny(any=target_labels),
114
+ )
115
+ )
116
+ elif self.label_mapping is not None:
117
+ raise TypeError("Label mapping is only supported for labeled spans")
118
+
119
+ if doc_id_blacklist is not None and doc_id_whitelist is not None:
120
+ overlap = set(doc_id_whitelist) & set(doc_id_blacklist)
121
+ if len(overlap) > 0:
122
+ raise ValueError(
123
+ f"Overlap between doc_id_whitelist and doc_id_blacklist: {overlap}"
124
+ )
125
+
126
+ if doc_id_whitelist is not None:
127
+ filter_kwargs["must"].append(
128
+ models.FieldCondition(
129
+ key=f"metadata.{metadata_doc_id_key}",
130
+ match=(
131
+ models.MatchValue(value=doc_id_whitelist[0])
132
+ if len(doc_id_whitelist) == 1
133
+ else models.MatchAny(any=doc_id_whitelist)
134
+ ),
135
+ )
136
+ )
137
+ if doc_id_blacklist is not None:
138
+ filter_kwargs["must_not"].append(
139
+ models.FieldCondition(
140
+ key=f"metadata.{metadata_doc_id_key}",
141
+ match=(
142
+ models.MatchValue(value=doc_id_blacklist[0])
143
+ if len(doc_id_blacklist) == 1
144
+ else models.MatchAny(any=doc_id_blacklist)
145
+ ),
146
+ )
147
+ )
148
+ if len(filter_kwargs) > 0:
149
+ return models.Filter(**filter_kwargs)
150
+ return None
151
+
152
+ @classmethod
153
+ def _document_from_point(
154
+ cls,
155
+ scored_point: Any,
156
+ collection_name: str,
157
+ content_payload_key: str,
158
+ metadata_payload_key: str,
159
+ ) -> LCDocument:
160
+ metadata = scored_point.payload.get(metadata_payload_key) or {}
161
+ metadata["_collection_name"] = collection_name
162
+ if hasattr(scored_point, "score"):
163
+ metadata[cls.RELEVANCE_SCORE_KEY] = scored_point.score
164
+ if hasattr(scored_point, "vector"):
165
+ metadata[cls.METADATA_VECTOR_KEY] = scored_point.vector
166
+ return LCDocument(
167
+ id=scored_point.id,
168
+ page_content=scored_point.payload.get(content_payload_key, ""),
169
+ metadata=metadata,
170
+ )
171
+
172
+ def _build_vectors_with_metadata(
173
+ self,
174
+ texts: Iterable[str],
175
+ metadatas: List[dict],
176
+ ) -> List[models.VectorStruct]:
177
+ starts = [metadata[self.METADATA_SPAN_KEY][self.SPAN_START_KEY] for metadata in metadatas]
178
+ ends = [metadata[self.METADATA_SPAN_KEY][self.SPAN_END_KEY] for metadata in metadatas]
179
+ if self.retrieval_mode == RetrievalMode.DENSE:
180
+ batch_embeddings = self.embeddings.embed_document_spans(list(texts), starts, ends)
181
+ return [
182
+ {
183
+ self.vector_name: vector,
184
+ }
185
+ for vector in batch_embeddings
186
+ ]
187
+
188
+ elif self.retrieval_mode == RetrievalMode.SPARSE:
189
+ raise ValueError("Sparse retrieval mode is not yet implemented.")
190
+
191
+ elif self.retrieval_mode == RetrievalMode.HYBRID:
192
+ raise NotImplementedError("Hybrid retrieval mode is not yet implemented.")
193
+ else:
194
+ raise ValueError(f"Unknown retrieval mode. {self.retrieval_mode} to build vectors.")
195
+
196
+ def _build_payloads_from_metadata(
197
+ self,
198
+ metadatas: Iterable[dict],
199
+ metadata_payload_key: str,
200
+ ) -> List[dict]:
201
+ payloads = [{metadata_payload_key: metadata} for metadata in metadatas]
202
+
203
+ return payloads
204
+
205
+ def _generate_batches(
206
+ self,
207
+ texts: Iterable[str],
208
+ metadatas: Optional[List[dict]] = None,
209
+ ids: Optional[Sequence[str | int]] = None,
210
+ batch_size: int = 64,
211
+ ) -> Generator[tuple[list[str | int], list[models.PointStruct]], Any, None]:
212
+ """Generate batches of points to index. Same as in `QdrantVectorStore` but metadata is used
213
+ to build vectors and payloads."""
214
+
215
+ texts_iterator = iter(texts)
216
+ if metadatas is None:
217
+ raise ValueError("Metadata must be provided to generate batches.")
218
+ metadatas_iterator = iter(metadatas)
219
+ ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
220
+
221
+ while batch_texts := list(islice(texts_iterator, batch_size)):
222
+ batch_metadatas = list(islice(metadatas_iterator, batch_size))
223
+ batch_ids = list(islice(ids_iterator, batch_size))
224
+ points = [
225
+ models.PointStruct(
226
+ id=point_id,
227
+ vector=vector,
228
+ payload=payload,
229
+ )
230
+ for point_id, vector, payload in zip(
231
+ batch_ids,
232
+ self._build_vectors_with_metadata(batch_texts, metadatas=batch_metadatas),
233
+ # we do not save the text in the payload because the text is the full
234
+ # document which is usually already saved in the docstore
235
+ self._build_payloads_from_metadata(
236
+ metadatas=batch_metadatas,
237
+ metadata_payload_key=self.metadata_payload_key,
238
+ ),
239
+ )
240
+ if vector[self.vector_name] is not None
241
+ ]
242
+
243
+ yield [point.id for point in points], points
244
+
245
+ def similarity_search_with_score_by_vector(
246
+ self,
247
+ embedding: List[float],
248
+ k: int = 4,
249
+ filter: Optional[models.Filter] = None,
250
+ search_params: Optional[models.SearchParams] = None,
251
+ offset: int = 0,
252
+ score_threshold: Optional[float] = None,
253
+ consistency: Optional[models.ReadConsistency] = None,
254
+ **kwargs: Any,
255
+ ) -> List[Tuple[LCDocument, float]]:
256
+ """Return docs most similar to query vector.
257
+
258
+ Returns:
259
+ List of documents most similar to the query text and distance for each.
260
+ """
261
+ query_options = {
262
+ "collection_name": self.collection_name,
263
+ "query_filter": filter,
264
+ "search_params": search_params,
265
+ "limit": k,
266
+ "offset": offset,
267
+ "with_payload": True,
268
+ "with_vectors": False,
269
+ "score_threshold": score_threshold,
270
+ "consistency": consistency,
271
+ **kwargs,
272
+ }
273
+
274
+ results = self.client.query_points(
275
+ query=embedding,
276
+ using=self.vector_name,
277
+ **query_options,
278
+ ).points
279
+
280
+ return [
281
+ (
282
+ self._document_from_point(
283
+ result,
284
+ self.collection_name,
285
+ self.content_payload_key,
286
+ self.metadata_payload_key,
287
+ ),
288
+ result.score,
289
+ )
290
+ for result in results
291
+ ]
292
+
293
+ def _as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[Any]]:
294
+ data, _ = self.client.scroll(
295
+ collection_name=self.collection_name, with_vectors=True, limit=len(self)
296
+ )
297
+ vectors_np = np.array([point.vector for point in data])
298
+ payloads = [point.payload for point in data]
299
+ emb_ids = [point.id for point in data]
300
+ return emb_ids, vectors_np, payloads
301
+
302
+ # TODO: or use create_snapshot and restore_snapshot?
303
+ def _save_to_directory(self, path: str, **kwargs) -> None:
304
+ indices, vectors, payloads = self._as_indices_vectors_payloads()
305
+ np.save(os.path.join(path, self.EMBEDDINGS_FILE), vectors)
306
+ with open(os.path.join(path, self.PAYLOADS_FILE), "w") as f:
307
+ json.dump(payloads, f, indent=2)
308
+ with open(os.path.join(path, self.INDEX_FILE), "w") as f:
309
+ json.dump(indices, f)
310
+
311
+ def _load_from_directory(self, path: str, **kwargs) -> None:
312
+ with open(os.path.join(path, self.INDEX_FILE), "r") as f:
313
+ index = json.load(f)
314
+ embeddings_np: np.ndarray = np.load(os.path.join(path, self.EMBEDDINGS_FILE))
315
+ with open(os.path.join(path, self.PAYLOADS_FILE), "r") as f:
316
+ payloads = json.load(f)
317
+ points = models.Batch(ids=index, vectors=embeddings_np.tolist(), payloads=payloads)
318
+ self.client.upsert(
319
+ collection_name=self.collection_name,
320
+ points=points,
321
+ )
322
+ logger.info(f"Loaded {len(index)} points into collection {self.collection_name}.")
323
+
324
+ def mget(self, keys: Sequence[str]) -> list[Optional[Record]]:
325
+ return self.client.retrieve(
326
+ self.collection_name, ids=keys, with_payload=True, with_vectors=True
327
+ )
328
+
329
+ def mset(self, key_value_pairs: Sequence[tuple[str, Record]]) -> None:
330
+ self.client.upsert(
331
+ collection_name=self.collection_name,
332
+ points=models.Batch(
333
+ ids=[key for key, _ in key_value_pairs],
334
+ vectors=[value.vector for _, value in key_value_pairs],
335
+ payloads=[value.payload for _, value in key_value_pairs],
336
+ ),
337
+ )
338
+
339
+ def mdelete(self, keys: Sequence[str]) -> None:
340
+ self.client.delete(collection_name=self.collection_name, points_selector=keys)
341
+
342
+ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
343
+ for point in self.client.scroll(
344
+ collection_name=self.collection_name,
345
+ with_vectors=False,
346
+ with_payload=False,
347
+ limit=len(self),
348
+ )[0]:
349
+ yield point.id
src/langchain_modules/serializable_store.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ from abc import ABC, abstractmethod
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class SerializableStore(ABC):
11
+ """Abstract base class for serializable stores."""
12
+
13
+ @abstractmethod
14
+ def _save_to_directory(self, path: str, **kwargs) -> None:
15
+ """Save the store to a directory.
16
+
17
+ Args:
18
+ path (str): The path to a directory.
19
+ """
20
+
21
+ @abstractmethod
22
+ def _load_from_directory(self, path: str, **kwargs) -> None:
23
+ """Load the store from a directory.
24
+
25
+ Args:
26
+ path (str): The path to the file.
27
+ """
28
+
29
+ def save_to_directory(self, path: str, **kwargs) -> None:
30
+ """Save the store to a directory.
31
+
32
+ Args:
33
+ path (str): The path to a directory.
34
+ """
35
+ os.makedirs(path, exist_ok=True)
36
+ self._save_to_directory(path, **kwargs)
37
+
38
+ def load_from_directory(self, path: str, **kwargs) -> None:
39
+ """Load the store from a directory.
40
+
41
+ Args:
42
+ path (str): The path to the file.
43
+ """
44
+ if not os.path.exists(path):
45
+ raise FileNotFoundError(f'Directory "{path}" not found.')
46
+
47
+ self._load_from_directory(path, **kwargs)
48
+
49
+ def save_to_archive(
50
+ self,
51
+ base_name: str,
52
+ format: str,
53
+ ) -> str:
54
+ """Save the store to an archive.
55
+
56
+ Args:
57
+ base_name (str): The base name of the archive.
58
+ format (str): The format of the archive.
59
+
60
+ Returns:
61
+ str: The path to the archive.
62
+ """
63
+ temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
64
+ # remove the temporary directory if it already exists
65
+ if os.path.exists(temp_dir):
66
+ shutil.rmtree(temp_dir)
67
+ # save the documents to the directory
68
+ self.save_to_directory(temp_dir)
69
+ # zip the directory
70
+ result_file_path = shutil.make_archive(
71
+ base_name=base_name, root_dir=temp_dir, format=format
72
+ )
73
+ # remove the temporary directory
74
+ shutil.rmtree(temp_dir)
75
+ return result_file_path
76
+
77
+ def load_from_archive(
78
+ self,
79
+ file_name: str,
80
+ ) -> None:
81
+ """Load the store from an archive.
82
+
83
+ Args:
84
+ file_name (str): The path to the archive.
85
+ """
86
+ if not os.path.exists(file_name):
87
+ raise FileNotFoundError(f'Archive file "{file_name}" not found.')
88
+
89
+ temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
90
+ # remove the temporary directory if it already exists
91
+ if os.path.exists(temp_dir):
92
+ shutil.rmtree(temp_dir)
93
+
94
+ # unzip the file
95
+ shutil.unpack_archive(file_name, temp_dir)
96
+ # load the documents from the directory
97
+ self.load_from_directory(temp_dir)
98
+ # remove the temporary directory
99
+ shutil.rmtree(temp_dir)
100
+
101
+ def save_to_disc(self, path: str) -> None:
102
+ """Save the store to disc. Depending on the path, this can be a directory or an archive.
103
+
104
+ Args:
105
+ path (str): The path to a directory or an archive.
106
+ """
107
+ # if it is a zip file, save to archive
108
+ if path.lower().endswith(".zip"):
109
+ logger.info(f"Saving to archive at {path} ...")
110
+ # get base name without extension
111
+ base_name = os.path.splitext(path)[0]
112
+ result_path = self.save_to_archive(base_name, format="zip")
113
+ if not result_path.endswith(path):
114
+ logger.warning(f"Saved to {result_path} instead of {path}.")
115
+ # if it does not have an extension, save to directory
116
+ elif not os.path.splitext(path)[1]:
117
+ logger.info(f"Saving to directory at {path} ...")
118
+ self.save_to_directory(path)
119
+ else:
120
+ raise ValueError("Unsupported file format. Only .zip and directories are supported.")
121
+
122
+ def load_from_disc(self, path: str) -> None:
123
+ """Load the store from disc. Depending on the path, this can be a directory or an archive.
124
+
125
+ Args:
126
+ path (str): The path to a directory or an archive.
127
+ """
128
+ # if it is a zip file, load from archive
129
+ if path.lower().endswith(".zip"):
130
+ logger.info(f"Loading from archive at {path} ...")
131
+ self.load_from_archive(path)
132
+ # if it is a directory, load from directory
133
+ elif os.path.isdir(path):
134
+ logger.info(f"Loading from directory at {path} ...")
135
+ self.load_from_directory(path)
136
+ else:
137
+ raise ValueError("Unsupported file format. Only .zip and directories are supported.")
src/langchain_modules/span_embeddings.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Optional, Union # type: ignore[import-not-found]
3
+
4
+ from langchain_core.embeddings import Embeddings
5
+ from langchain_core.runnables.config import run_in_executor
6
+
7
+
8
+ class SpanEmbeddings(Embeddings, ABC):
9
+ """Interface for models that embed text spans within documents."""
10
+
11
+ @abstractmethod
12
+ def embed_document_spans(
13
+ self,
14
+ texts: list[str],
15
+ starts: Union[list[int], List[List[int]]],
16
+ ends: Union[list[int], List[List[int]]],
17
+ ) -> list[Optional[list[float]]]:
18
+ """Embed search docs.
19
+
20
+ Args:
21
+ texts: List of text to embed.
22
+ starts: List of start indices or list of lists of start indices (multi-span).
23
+ ends: List of end indices or list of lists of end indices (multi-span).
24
+
25
+ Returns:
26
+ List of embeddings.
27
+ """
28
+
29
+ @abstractmethod
30
+ def embed_query_span(
31
+ self, text: str, start: Union[int, list[int]], end: Union[int, list[int]]
32
+ ) -> Optional[list[float]]:
33
+ """Embed query text.
34
+
35
+ Args:
36
+ text: Text to embed.
37
+ start: Start index or list of start indices (multi-span).
38
+ end: End index or list of end indices (multi-span).
39
+
40
+ Returns:
41
+ Embedding.
42
+ """
43
+
44
+ def embed_documents(self, texts: list[str]) -> list[Optional[list[float]]]:
45
+ """Embed search docs.
46
+
47
+ Args:
48
+ texts: List of text to embed.
49
+
50
+ Returns:
51
+ List of embeddings.
52
+ """
53
+ return self.embed_document_spans(texts, [0] * len(texts), [len(text) for text in texts])
54
+
55
+ def embed_query(self, text: str) -> Optional[list[float]]:
56
+ """Embed query text.
57
+
58
+ Args:
59
+ text: Text to embed.
60
+
61
+ Returns:
62
+ Embedding.
63
+ """
64
+ return self.embed_query_span(text, 0, len(text))
65
+
66
+ async def aembed_document_spans(
67
+ self,
68
+ texts: list[str],
69
+ starts: Union[list[int], list[list[int]]],
70
+ ends: Union[list[int], list[list[int]]],
71
+ ) -> list[Optional[list[float]]]:
72
+ """Asynchronous Embed search docs.
73
+
74
+ Args:
75
+ texts: List of text to embed.
76
+ starts: List of start indices or list of lists of start indices (multi-span).
77
+ ends: List of end indices or list of lists of end indices (multi-span).
78
+
79
+ Returns:
80
+ List of embeddings.
81
+ """
82
+ return await run_in_executor(None, self.embed_document_spans, texts, starts, ends)
83
+
84
+ async def aembed_query_spans(
85
+ self, text: str, start: Union[int, list[int]], end: Union[int, list[int]]
86
+ ) -> Optional[list[float]]:
87
+ """Asynchronous Embed query text.
88
+
89
+ Args:
90
+ text: Text to embed.
91
+ start: Start index or list of start indices (multi-span).
92
+ end: End index or list of end indices (multi-span).
93
+
94
+ Returns:
95
+ Embedding.
96
+ """
97
+ return await run_in_executor(None, self.embed_query_span, text, start, end)
98
+
99
+ @property
100
+ @abstractmethod
101
+ def embedding_dim(self) -> int:
102
+ """Get the embedding dimension."""
103
+ ...
src/langchain_modules/span_retriever.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import uuid
4
+ from collections import defaultdict
5
+ from copy import copy
6
+ from enum import Enum
7
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
8
+
9
+ from langchain_core.callbacks import (
10
+ AsyncCallbackManagerForRetrieverRun,
11
+ CallbackManagerForRetrieverRun,
12
+ )
13
+ from langchain_core.documents import BaseDocumentCompressor
14
+ from langchain_core.documents import Document as LCDocument
15
+ from langchain_core.retrievers import BaseRetriever
16
+ from pydantic import Field
17
+ from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span
18
+ from pytorch_ie.core.document import BaseAnnotationList
19
+ from pytorch_ie.documents import (
20
+ TextBasedDocument,
21
+ TextDocumentWithLabeledMultiSpans,
22
+ TextDocumentWithLabeledSpans,
23
+ TextDocumentWithSpans,
24
+ )
25
+
26
+ from .pie_document_store import PieDocumentStore
27
+ from .serializable_store import SerializableStore
28
+ from .span_vectorstore import SpanVectorStore
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def _parse_config(config_string: str, format: str) -> Dict[str, Any]:
34
+ """Parse a configuration string."""
35
+ if format == "json":
36
+ import json
37
+
38
+ return json.loads(config_string)
39
+ elif format == "yaml":
40
+ import yaml
41
+
42
+ return yaml.safe_load(config_string)
43
+ else:
44
+ raise ValueError(f"Unsupported format: {format}. Use 'json' or 'yaml'.")
45
+
46
+
47
+ METADATA_KEY_CHILD_ID2IDX = "child_id2idx"
48
+
49
+
50
+ class SpanNotFoundError(ValueError):
51
+ def __init__(self, span_id: str, doc_id: Optional[str] = None, msg: Optional[str] = None):
52
+ if msg is None:
53
+ if doc_id is not None:
54
+ msg = f"Span with id [{span_id}] not found in document [{doc_id}]"
55
+ else:
56
+ msg = f"Span with id [{span_id}] not found in the vectorstore"
57
+ super().__init__(msg)
58
+ self.span_id = span_id
59
+ self.doc_id = doc_id
60
+
61
+
62
+ class DocumentNotFoundError(ValueError):
63
+ def __init__(self, doc_id: str, msg: Optional[str] = None):
64
+ msg = msg or f"Document with id [{doc_id}] not found in the docstore"
65
+ super().__init__(msg)
66
+ self.doc_id = doc_id
67
+
68
+
69
+ class SearchType(str, Enum):
70
+ """Enumerator of the types of search to perform."""
71
+
72
+ similarity = "similarity"
73
+ """Similarity search."""
74
+ similarity_score_threshold = "similarity_score_threshold"
75
+ """Similarity search with a score threshold."""
76
+ mmr = "mmr"
77
+ """Maximal Marginal Relevance reranking of similarity search."""
78
+
79
+
80
+ class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore):
81
+ """Retriever for contextualized text spans, i.e. spans within text documents.
82
+ It accepts spans as queries and retrieves spans with their containing document.
83
+ Note that the query span (and its document) must already be in the retriever's
84
+ store."""
85
+
86
+ pie_document_type: Type[TextBasedDocument]
87
+ """The name of the span annotation layer in the pie document."""
88
+ use_predicted_annotations_key: str = "use_predicted_annotations"
89
+ """Whether to use the predicted annotations or the gold annotations."""
90
+ retrieve_from_same_document: bool = False
91
+ """Whether to retrieve spans exclusively from the same document as the query span."""
92
+ retrieve_from_different_documents: bool = False
93
+ """Whether to retrieve spans exclusively from different documents than the query span."""
94
+
95
+ # content from langchain_core.retrievers.MultiVectorRetriever
96
+ vectorstore: SpanVectorStore
97
+ """The underlying vectorstore to use to store small chunks
98
+ and their embedding vectors"""
99
+ docstore: PieDocumentStore
100
+ """The storage interface for the parent documents"""
101
+ id_key: str = "doc_id"
102
+ """The key to use to track the parent id. This will be stored in the
103
+ metadata of child documents."""
104
+ search_kwargs: dict = Field(default_factory=dict)
105
+ """Keyword arguments to pass to the search function."""
106
+ search_type: SearchType = SearchType.similarity
107
+ """Type of search to perform (similarity / mmr)"""
108
+
109
+ # content taken from langchain_core.retrievers.ParentDocumentRetriever
110
+ child_metadata_fields: Optional[Sequence[str]] = None
111
+ """Metadata fields to leave in child documents. If None, leave all parent document
112
+ metadata.
113
+ """
114
+
115
+ # re-ranking
116
+ compressor: Optional[BaseDocumentCompressor] = None
117
+ """Compressor for compressing retrieved documents."""
118
+ compressor_context_size: int = 50
119
+ """Size of the context to use around the query and retrieved spans when compressing."""
120
+ compressor_query_context_size: Optional[int] = 10
121
+ """Size of the context to use around the query when compressing. If None, will use the
122
+ same value as `compressor_context_size`."""
123
+
124
+ @classmethod
125
+ def instantiate_from_config(
126
+ cls, config: Dict[str, Any], overwrites: Optional[Dict[str, Any]] = None
127
+ ) -> "DocumentAwareSpanRetriever":
128
+ """Instantiate a retriever from a configuration dictionary."""
129
+ from hydra.utils import instantiate
130
+
131
+ return instantiate(config, **(overwrites or {}))
132
+
133
+ @classmethod
134
+ def instantiate_from_config_string(
135
+ cls, config_string: str, format: str, overwrites: Optional[Dict[str, Any]] = None
136
+ ) -> "DocumentAwareSpanRetriever":
137
+ """Instantiate a retriever from a configuration string."""
138
+ return cls.instantiate_from_config(
139
+ _parse_config(config_string, format=format), overwrites=overwrites
140
+ )
141
+
142
+ @classmethod
143
+ def instantiate_from_config_file(
144
+ cls, config_path: str, overwrites: Optional[Dict[str, Any]] = None
145
+ ) -> "DocumentAwareSpanRetriever":
146
+ """Instantiate a retriever from a configuration file."""
147
+ with open(config_path, "r") as file:
148
+ config_string = file.read()
149
+ if config_path.endswith(".json"):
150
+ return cls.instantiate_from_config_string(
151
+ config_string, format="json", overwrites=overwrites
152
+ )
153
+ elif config_path.endswith(".yaml"):
154
+ return cls.instantiate_from_config_string(
155
+ config_string, format="yaml", overwrites=overwrites
156
+ )
157
+ else:
158
+ raise ValueError(f"Unsupported file extension: {config_path}")
159
+
160
+ @property
161
+ def pie_annotation_layer_name(self) -> str:
162
+ if issubclass(self.pie_document_type, TextDocumentWithSpans):
163
+ return "spans"
164
+ elif issubclass(self.pie_document_type, TextDocumentWithLabeledSpans):
165
+ return "labeled_spans"
166
+ elif issubclass(self.pie_document_type, TextDocumentWithLabeledMultiSpans):
167
+ return "labeled_multi_spans"
168
+ else:
169
+ raise ValueError(
170
+ f"Unsupported pie document type: {self.pie_document_type}. "
171
+ "Must be one of TextDocumentWithSpans, TextDocumentWithLabeledSpans, "
172
+ "or TextDocumentWithLabeledMultiSpans."
173
+ )
174
+
175
+ def _span_to_dict(self, span: Union[Span, MultiSpan]) -> dict:
176
+ span_dict = {}
177
+ if isinstance(span, Span):
178
+ span_dict[self.vectorstore.SPAN_START_KEY] = span.start
179
+ span_dict[self.vectorstore.SPAN_END_KEY] = span.end
180
+ span_dict["type"] = "Span"
181
+ elif isinstance(span, MultiSpan):
182
+ starts, ends = zip(*span.slices)
183
+ span_dict[self.vectorstore.SPAN_START_KEY] = starts
184
+ span_dict[self.vectorstore.SPAN_END_KEY] = ends
185
+ span_dict["type"] = "MultiSpan"
186
+ else:
187
+ raise ValueError(f"Unsupported span type: {type(span)}")
188
+ if isinstance(span, (LabeledSpan, LabeledMultiSpan)):
189
+ span_dict["label"] = span.label
190
+ span_dict["score"] = span.score
191
+ return span_dict
192
+
193
+ def _dict_to_span(self, span_dict: dict) -> Union[Span, MultiSpan]:
194
+
195
+ if span_dict["type"] == "Span":
196
+ kwargs = dict(
197
+ start=span_dict[self.vectorstore.SPAN_START_KEY],
198
+ end=span_dict[self.vectorstore.SPAN_END_KEY],
199
+ )
200
+ if "label" in span_dict:
201
+ kwargs["label"] = span_dict["label"]
202
+ kwargs["score"] = span_dict["score"]
203
+ return LabeledSpan(**kwargs)
204
+ else:
205
+ return Span(**kwargs)
206
+ elif span_dict["type"] == "MultiSpan":
207
+ starts = span_dict[self.vectorstore.SPAN_START_KEY]
208
+ ends = span_dict[self.vectorstore.SPAN_END_KEY]
209
+ slices = tuple((start, end) for start, end in zip(starts, ends))
210
+ kwargs = dict(slices=slices)
211
+ if "label" in span_dict:
212
+ kwargs["label"] = span_dict["label"]
213
+ kwargs["score"] = span_dict["score"]
214
+ return LabeledMultiSpan(**kwargs)
215
+ else:
216
+ return MultiSpan(**kwargs)
217
+ else:
218
+ raise ValueError(f"Unsupported span type: {span_dict['type']}")
219
+
220
+ def use_predicted_annotations(self, doc: LCDocument) -> bool:
221
+ """Check if the document uses predicted spans."""
222
+ return doc.metadata.get(self.use_predicted_annotations_key, True)
223
+
224
+ def get_document(self, doc_id: str) -> LCDocument:
225
+ """Get a document by its id."""
226
+ documents = self.docstore.mget([doc_id])
227
+ if len(documents) == 0 or documents[0] is None:
228
+ raise DocumentNotFoundError(doc_id=doc_id)
229
+ if len(documents) > 1:
230
+ raise ValueError(f"Multiple documents found with id: {doc_id}")
231
+ return documents[0]
232
+
233
+ def get_span_document(self, span_id: str, with_vector: bool = False) -> LCDocument:
234
+ """Get a span document by its id."""
235
+ if with_vector:
236
+ span_docs = self.vectorstore.get_by_ids_with_vectors([span_id])
237
+ else:
238
+ span_docs = self.vectorstore.get_by_ids([span_id])
239
+ if len(span_docs) == 0 or span_docs[0] is None:
240
+ raise SpanNotFoundError(span_id=span_id)
241
+ if len(span_docs) > 1:
242
+ raise ValueError(f"Multiple span documents found with id: {span_id}")
243
+ return span_docs[0]
244
+
245
+ def get_base_layer(
246
+ self, pie_document: TextBasedDocument, use_predicted_annotations: bool
247
+ ) -> BaseAnnotationList:
248
+ """Get the base layer of the pie document."""
249
+
250
+ if self.pie_annotation_layer_name not in pie_document:
251
+ raise ValueError(
252
+ f'The pie document must contain the annotation layer "{self.pie_annotation_layer_name}"'
253
+ )
254
+ layer = pie_document[self.pie_annotation_layer_name]
255
+ return layer.predictions if use_predicted_annotations else layer
256
+
257
+ def get_span_by_id(self, span_id: str) -> Union[Span, MultiSpan]:
258
+ """Get a span annotation by its id."""
259
+ span_doc = self.get_span_document(span_id)
260
+ doc_id = span_doc.metadata[self.id_key]
261
+ doc = self.get_document(doc_id)
262
+ return self.get_span_from_doc_by_id(doc=doc, span_id=span_id)
263
+
264
+ def get_span_from_doc_by_id(self, doc: LCDocument, span_id: str) -> Union[Span, MultiSpan]:
265
+ """Get the span of a query."""
266
+ base_layer = self.get_base_layer(
267
+ self.docstore.unwrap(doc),
268
+ use_predicted_annotations=self.use_predicted_annotations(doc),
269
+ )
270
+ span_idx = doc.metadata[METADATA_KEY_CHILD_ID2IDX].get(span_id)
271
+ if span_idx is None:
272
+ raise SpanNotFoundError(span_id=span_id, doc_id=doc.id)
273
+ return base_layer[span_idx]
274
+
275
+ def get_span_id2idx_from_doc(self, doc: Union[LCDocument, str]) -> Dict[str, int]:
276
+ """Get all span ids from a document.
277
+
278
+ Args:
279
+ doc: Document or document id
280
+
281
+ Returns:
282
+ Dictionary mapping span ids to their index in the base layer.
283
+ """
284
+
285
+ if isinstance(doc, str):
286
+ doc = self.get_document(doc)
287
+ return doc.metadata[METADATA_KEY_CHILD_ID2IDX]
288
+
289
+ def prepare_search_kwargs(
290
+ self,
291
+ span_id: str,
292
+ doc_id_whitelist: Optional[List[str]] = None,
293
+ doc_id_blacklist: Optional[List[str]] = None,
294
+ kwargs: Optional[Dict[str, Any]] = None,
295
+ ) -> Tuple[dict, LCDocument]:
296
+ # get the span document
297
+ query_span_doc = self.get_span_document(span_id, with_vector=True)
298
+ query_doc_id = query_span_doc.metadata[self.id_key]
299
+ query_doc = self.get_document(query_doc_id)
300
+
301
+ # TODO: why do we do this? Just to be the same as the result of the search when doing compression?
302
+ # add "pie_document" to the metadata
303
+ query_span_doc.metadata[self.docstore.METADATA_KEY_PIE_DOCUMENT] = self.docstore.unwrap(
304
+ query_doc
305
+ )
306
+
307
+ search_kwargs = copy(self.search_kwargs)
308
+ search_kwargs.update(kwargs or {})
309
+
310
+ query_span = self.get_span_from_doc_by_id(doc=query_doc, span_id=span_id)
311
+
312
+ if self.retrieve_from_different_documents and self.retrieve_from_same_document:
313
+ raise ValueError("Cannot retrieve from both same and different documents")
314
+
315
+ if self.retrieve_from_same_document:
316
+ if doc_id_whitelist is None:
317
+ doc_id_whitelist = [query_doc_id]
318
+ elif query_doc_id not in doc_id_whitelist:
319
+ doc_id_whitelist.append(query_doc_id)
320
+
321
+ if self.retrieve_from_different_documents:
322
+ if doc_id_blacklist is None:
323
+ doc_id_blacklist = [query_doc_id]
324
+ elif query_doc_id not in doc_id_blacklist:
325
+ doc_id_blacklist.append(query_doc_id)
326
+
327
+ query_filter = self.vectorstore.construct_filter(
328
+ query_span=query_span,
329
+ metadata_doc_id_key=self.id_key,
330
+ doc_id_whitelist=doc_id_whitelist,
331
+ doc_id_blacklist=doc_id_blacklist,
332
+ )
333
+ if query_filter is not None:
334
+ search_kwargs["filter"] = query_filter
335
+
336
+ # get the vector of the reference span
337
+ search_kwargs["embedding"] = query_span_doc.metadata[self.vectorstore.METADATA_VECTOR_KEY]
338
+ return search_kwargs, query_span_doc
339
+
340
+ def _prepare_query_for_compression(self, query_doc: LCDocument) -> str:
341
+ return self._prepare_doc_for_compression(
342
+ query_doc, context_size=self.compressor_query_context_size
343
+ ).page_content
344
+
345
+ def _prepare_doc_for_compression(
346
+ self, doc: LCDocument, context_size: Optional[int] = None
347
+ ) -> LCDocument:
348
+ if context_size is None:
349
+ context_size = self.compressor_context_size
350
+ pie_doc: TextBasedDocument = self.docstore.unwrap(doc)
351
+ text = pie_doc.text
352
+ span_dict = doc.metadata[self.vectorstore.METADATA_SPAN_KEY]
353
+ span_start = span_dict[self.vectorstore.SPAN_START_KEY]
354
+ span_end = span_dict[self.vectorstore.SPAN_END_KEY]
355
+ if isinstance(span_start, list):
356
+ span_start = span_start[0]
357
+ if isinstance(span_end, list):
358
+ span_end = span_end[0]
359
+ context_start = span_start - context_size
360
+ context_end = span_end + context_size
361
+ doc.page_content = text[max(0, context_start) : min(context_end, len(text))]
362
+ # save the original relevance score and remove it because otherwise we will not be able to get
363
+ # the reranking relevance score
364
+ if "relevance_score" in doc.metadata:
365
+ doc.metadata["relevance_score_without_reranking"] = doc.metadata.pop("relevance_score")
366
+ return doc
367
+
368
+ def _get_relevant_documents(
369
+ self,
370
+ query: str,
371
+ *,
372
+ run_manager: CallbackManagerForRetrieverRun,
373
+ doc_id_whitelist: Optional[List[str]] = None,
374
+ doc_id_blacklist: Optional[List[str]] = None,
375
+ **kwargs: Any,
376
+ ) -> List[LCDocument]:
377
+ """Get span documents relevant to a query span
378
+ Args:
379
+ query: The span id to find relevant spans for
380
+ run_manager: The callbacks handler to use
381
+ Returns:
382
+ List of relevant span documents with metadata from the parent document
383
+ """
384
+
385
+ search_kwargs, query_span_doc = self.prepare_search_kwargs(
386
+ span_id=query,
387
+ kwargs=kwargs,
388
+ doc_id_whitelist=doc_id_whitelist,
389
+ doc_id_blacklist=doc_id_blacklist,
390
+ )
391
+ if self.search_type == SearchType.mmr:
392
+ span_docs = self.vectorstore.max_marginal_relevance_search_by_vector(**search_kwargs)
393
+ elif self.search_type == SearchType.similarity_score_threshold:
394
+ sub_docs_and_similarities = self.vectorstore.similarity_search_with_score_by_vector(
395
+ **search_kwargs
396
+ )
397
+ span_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
398
+ else:
399
+ span_docs = self.vectorstore.similarity_search_by_vector(**search_kwargs)
400
+
401
+ # We do this to maintain the order of the ids that are returned
402
+ doc_ids = []
403
+ for span_doc in span_docs:
404
+ if self.id_key not in span_doc.metadata:
405
+ raise ValueError(f"Metadata must contain the key {self.id_key}")
406
+ if span_doc.metadata[self.id_key] not in doc_ids:
407
+ doc_ids.append(span_doc.metadata[self.id_key])
408
+ docs = self.docstore.mget(doc_ids)
409
+ doc_id2doc = dict(zip(doc_ids, docs))
410
+ for span_doc in span_docs:
411
+ doc = doc_id2doc[span_doc.metadata[self.id_key]]
412
+ span_doc.metadata.update(doc.metadata)
413
+ span_doc.metadata["attached_span"] = self.get_span_from_doc_by_id(
414
+ doc=doc, span_id=span_doc.id
415
+ )
416
+ span_doc.metadata["query_span_id"] = query
417
+ # filter out the query span doc
418
+ span_docs_filtered = [
419
+ span_doc for span_doc in span_docs if span_doc.id != query_span_doc.id
420
+ ]
421
+ if self.compressor is None:
422
+ return span_docs_filtered
423
+ if span_docs_filtered:
424
+ prepared_docs = [
425
+ self._prepare_doc_for_compression(sub_doc) for sub_doc in span_docs_filtered
426
+ ]
427
+ prepared_query = self._prepare_query_for_compression(query_span_doc)
428
+ compressed_docs = self.compressor.compress_documents(
429
+ documents=prepared_docs, query=prepared_query, callbacks=run_manager.get_child()
430
+ )
431
+ return list(compressed_docs)
432
+ else:
433
+ return []
434
+
435
+ async def _aget_relevant_documents(
436
+ self,
437
+ query: str,
438
+ *,
439
+ run_manager: AsyncCallbackManagerForRetrieverRun,
440
+ doc_id_whitelist: Optional[List[str]] = None,
441
+ doc_id_blacklist: Optional[List[str]] = None,
442
+ **kwargs: Any,
443
+ ) -> List[LCDocument]:
444
+ """Asynchronously get span documents relevant to a query span
445
+ Args:
446
+ query: The span id to find relevant spans for
447
+ run_manager: The callbacks handler to use
448
+ Returns:
449
+ List of relevant span documents with metadata from the parent document
450
+ """
451
+ search_kwargs, query_span_doc = self.prepare_search_kwargs(
452
+ span_id=query,
453
+ kwargs=kwargs,
454
+ doc_id_whitelist=doc_id_whitelist,
455
+ doc_id_blacklist=doc_id_blacklist,
456
+ )
457
+ if self.search_type == SearchType.mmr:
458
+ span_docs = await self.vectorstore.amax_marginal_relevance_search_by_vector(
459
+ **search_kwargs
460
+ )
461
+ elif self.search_type == SearchType.similarity_score_threshold:
462
+ sub_docs_and_similarities = (
463
+ await self.vectorstore.asimilarity_search_with_score_by_vector(**search_kwargs)
464
+ )
465
+ span_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
466
+ else:
467
+ span_docs = await self.vectorstore.asimilarity_search_by_vector(**search_kwargs)
468
+
469
+ # We do this to maintain the order of the ids that are returned
470
+ ids = []
471
+ for span_doc in span_docs:
472
+ if self.id_key not in span_doc.metadata:
473
+ raise ValueError(f"Metadata must contain the key {self.id_key}")
474
+ if span_doc.metadata[self.id_key] not in ids:
475
+ ids.append(span_doc.metadata[self.id_key])
476
+ docs = await self.docstore.amget(ids)
477
+ doc_id2doc = dict(zip(ids, docs))
478
+ for span_doc in span_docs:
479
+ doc = doc_id2doc[span_doc.metadata[self.id_key]]
480
+ span_doc.metadata.update(doc.metadata)
481
+ span_doc.metadata["attached_span"] = self.get_span_from_doc_by_id(
482
+ doc=doc, span_id=span_doc.id
483
+ )
484
+ span_doc.metadata["query_span_id"] = query
485
+ # filter out the query span doc
486
+ span_docs_filtered = [
487
+ span_doc for span_doc in span_docs if span_doc.id != query_span_doc.id
488
+ ]
489
+
490
+ if self.compressor is None:
491
+ return span_docs_filtered
492
+ if docs:
493
+ prepared_docs = [
494
+ self._prepare_doc_for_compression(sub_doc) for sub_doc in span_docs_filtered
495
+ ]
496
+ prepared_query = self._prepare_query_for_compression(query_span_doc)
497
+ compressed_docs = await self.base_compressor.acompress_documents(
498
+ prepared_docs, query=prepared_query, callbacks=run_manager.get_child()
499
+ )
500
+ return list(compressed_docs)
501
+ else:
502
+ return []
503
+
504
+ def create_span_documents(
505
+ self, documents: List[LCDocument]
506
+ ) -> Tuple[List[LCDocument], Dict[str, int]]:
507
+ span_docs = []
508
+ id2idx = {}
509
+ for i, doc in enumerate(documents):
510
+ pie_doc, metadata = self.docstore.unwrap_with_metadata(doc)
511
+ base_layer = self.get_base_layer(
512
+ pie_doc, use_predicted_annotations=self.use_predicted_annotations(doc)
513
+ )
514
+ if len(base_layer) == 0:
515
+ logger.warning(f"No spans found in document {i} (id: {doc.id})")
516
+ for idx, labeled_span in enumerate(base_layer):
517
+ _metadata = {k: v for k, v in metadata.items() if k != METADATA_KEY_CHILD_ID2IDX}
518
+ # save as dict to avoid serialization issues
519
+ _metadata[self.vectorstore.METADATA_SPAN_KEY] = self._span_to_dict(labeled_span)
520
+ new_doc = LCDocument(
521
+ id=str(uuid.uuid4()), page_content=pie_doc.text, metadata=_metadata
522
+ )
523
+ span_docs.append(new_doc)
524
+ id2idx[new_doc.id] = idx
525
+ return span_docs, id2idx
526
+
527
+ def _split_docs_for_adding(
528
+ self,
529
+ documents: List[LCDocument],
530
+ ids: Optional[List[str]] = None,
531
+ add_to_docstore: bool = True,
532
+ ) -> Tuple[List[LCDocument], List[Tuple[str, LCDocument]]]:
533
+ if ids is None:
534
+ doc_ids = [doc.id for doc in documents]
535
+ if not add_to_docstore:
536
+ raise ValueError("If ids are not passed in, `add_to_docstore` MUST be True")
537
+ else:
538
+ if len(documents) != len(ids):
539
+ raise ValueError(
540
+ "Got uneven list of documents and ids. "
541
+ "If `ids` is provided, should be same length as `documents`."
542
+ )
543
+ doc_ids = ids
544
+
545
+ if len(set(doc_ids)) != len(doc_ids):
546
+ raise ValueError("IDs must be unique")
547
+
548
+ docs = []
549
+ full_docs = []
550
+ for i, doc in enumerate(documents):
551
+ _id = doc_ids[i]
552
+ sub_docs, sub_doc_id2idx = self.create_span_documents([doc])
553
+ if self.child_metadata_fields is not None:
554
+ for sub_doc in sub_docs:
555
+ sub_doc.metadata = {k: sub_doc.metadata[k] for k in self.child_metadata_fields}
556
+ for sub_doc in sub_docs:
557
+ # Add the parent id to the child document id
558
+ sub_doc.metadata[self.id_key] = _id
559
+ docs.extend(sub_docs)
560
+ doc.metadata[METADATA_KEY_CHILD_ID2IDX] = sub_doc_id2idx
561
+ full_docs.append((_id, doc))
562
+
563
+ return docs, full_docs
564
+
565
+ def remove_missing_span_ids_from_document(
566
+ self, document: LCDocument, span_ids: Set[str]
567
+ ) -> LCDocument:
568
+ """Remove invalid span ids from the span to idx mapping
569
+ of the document.
570
+
571
+ Args:
572
+ document: Document to remove invalid span ids from
573
+ span_ids: Set of valid span ids
574
+
575
+ Returns:
576
+ Document with invalid span ids removed
577
+ """
578
+ span_id2idx = document.metadata[METADATA_KEY_CHILD_ID2IDX]
579
+ new_doc = copy(document)
580
+ filtered_span_id2idx = {
581
+ span_id: idx for span_id, idx in span_id2idx.items() if span_id in span_ids
582
+ }
583
+ new_doc.metadata[METADATA_KEY_CHILD_ID2IDX] = filtered_span_id2idx
584
+ missed_span_ids = set(span_id2idx.keys()) - span_ids
585
+ if len(missed_span_ids) > 0:
586
+ layer = self.get_base_layer(
587
+ self.docstore.unwrap(document),
588
+ use_predicted_annotations=self.use_predicted_annotations(document),
589
+ )
590
+ resolved_missed_spans = [
591
+ layer[span_id2idx[span_id]].resolve() for span_id in missed_span_ids
592
+ ]
593
+ logger.warning(
594
+ f"Document {document.id} contains spans that can not be added to the "
595
+ f"vectorstore because no vector could be calculated:\n{resolved_missed_spans}.\n"
596
+ "These spans will be not queryable."
597
+ )
598
+ return document
599
+
600
+ def add_documents(
601
+ self,
602
+ documents: List[LCDocument],
603
+ ids: Optional[List[str]] = None,
604
+ add_to_docstore: bool = True,
605
+ **kwargs: Any,
606
+ ) -> None:
607
+ """Adds documents to the docstore and vectorstores.
608
+
609
+ Args:
610
+ documents: List of documents to add
611
+ ids: Optional list of ids for documents. If provided should be the same
612
+ length as the list of documents. Can be provided if parent documents
613
+ are already in the document store and you don't want to re-add
614
+ to the docstore. If not provided, random UUIDs will be used as
615
+ ids.
616
+ add_to_docstore: Boolean of whether to add documents to docstore.
617
+ This can be false if and only if `ids` are provided. You may want
618
+ to set this to False if the documents are already in the docstore
619
+ and you don't want to re-add them.
620
+ """
621
+ docs, full_docs = self._split_docs_for_adding(documents, ids, add_to_docstore)
622
+ added_span_ids = self.vectorstore.add_documents(docs, **kwargs)
623
+ full_docs = [
624
+ (doc_id, self.remove_missing_span_ids_from_document(doc, set(added_span_ids)))
625
+ for doc_id, doc in full_docs
626
+ ]
627
+ if add_to_docstore:
628
+ self.docstore.mset(full_docs)
629
+
630
+ async def aadd_documents(
631
+ self,
632
+ documents: List[LCDocument],
633
+ ids: Optional[List[str]] = None,
634
+ add_to_docstore: bool = True,
635
+ **kwargs: Any,
636
+ ) -> None:
637
+ docs, full_docs = self._split_docs_for_adding(documents, ids, add_to_docstore)
638
+ added_span_ids = await self.vectorstore.aadd_documents(docs, **kwargs)
639
+ full_docs = [
640
+ (doc_id, self.remove_missing_span_ids_from_document(doc, set(added_span_ids)))
641
+ for doc_id, doc in full_docs
642
+ ]
643
+ if add_to_docstore:
644
+ await self.docstore.amset(full_docs)
645
+
646
+ def delete_documents(self, ids: List[str]) -> None:
647
+ """Remove documents from the docstore and vectorstores.
648
+
649
+ Args:
650
+ ids: List of ids to remove
651
+ """
652
+ # get all child ids
653
+ child_ids = []
654
+ for doc in self.docstore.mget(ids):
655
+ child_ids.extend(doc.metadata[METADATA_KEY_CHILD_ID2IDX])
656
+
657
+ self.vectorstore.delete(child_ids)
658
+ self.docstore.mdelete(ids)
659
+
660
+ async def adelete_documents(self, ids: List[str]) -> None:
661
+ """Asynchronously remove documents from the docstore and vectorstores.
662
+
663
+ Args:
664
+ ids: List of ids to remove
665
+ """
666
+ # get all child ids
667
+ child_ids = []
668
+ docs: List[LCDocument] = await self.docstore.amget(ids)
669
+ for doc in docs:
670
+ child_ids.extend(doc.metadata[METADATA_KEY_CHILD_ID2IDX])
671
+
672
+ await self.vectorstore.adelete(child_ids)
673
+ await self.docstore.amdelete(ids)
674
+
675
+ def add_pie_documents(
676
+ self,
677
+ documents: List[TextBasedDocument],
678
+ use_predicted_annotations: bool,
679
+ metadata: Optional[Dict[str, Any]] = None,
680
+ ) -> None:
681
+ """Add pie documents to the retriever.
682
+
683
+ Args:
684
+ documents: List of pie documents to add
685
+ use_predicted_annotations: Whether to use the predicted annotations or the gold annotations
686
+ metadata: Optional metadata to add to each document
687
+ """
688
+ metadata = metadata or {}
689
+ metadata = copy(metadata)
690
+ metadata[self.use_predicted_annotations_key] = use_predicted_annotations
691
+ docs = [self.docstore.wrap(doc, **metadata) for doc in documents]
692
+
693
+ # delete any existing documents with the same ids (simply overwriting would keep the spans)
694
+ new_docs_ids = [doc.id for doc in docs]
695
+ existing_docs = self.docstore.mget(new_docs_ids)
696
+ existing_doc_ids = [doc.id for doc in existing_docs]
697
+ self.delete_documents(existing_doc_ids)
698
+
699
+ self.add_documents(docs)
700
+
701
+ def _save_to_directory(self, path: str, **kwargs) -> None:
702
+ logger.info(f'Saving docstore and vectorstore to "{path}" ...')
703
+ self.docstore.save_to_directory(os.path.join(path, "docstore"))
704
+ self.vectorstore.save_to_directory(os.path.join(path, "vectorstore"))
705
+
706
+ def _load_from_directory(self, path: str, **kwargs) -> None:
707
+ logger.info(f'Loading docstore and vectorstore from "{path}" ...')
708
+ self.docstore.load_from_directory(os.path.join(path, "docstore"))
709
+ self.vectorstore.load_from_directory(os.path.join(path, "vectorstore"))
710
+
711
+
712
+ METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES = "relation_label2tails_with_scores"
713
+
714
+
715
+ class DocumentAwareSpanRetrieverWithRelations(DocumentAwareSpanRetriever):
716
+ """Retriever for related contextualized text spans, i.e. spans linked by relations
717
+ to reference spans that are similar to the query span. It accepts spans as queries and
718
+ retrieves spans with their containing document and the reference span."""
719
+
720
+ relation_layer_name: str = "binary_relations"
721
+ """The name of the relation annotation layer in the pie document."""
722
+ relation_labels: Optional[List[str]] = None
723
+ """The list of relation labels to consider."""
724
+ span_labels: Optional[List[str]] = None
725
+ """The list of span labels to consider."""
726
+ reversed_relations_suffix: Optional[str] = None
727
+ """Whether to consider reverse relations as well."""
728
+
729
+ def get_relation_layer(
730
+ self, pie_document: TextBasedDocument, use_predicted_annotations: bool
731
+ ) -> BaseAnnotationList:
732
+ """Get the relation layer of the pie document."""
733
+ if self.relation_layer_name not in pie_document:
734
+ raise ValueError(
735
+ f'The pie document must contain the annotation layer "{self.relation_layer_name}"'
736
+ )
737
+ layer = pie_document[self.relation_layer_name]
738
+ return layer.predictions if use_predicted_annotations else layer
739
+
740
+ def create_span_documents(
741
+ self, documents: List[LCDocument]
742
+ ) -> Tuple[List[LCDocument], Dict[str, int]]:
743
+ span_docs = []
744
+ id2idx = {}
745
+ for i, doc in enumerate(documents):
746
+ pie_doc, metadata = self.docstore.unwrap_with_metadata(doc)
747
+ use_predicted_annotations = self.use_predicted_annotations(doc)
748
+ base_layer = self.get_base_layer(
749
+ pie_doc, use_predicted_annotations=use_predicted_annotations
750
+ )
751
+ if len(base_layer) == 0:
752
+ logger.warning(f"No spans found in document {i} (id: {doc.id})")
753
+ id2span = {str(uuid.uuid4()): span for span in base_layer}
754
+ span2id = {span: span_id for span_id, span in id2span.items()}
755
+ if len(id2span) != len(span2id):
756
+ raise ValueError("Span ids and spans must be unique")
757
+ relations = self.get_relation_layer(
758
+ pie_doc, use_predicted_annotations=use_predicted_annotations
759
+ )
760
+ head2label2tails_with_scores: Dict[str, Dict[str, List[Tuple[str, float]]]] = (
761
+ defaultdict(lambda: defaultdict(list))
762
+ )
763
+
764
+ for relation in relations:
765
+ if self.relation_labels is None or relation.label in self.relation_labels:
766
+ head2label2tails_with_scores[span2id[relation.head]][relation.label].append(
767
+ (span2id[relation.tail], relation.score)
768
+ )
769
+ if self.reversed_relations_suffix is not None:
770
+ reversed_label = f"{relation.label}{self.reversed_relations_suffix}"
771
+ if self.relation_labels is None or reversed_label in self.relation_labels:
772
+ head2label2tails_with_scores[span2id[relation.tail]][
773
+ reversed_label
774
+ ].append((span2id[relation.head], relation.score))
775
+
776
+ for idx, span in enumerate(base_layer):
777
+ span_id = span2id[span]
778
+ _metadata = {k: v for k, v in metadata.items() if k != METADATA_KEY_CHILD_ID2IDX}
779
+ # save as dict to avoid serialization issues
780
+ _metadata[self.vectorstore.METADATA_SPAN_KEY] = self._span_to_dict(span)
781
+ relation_label2tails_with_scores = head2label2tails_with_scores[span_id]
782
+ _metadata[METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES] = dict(
783
+ relation_label2tails_with_scores
784
+ )
785
+ new_doc = LCDocument(id=span_id, page_content=pie_doc.text, metadata=_metadata)
786
+ span_docs.append(new_doc)
787
+ id2idx[span_id] = idx
788
+ return span_docs, id2idx
789
+
790
+ def _get_relevant_documents(
791
+ self,
792
+ query: str,
793
+ return_related: bool = False,
794
+ *,
795
+ run_manager: CallbackManagerForRetrieverRun,
796
+ **kwargs: Any,
797
+ ) -> List[LCDocument]:
798
+ """Get span documents relevant to a query span. We follow one hop of relations.
799
+
800
+ Args:
801
+ query: The span id to find relevant spans for
802
+ return_related: Whether to return related spans
803
+ run_manager: The callbacks handler to use
804
+ Returns:
805
+ List of relevant span documents with metadata from the parent document
806
+ """
807
+ similar_span_docs = super()._get_relevant_documents(
808
+ query=query, run_manager=run_manager, **kwargs
809
+ )
810
+ if not return_related:
811
+ return similar_span_docs
812
+
813
+ related_docs = []
814
+ for head_span_doc in similar_span_docs:
815
+ doc_id = head_span_doc.metadata[self.id_key]
816
+ doc = self.get_document(doc_id)
817
+ query_span_id = head_span_doc.metadata["query_span_id"]
818
+
819
+ for relation_label, tails_with_score in head_span_doc.metadata[
820
+ METADATA_KEY_RELATION_LABEL2TAILS_WITH_SCORES
821
+ ].items():
822
+ for tail_id, relation_score in tails_with_score:
823
+ # in the case that we query against the same document,
824
+ # we don't want to return the same span as the query span
825
+ if tail_id == query_span_id:
826
+ continue
827
+
828
+ try:
829
+ attached_tail_span = self.get_span_from_doc_by_id(doc=doc, span_id=tail_id)
830
+ # this may happen if the tail span could not be added to the vectorstore, e.g. because
831
+ # the token span length is zero and no vector could be calculated
832
+ except SpanNotFoundError:
833
+ logger.warning(
834
+ f"Tail span with id [{tail_id}] not found in the vectorstore. Skipping."
835
+ )
836
+ continue
837
+
838
+ # TODO: handle via filter? see vectorstore.construct_filter
839
+ if self.span_labels is not None:
840
+ if not isinstance(attached_tail_span, (LabeledSpan, LabeledMultiSpan)):
841
+ raise ValueError(
842
+ "Span must must be a labeled span if span_labels is provided"
843
+ )
844
+ if attached_tail_span.label not in self.span_labels:
845
+ continue
846
+
847
+ related_docs.append(
848
+ LCDocument(
849
+ id=tail_id,
850
+ page_content="",
851
+ metadata={
852
+ "relation_score": relation_score,
853
+ "head_id": head_span_doc.id,
854
+ "relation_label": relation_label,
855
+ "attached_tail_span": attached_tail_span,
856
+ **head_span_doc.metadata,
857
+ },
858
+ )
859
+ )
860
+ return related_docs
src/langchain_modules/span_vectorstore.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Optional, Sequence, Union
3
+
4
+ from langchain_core.documents import Document as LCDocument
5
+ from langchain_core.runnables import run_in_executor
6
+ from langchain_core.stores import BaseStore
7
+ from langchain_core.vectorstores import VectorStore
8
+ from pytorch_ie.annotations import MultiSpan, Span
9
+
10
+ from .serializable_store import SerializableStore
11
+ from .span_embeddings import SpanEmbeddings
12
+
13
+
14
+ class SpanVectorStore(VectorStore, BaseStore, SerializableStore, ABC):
15
+ """Abstract base class for vector stores specialized in storing
16
+ and retrieving embeddings for text spans within documents."""
17
+
18
+ METADATA_SPAN_KEY: str = "pie_labeled_span"
19
+ """Key for the span data in the (langchain) document metadata."""
20
+ SPAN_START_KEY: str = "start"
21
+ """Key for the start of the span in the span data."""
22
+ SPAN_END_KEY: str = "end"
23
+ """Key for the end of the span in the span data."""
24
+ METADATA_VECTOR_KEY: str = "vector"
25
+ """Key for the vector in the (langchain) document metadata."""
26
+ RELEVANCE_SCORE_KEY: str = "relevance_score"
27
+ """Key for the relevance score in the (langchain) document metadata."""
28
+
29
+ def __init__(
30
+ self,
31
+ label_mapping: Optional[Dict[str, List[str]]] = None,
32
+ **kwargs: Any,
33
+ ):
34
+ """Initialize the SpanVectorStore.
35
+
36
+ Args:
37
+ label_mapping: Mapping from query span labels to target span labels. If provided,
38
+ only spans with a label in the mapping for the query span's label are retrieved.
39
+ **kwargs: Additional arguments.
40
+ """
41
+ self.label_mapping = label_mapping
42
+ super().__init__(**kwargs)
43
+
44
+ @property
45
+ def embeddings(self) -> SpanEmbeddings:
46
+ """Get the dense embeddings instance that is being used.
47
+
48
+ Raises:
49
+ ValueError: If embeddings are `None`.
50
+
51
+ Returns:
52
+ Embeddings: An instance of `Embeddings`.
53
+ """
54
+ result = super().embeddings
55
+ if not isinstance(result, SpanEmbeddings):
56
+ raise ValueError(f"Embeddings must be of type SpanEmbeddings, but got: {result}")
57
+ return result
58
+
59
+ @abstractmethod
60
+ def get_by_ids_with_vectors(self, ids: Sequence[Union[str, int]], /) -> List[LCDocument]:
61
+ """Get documents by their ids.
62
+
63
+ Args:
64
+ ids: List of document ids.
65
+
66
+ Returns:
67
+ List of documents including their vectors in the metadata at key `metadata_vector_key`.
68
+ """
69
+ ...
70
+
71
+ @abstractmethod
72
+ def construct_filter(
73
+ self,
74
+ query_span: Union[Span, MultiSpan],
75
+ metadata_doc_id_key: str,
76
+ doc_id_whitelist: Optional[Sequence[str]] = None,
77
+ doc_id_blacklist: Optional[Sequence[str]] = None,
78
+ ) -> Any:
79
+ """Construct a filter for the retrieval. It should enforce that:
80
+ - if the span is labeled, the retrieved span has the same label, or
81
+ - if, in addition, a label mapping is provided, the retrieved span has a label that is in the mapping for the query span's label
82
+ - if `doc_id_whitelist` is provided, the retrieved span is from a document in the whitelist
83
+ - if `doc_id_blacklist` is provided, the retrieved span is not from a document in the blacklist
84
+
85
+ Args:
86
+ query_span: The query span.
87
+ metadata_doc_id_key: The key in the metadata that holds the document id.
88
+ doc_id_whitelist: A list of document ids to restrict the retrieval to.
89
+ doc_id_blacklist: A list of document ids to exclude from the retrieval.
90
+
91
+ Returns:
92
+ A filter object.
93
+ """
94
+ ...
95
+
96
+ @abstractmethod
97
+ def similarity_search_with_score_by_vector(
98
+ self, embedding: list[float], k: int = 4, **kwargs: Any
99
+ ) -> list[LCDocument]:
100
+ """Return docs most similar to embedding vector.
101
+
102
+ Args:
103
+ embedding: Embedding to look up documents similar to.
104
+ k: Number of Documents to return. Defaults to 4.
105
+ **kwargs: Arguments to pass to the search method.
106
+
107
+ Returns:
108
+ List of Documents most similar to the query vector.
109
+ """
110
+ ...
111
+
112
+ async def asimilarity_search_with_score_by_vector(
113
+ self, embedding: list[float], k: int = 4, **kwargs: Any
114
+ ) -> list[LCDocument]:
115
+ """Async return docs most similar to embedding vector.
116
+
117
+ Args:
118
+ embedding: Embedding to look up documents similar to.
119
+ k: Number of Documents to return. Defaults to 4.
120
+ **kwargs: Arguments to pass to the search method.
121
+
122
+ Returns:
123
+ List of Documents most similar to the query vector.
124
+ """
125
+
126
+ # This is a temporary workaround to make the similarity search
127
+ # asynchronous. The proper solution is to make the similarity search
128
+ # asynchronous in the vector store implementations.
129
+ return await run_in_executor(
130
+ None, self.similarity_search_with_score_by_vector, embedding, k=k, **kwargs
131
+ )