"""Wrapper around Qdrant vector database.""" import uuid from operator import itemgetter from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance MetadataFilter = Dict[str, Union[str, int, bool]] class Qdrant(VectorStore): """Wrapper around Qdrant vector database. To use you should have the ``qdrant-client`` package installed. Example: .. code-block:: python from langchain import Qdrant client = QdrantClient() collection_name = "MyCollection" qdrant = Qdrant(client, collection_name, embedding_function) """ CONTENT_KEY = "page_content" METADATA_KEY = "metadata" def __init__( self, client: Any, collection_name: str, embedding_function: Callable, content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, ): """Initialize with necessary components.""" try: import qdrant_client except ImportError: raise ValueError( "Could not import qdrant-client python package. " "Please install it with `pip install qdrant-client`." ) if not isinstance(client, qdrant_client.QdrantClient): raise ValueError( f"client should be an instance of qdrant_client.QdrantClient, " f"got {type(client)}" ) self.client: qdrant_client.QdrantClient = client self.collection_name = collection_name self.embedding_function = embedding_function self.content_payload_key = content_payload_key or self.CONTENT_KEY self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. Returns: List of ids from adding the texts into the vectorstore. """ from qdrant_client.http import models as rest ids = [uuid.uuid4().hex for _ in texts] self.client.upsert( collection_name=self.collection_name, points=rest.Batch( ids=ids, vectors=[self.embedding_function(text) for text in texts], payloads=self._build_payloads( texts, metadatas, self.content_payload_key, self.metadata_payload_key, ), ), ) return ids def similarity_search( self, query: str, k: int = 4, filter: Optional[MetadataFilter] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to query. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter: Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query. """ results = self.similarity_search_with_score(query, k, filter) return list(map(itemgetter(0), results)) def similarity_search_with_score( self, query: str, k: int = 4, filter: Optional[MetadataFilter] = None ) -> List[Tuple[Document, float]]: """Return docs most similar to query. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter: Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query and score for each """ embedding = self.embedding_function(query) results = self.client.search( collection_name=self.collection_name, query_vector=embedding, query_filter=self._qdrant_filter_from_dict(filter), with_payload=True, limit=k, ) return [ ( self._document_from_scored_point( result, self.content_payload_key, self.metadata_payload_key ), result.score, ) for result in results ] def max_marginal_relevance_search( self, query: str, k: int = 4, fetch_k: int = 20 ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch to pass to MMR algorithm. Returns: List of Documents selected by maximal marginal relevance. """ embedding = self.embedding_function(query) results = self.client.search( collection_name=self.collection_name, query_vector=embedding, with_payload=True, with_vectors=True, limit=k, ) embeddings = [result.vector for result in results] mmr_selected = maximal_marginal_relevance(embedding, embeddings, k=k) return [ self._document_from_scored_point( results[i], self.content_payload_key, self.metadata_payload_key ) for i in mmr_selected ] @classmethod def from_documents( cls, documents: List[Document], embedding: Embeddings, url: Optional[str] = None, port: Optional[int] = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, https: Optional[bool] = None, api_key: Optional[str] = None, prefix: Optional[str] = None, timeout: Optional[float] = None, host: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, **kwargs: Any, ) -> "Qdrant": return cast( Qdrant, super().from_documents( documents, embedding, url=url, port=port, grpc_port=grpc_port, prefer_grpc=prefer_grpc, https=https, api_key=api_key, prefix=prefix, timeout=timeout, host=host, collection_name=collection_name, distance_func=distance_func, content_payload_key=content_payload_key, metadata_payload_key=metadata_payload_key, **kwargs, ), ) @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, url: Optional[str] = None, port: Optional[int] = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, https: Optional[bool] = None, api_key: Optional[str] = None, prefix: Optional[str] = None, timeout: Optional[float] = None, host: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", content_payload_key: str = CONTENT_KEY, metadata_payload_key: str = METADATA_KEY, **kwargs: Any, ) -> "Qdrant": """Construct Qdrant wrapper from raw documents. Args: texts: A list of texts to be indexed in Qdrant. embedding: A subclass of `Embeddings`, responsible for text vectorization. metadatas: An optional list of metadata. If provided it has to be of the same length as a list of texts. url: either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Default: `None` port: Port of the REST API interface. Default: 6333 grpc_port: Port of the gRPC interface. Default: 6334 prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods. https: If `true` - use HTTPS(SSL) protocol. Default: `None` api_key: API key for authentication in Qdrant Cloud. Default: `None` prefix: If not `None` - add `prefix` to the REST URL path. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API. Default: `None` timeout: Timeout for REST and gRPC API requests. Default: 5.0 seconds for REST and unlimited for gRPC host: Host name of Qdrant service. If url and host are None, set to 'localhost'. Default: `None` collection_name: Name of the Qdrant collection to be used. If not provided, will be created randomly. distance_func: Distance function. One of the: "Cosine" / "Euclid" / "Dot". content_payload_key: A payload key used to store the content of the document. metadata_payload_key: A payload key used to store the metadata of the document. **kwargs: Additional arguments passed directly into REST client initialization This is a user friendly interface that: 1. Embeds documents. 2. Creates an in memory docstore 3. Initializes the Qdrant database This is intended to be a quick way to get started. Example: .. code-block:: python from langchain import Qdrant from langchain.embeddings import OpenAIEmbeddings embeddings = OpenAIEmbeddings() qdrant = Qdrant.from_texts(texts, embeddings, "localhost") """ try: import qdrant_client except ImportError: raise ValueError( "Could not import qdrant-client python package. " "Please install it with `pip install qdrant-client`." ) from qdrant_client.http import models as rest # Just do a single quick embedding to get vector size partial_embeddings = embedding.embed_documents(texts[:1]) vector_size = len(partial_embeddings[0]) collection_name = collection_name or uuid.uuid4().hex distance_func = distance_func.upper() client = qdrant_client.QdrantClient( url=url, port=port, grpc_port=grpc_port, prefer_grpc=prefer_grpc, https=https, api_key=api_key, prefix=prefix, timeout=timeout, host=host, **kwargs, ) client.recreate_collection( collection_name=collection_name, vectors_config=rest.VectorParams( size=vector_size, distance=rest.Distance[distance_func], ), ) # Now generate the embeddings for all the texts embeddings = embedding.embed_documents(texts) client.upsert( collection_name=collection_name, points=rest.Batch( ids=[uuid.uuid4().hex for _ in texts], vectors=embeddings, payloads=cls._build_payloads( texts, metadatas, content_payload_key, metadata_payload_key ), ), ) return cls( client=client, collection_name=collection_name, embedding_function=embedding.embed_query, content_payload_key=content_payload_key, metadata_payload_key=metadata_payload_key, ) @classmethod def _build_payloads( cls, texts: Iterable[str], metadatas: Optional[List[dict]], content_payload_key: str, metadata_payload_key: str, ) -> List[dict]: payloads = [] for i, text in enumerate(texts): if text is None: raise ValueError( "At least one of the texts is None. Please remove it before " "calling .from_texts or .add_texts on Qdrant instance." ) metadata = metadatas[i] if metadatas is not None else None payloads.append( { content_payload_key: text, metadata_payload_key: metadata, } ) return payloads @classmethod def _document_from_scored_point( cls, scored_point: Any, content_payload_key: str, metadata_payload_key: str, ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), metadata=scored_point.payload.get(metadata_payload_key) or {}, ) def _qdrant_filter_from_dict(self, filter: Optional[MetadataFilter]) -> Any: if filter is None or 0 == len(filter): return None from qdrant_client.http import models as rest return rest.Filter( must=[ rest.FieldCondition( key=f"{self.metadata_payload_key}.{key}", match=rest.MatchValue(value=value), ) for key, value in filter.items() ] )