Spaces:
Runtime error
Runtime error
"""Wrapper around the Milvus vector database.""" | |
from __future__ import annotations | |
import uuid | |
from typing import Any, Iterable, List, Optional, Tuple | |
import numpy as np | |
from langchain.docstore.document import Document | |
from langchain.embeddings.base import Embeddings | |
from langchain.vectorstores.base import VectorStore | |
from langchain.vectorstores.utils import maximal_marginal_relevance | |
class Milvus(VectorStore): | |
"""Wrapper around the Milvus vector database.""" | |
def __init__( | |
self, | |
embedding_function: Embeddings, | |
connection_args: dict, | |
collection_name: str, | |
text_field: str, | |
): | |
"""Initialize wrapper around the milvus vector database. | |
In order to use this you need to have `pymilvus` installed and a | |
running Milvus instance. | |
See the following documentation for how to run a Milvus instance: | |
https://milvus.io/docs/install_standalone-docker.md | |
Args: | |
embedding_function (Embeddings): Function used to embed the text | |
connection_args (dict): Arguments for pymilvus connections.connect() | |
collection_name (str): The name of the collection to search. | |
text_field (str): The field in Milvus schema where the | |
original text is stored. | |
""" | |
try: | |
from pymilvus import Collection, DataType, connections | |
except ImportError: | |
raise ValueError( | |
"Could not import pymilvus python package. " | |
"Please install it with `pip install pymilvus`." | |
) | |
# Connecting to Milvus instance | |
if not connections.has_connection("default"): | |
connections.connect(**connection_args) | |
self.embedding_func = embedding_function | |
self.collection_name = collection_name | |
self.text_field = text_field | |
self.auto_id = False | |
self.primary_field = None | |
self.vector_field = None | |
self.fields = [] | |
self.col = Collection(self.collection_name) | |
schema = self.col.schema | |
# Grabbing the fields for the existing collection. | |
for x in schema.fields: | |
self.fields.append(x.name) | |
if x.auto_id: | |
self.fields.remove(x.name) | |
if x.is_primary: | |
self.primary_field = x.name | |
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: | |
self.vector_field = x.name | |
# Default search params when one is not provided. | |
self.index_params = { | |
"IVF_FLAT": {"params": {"nprobe": 10}}, | |
"IVF_SQ8": {"params": {"nprobe": 10}}, | |
"IVF_PQ": {"params": {"nprobe": 10}}, | |
"HNSW": {"params": {"ef": 10}}, | |
"RHNSW_FLAT": {"params": {"ef": 10}}, | |
"RHNSW_SQ": {"params": {"ef": 10}}, | |
"RHNSW_PQ": {"params": {"ef": 10}}, | |
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, | |
"ANNOY": {"params": {"search_k": 10}}, | |
} | |
def add_texts( | |
self, | |
texts: Iterable[str], | |
metadatas: Optional[List[dict]] = None, | |
partition_name: Optional[str] = None, | |
timeout: Optional[int] = None, | |
**kwargs: Any, | |
) -> List[str]: | |
"""Insert text data into Milvus. | |
When using add_texts() it is assumed that a collecton has already | |
been made and indexed. If metadata is included, it is assumed that | |
it is ordered correctly to match the schema provided to the Collection | |
and that the embedding vector is the first schema field. | |
Args: | |
texts (Iterable[str]): The text being embedded and inserted. | |
metadatas (Optional[List[dict]], optional): The metadata that | |
corresponds to each insert. Defaults to None. | |
partition_name (str, optional): The partition of the collection | |
to insert data into. Defaults to None. | |
timeout: specified timeout. | |
Returns: | |
List[str]: The resulting keys for each inserted element. | |
""" | |
insert_dict: Any = {self.text_field: list(texts)} | |
try: | |
insert_dict[self.vector_field] = self.embedding_func.embed_documents( | |
list(texts) | |
) | |
except NotImplementedError: | |
insert_dict[self.vector_field] = [ | |
self.embedding_func.embed_query(x) for x in texts | |
] | |
# Collect the metadata into the insert dict. | |
if len(self.fields) > 2 and metadatas is not None: | |
for d in metadatas: | |
for key, value in d.items(): | |
if key in self.fields: | |
insert_dict.setdefault(key, []).append(value) | |
# Convert dict to list of lists for insertion | |
insert_list = [insert_dict[x] for x in self.fields] | |
# Insert into the collection. | |
res = self.col.insert( | |
insert_list, partition_name=partition_name, timeout=timeout | |
) | |
# Flush to make sure newly inserted is immediately searchable. | |
self.col.flush() | |
return res.primary_keys | |
def _worker_search( | |
self, | |
query: str, | |
k: int = 4, | |
param: Optional[dict] = None, | |
expr: Optional[str] = None, | |
partition_names: Optional[List[str]] = None, | |
round_decimal: int = -1, | |
timeout: Optional[int] = None, | |
**kwargs: Any, | |
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]: | |
# Load the collection into memory for searching. | |
self.col.load() | |
# Decide to use default params if not passed in. | |
if param is None: | |
index_type = self.col.indexes[0].params["index_type"] | |
param = self.index_params[index_type] | |
# Embed the query text. | |
data = [self.embedding_func.embed_query(query)] | |
# Determine result metadata fields. | |
output_fields = self.fields[:] | |
output_fields.remove(self.vector_field) | |
# Perform the search. | |
res = self.col.search( | |
data, | |
self.vector_field, | |
param, | |
k, | |
expr=expr, | |
output_fields=output_fields, | |
partition_names=partition_names, | |
round_decimal=round_decimal, | |
timeout=timeout, | |
**kwargs, | |
) | |
# Organize results. | |
ret = [] | |
for result in res[0]: | |
meta = {x: result.entity.get(x) for x in output_fields} | |
ret.append( | |
( | |
Document(page_content=meta.pop(self.text_field), metadata=meta), | |
result.distance, | |
result.id, | |
) | |
) | |
return data[0], ret | |
def similarity_search_with_score( | |
self, | |
query: str, | |
k: int = 4, | |
param: Optional[dict] = None, | |
expr: Optional[str] = None, | |
partition_names: Optional[List[str]] = None, | |
round_decimal: int = -1, | |
timeout: Optional[int] = None, | |
**kwargs: Any, | |
) -> List[Tuple[Document, float]]: | |
"""Perform a search on a query string and return results. | |
Args: | |
query (str): The text being searched. | |
k (int, optional): The amount of results ot return. Defaults to 4. | |
param (dict, optional): The search params for the specified index. | |
Defaults to None. | |
expr (str, optional): Filtering expression. Defaults to None. | |
partition_names (List[str], optional): Partitions to search through. | |
Defaults to None. | |
round_decimal (int, optional): Round the resulting distance. Defaults | |
to -1. | |
timeout (int, optional): Amount to wait before timeout error. Defaults | |
to None. | |
kwargs: Collection.search() keyword arguments. | |
Returns: | |
List[float], List[Tuple[Document, any, any]]: search_embedding, | |
(Document, distance, primary_field) results. | |
""" | |
_, result = self._worker_search( | |
query, k, param, expr, partition_names, round_decimal, timeout, **kwargs | |
) | |
return [(x, y) for x, y, _ in result] | |
def max_marginal_relevance_search( | |
self, | |
query: str, | |
k: int = 4, | |
fetch_k: int = 20, | |
param: Optional[dict] = None, | |
expr: Optional[str] = None, | |
partition_names: Optional[List[str]] = None, | |
round_decimal: int = -1, | |
timeout: Optional[int] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Perform a search and return results that are reordered by MMR. | |
Args: | |
query (str): The text being searched. | |
k (int, optional): How many results to give. Defaults to 4. | |
fetch_k (int, optional): Total results to select k from. | |
Defaults to 20. | |
param (dict, optional): The search params for the specified index. | |
Defaults to None. | |
expr (str, optional): Filtering expression. Defaults to None. | |
partition_names (List[str], optional): What partitions to search. | |
Defaults to None. | |
round_decimal (int, optional): Round the resulting distance. Defaults | |
to -1. | |
timeout (int, optional): Amount to wait before timeout error. Defaults | |
to None. | |
Returns: | |
List[Document]: Document results for search. | |
""" | |
data, res = self._worker_search( | |
query, | |
fetch_k, | |
param, | |
expr, | |
partition_names, | |
round_decimal, | |
timeout, | |
**kwargs, | |
) | |
# Extract result IDs. | |
ids = [x for _, _, x in res] | |
# Get the raw vectors from Milvus. | |
vectors = self.col.query( | |
expr=f"{self.primary_field} in {ids}", | |
output_fields=[self.primary_field, self.vector_field], | |
) | |
# Reorganize the results from query to match result order. | |
vectors = {x[self.primary_field]: x[self.vector_field] for x in vectors} | |
search_embedding = data | |
ordered_result_embeddings = [vectors[x] for x in ids] | |
# Get the new order of results. | |
new_ordering = maximal_marginal_relevance( | |
np.array(search_embedding), ordered_result_embeddings, k=k | |
) | |
# Reorder the values and return. | |
ret = [] | |
for x in new_ordering: | |
if x == -1: | |
break | |
else: | |
ret.append(res[x][0]) | |
return ret | |
def similarity_search( | |
self, | |
query: str, | |
k: int = 4, | |
param: Optional[dict] = None, | |
expr: Optional[str] = None, | |
partition_names: Optional[List[str]] = None, | |
round_decimal: int = -1, | |
timeout: Optional[int] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Perform a similarity search against the query string. | |
Args: | |
query (str): The text to search. | |
k (int, optional): How many results to return. Defaults to 4. | |
param (dict, optional): The search params for the index type. | |
Defaults to None. | |
expr (str, optional): Filtering expression. Defaults to None. | |
partition_names (List[str], optional): What partitions to search. | |
Defaults to None. | |
round_decimal (int, optional): What decimal point to round to. | |
Defaults to -1. | |
timeout (int, optional): How long to wait before timeout error. | |
Defaults to None. | |
Returns: | |
List[Document]: Document results for search. | |
""" | |
_, docs_and_scores = self._worker_search( | |
query, k, param, expr, partition_names, round_decimal, timeout, **kwargs | |
) | |
return [doc for doc, _, _ in docs_and_scores] | |
def from_texts( | |
cls, | |
texts: List[str], | |
embedding: Embeddings, | |
metadatas: Optional[List[dict]] = None, | |
**kwargs: Any, | |
) -> Milvus: | |
"""Create a Milvus collection, indexes it with HNSW, and insert data. | |
Args: | |
texts (List[str]): Text to insert. | |
embedding (Embeddings): Embedding function to use. | |
metadatas (Optional[List[dict]], optional): Dict metatadata. | |
Defaults to None. | |
Returns: | |
VectorStore: The Milvus vector store. | |
""" | |
try: | |
from pymilvus import ( | |
Collection, | |
CollectionSchema, | |
DataType, | |
FieldSchema, | |
connections, | |
) | |
from pymilvus.orm.types import infer_dtype_bydata | |
except ImportError: | |
raise ValueError( | |
"Could not import pymilvus python package. " | |
"Please install it with `pip install pymilvus`." | |
) | |
# Connect to Milvus instance | |
if not connections.has_connection("default"): | |
connections.connect(**kwargs.get("connection_args", {"port": 19530})) | |
# Determine embedding dim | |
embeddings = embedding.embed_query(texts[0]) | |
dim = len(embeddings) | |
# Generate unique names | |
primary_field = "c" + str(uuid.uuid4().hex) | |
vector_field = "c" + str(uuid.uuid4().hex) | |
text_field = "c" + str(uuid.uuid4().hex) | |
collection_name = "c" + str(uuid.uuid4().hex) | |
fields = [] | |
# Determine metadata schema | |
if metadatas: | |
# Check if all metadata keys line up | |
key = metadatas[0].keys() | |
for x in metadatas: | |
if key != x.keys(): | |
raise ValueError( | |
"Mismatched metadata. " | |
"Make sure all metadata has the same keys and datatype." | |
) | |
# Create FieldSchema for each entry in singular metadata. | |
for key, value in metadatas[0].items(): | |
# Infer the corresponding datatype of the metadata | |
dtype = infer_dtype_bydata(value) | |
if dtype == DataType.UNKNOWN: | |
raise ValueError(f"Unrecognized datatype for {key}.") | |
elif dtype == DataType.VARCHAR: | |
# Find out max length text based metadata | |
max_length = 0 | |
for subvalues in metadatas: | |
max_length = max(max_length, len(subvalues[key])) | |
fields.append( | |
FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1) | |
) | |
else: | |
fields.append(FieldSchema(key, dtype)) | |
# Find out max length of texts | |
max_length = 0 | |
for y in texts: | |
max_length = max(max_length, len(y)) | |
# Create the text field | |
fields.append( | |
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1) | |
) | |
# Create the primary key field | |
fields.append( | |
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True) | |
) | |
# Create the vector field | |
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) | |
# Create the schema for the collection | |
schema = CollectionSchema(fields) | |
# Create the collection | |
collection = Collection(collection_name, schema) | |
# Index parameters for the collection | |
index = { | |
"index_type": "HNSW", | |
"metric_type": "L2", | |
"params": {"M": 8, "efConstruction": 64}, | |
} | |
# Create the index | |
collection.create_index(vector_field, index) | |
# Create the VectorStore | |
milvus = cls( | |
embedding, | |
kwargs.get("connection_args", {"port": 19530}), | |
collection_name, | |
text_field, | |
) | |
# Add the texts. | |
milvus.add_texts(texts, metadatas) | |
return milvus | |