refactor: excel parse

This commit is contained in:
Blizzard
2026-04-16 10:01:11 +08:00
parent 680ecc320f
commit f62f95ec02
7941 changed files with 2899112 additions and 0 deletions
@@ -0,0 +1,2 @@
from .async_qdrant_client import AsyncQdrantClient as AsyncQdrantClient
from .qdrant_client import QdrantClient as QdrantClient
@@ -0,0 +1,69 @@
import json
from typing import Any, Type, TypeVar
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
if PYDANTIC_V2:
import pydantic_core
to_jsonable_python = pydantic_core.to_jsonable_python
else:
from pydantic.json import ENCODERS_BY_TYPE
def to_jsonable_python(x: Any) -> Any:
return ENCODERS_BY_TYPE[type(x)](x)
def update_forward_refs(model_class: Type[BaseModel], *args: Any, **kwargs: Any) -> None:
if PYDANTIC_V2:
model_class.model_rebuild(*args, **kwargs)
else:
model_class.update_forward_refs(*args, **kwargs)
def construct(model_class: Type[Model], *args: Any, **kwargs: Any) -> Model:
if PYDANTIC_V2:
return model_class.model_construct(*args, **kwargs)
else:
return model_class.construct(*args, **kwargs)
def to_dict(model: BaseModel, *args: Any, **kwargs: Any) -> dict[Any, Any]:
if PYDANTIC_V2:
return model.model_dump(*args, **kwargs)
else:
return model.dict(*args, **kwargs)
def model_fields_set(model: BaseModel) -> set:
if PYDANTIC_V2:
return model.model_fields_set
else:
return model.__fields_set__
def model_fields(model: Type[BaseModel]) -> dict:
if PYDANTIC_V2:
return model.model_fields # type: ignore # pydantic type issue
else:
return model.__fields__
def model_json_schema(model: Type[BaseModel], *args: Any, **kwargs: Any) -> dict[str, Any]:
if PYDANTIC_V2:
return model.model_json_schema(*args, **kwargs)
else:
return json.loads(model.schema_json(*args, **kwargs))
def model_config(model: Type[BaseModel]) -> dict[str, Any]:
if PYDANTIC_V2:
return model.model_config
else:
return dict(vars(model.__config__))
@@ -0,0 +1,386 @@
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
#
# This file is autogenerated. Do not edit it manually.
# To regenerate this file, use
#
# ```
# bash -x tools/generate_async_client.sh
# ```
#
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
from typing import Any, Iterable, Mapping, Optional, Sequence, Union
from qdrant_client.conversions import common_types as types
class AsyncQdrantBase:
def __init__(self, **kwargs: Any):
pass
async def search_matrix_offsets(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixOffsetsResponse:
raise NotImplementedError()
async def search_matrix_pairs(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixPairsResponse:
raise NotImplementedError()
async def query_batch_points(
self, collection_name: str, requests: Sequence[types.QueryRequest], **kwargs: Any
) -> list[types.QueryResponse]:
raise NotImplementedError()
async def query_points(
self,
collection_name: str,
query: Union[
types.PointId,
list[float],
list[list[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
types.Image,
types.InferenceObject,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
offset: Optional[int] = None,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
lookup_from: Optional[types.LookupLocation] = None,
**kwargs: Any,
) -> types.QueryResponse:
raise NotImplementedError()
async def query_points_groups(
self,
collection_name: str,
group_by: str,
query: Union[
types.PointId,
list[float],
list[list[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
types.Image,
types.InferenceObject,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
group_size: int = 3,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
with_lookup: Optional[types.WithLookupInterface] = None,
lookup_from: Optional[types.LookupLocation] = None,
**kwargs: Any,
) -> types.GroupsResult:
raise NotImplementedError()
async def scroll(
self,
collection_name: str,
scroll_filter: Optional[types.Filter] = None,
limit: int = 10,
order_by: Optional[types.OrderBy] = None,
offset: Optional[types.PointId] = None,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
**kwargs: Any,
) -> tuple[list[types.Record], Optional[types.PointId]]:
raise NotImplementedError()
async def count(
self,
collection_name: str,
count_filter: Optional[types.Filter] = None,
exact: bool = True,
**kwargs: Any,
) -> types.CountResult:
raise NotImplementedError()
async def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
**kwargs: Any,
) -> types.FacetResponse:
raise NotImplementedError()
async def upsert(
self, collection_name: str, points: types.Points, **kwargs: Any
) -> types.UpdateResult:
raise NotImplementedError()
async def update_vectors(
self, collection_name: str, points: Sequence[types.PointVectors], **kwargs: Any
) -> types.UpdateResult:
raise NotImplementedError()
async def delete_vectors(
self,
collection_name: str,
vectors: Sequence[str],
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
async def retrieve(
self,
collection_name: str,
ids: Sequence[types.PointId],
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
**kwargs: Any,
) -> list[types.Record]:
raise NotImplementedError()
async def delete(
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
) -> types.UpdateResult:
raise NotImplementedError()
async def set_payload(
self,
collection_name: str,
payload: types.Payload,
points: types.PointsSelector,
key: Optional[str] = None,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
async def overwrite_payload(
self,
collection_name: str,
payload: types.Payload,
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
async def delete_payload(
self,
collection_name: str,
keys: Sequence[str],
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
async def clear_payload(
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
) -> types.UpdateResult:
raise NotImplementedError()
async def batch_update_points(
self,
collection_name: str,
update_operations: Sequence[types.UpdateOperation],
**kwargs: Any,
) -> list[types.UpdateResult]:
raise NotImplementedError()
async def update_collection_aliases(
self, change_aliases_operations: Sequence[types.AliasOperations], **kwargs: Any
) -> bool:
raise NotImplementedError()
async def get_collection_aliases(
self, collection_name: str, **kwargs: Any
) -> types.CollectionsAliasesResponse:
raise NotImplementedError()
async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
raise NotImplementedError()
async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
raise NotImplementedError()
async def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
raise NotImplementedError()
async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
raise NotImplementedError()
async def update_collection(self, collection_name: str, **kwargs: Any) -> bool:
raise NotImplementedError()
async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
raise NotImplementedError()
async def create_collection(
self,
collection_name: str,
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
**kwargs: Any,
) -> bool:
raise NotImplementedError()
async def recreate_collection(
self,
collection_name: str,
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
**kwargs: Any,
) -> bool:
raise NotImplementedError()
def upload_points(
self, collection_name: str, points: Iterable[types.PointStruct], **kwargs: Any
) -> None:
raise NotImplementedError()
def upload_collection(
self,
collection_name: str,
vectors: Union[
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
],
payload: Optional[Iterable[dict[Any, Any]]] = None,
ids: Optional[Iterable[types.PointId]] = None,
**kwargs: Any,
) -> None:
raise NotImplementedError()
async def create_payload_index(
self,
collection_name: str,
field_name: str,
field_schema: Optional[types.PayloadSchemaType] = None,
field_type: Optional[types.PayloadSchemaType] = None,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
async def delete_payload_index(
self, collection_name: str, field_name: str, **kwargs: Any
) -> types.UpdateResult:
raise NotImplementedError()
async def list_snapshots(
self, collection_name: str, **kwargs: Any
) -> list[types.SnapshotDescription]:
raise NotImplementedError()
async def create_snapshot(
self, collection_name: str, **kwargs: Any
) -> Optional[types.SnapshotDescription]:
raise NotImplementedError()
async def delete_snapshot(
self, collection_name: str, snapshot_name: str, **kwargs: Any
) -> Optional[bool]:
raise NotImplementedError()
async def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
raise NotImplementedError()
async def create_full_snapshot(self, **kwargs: Any) -> Optional[types.SnapshotDescription]:
raise NotImplementedError()
async def delete_full_snapshot(self, snapshot_name: str, **kwargs: Any) -> Optional[bool]:
raise NotImplementedError()
async def recover_snapshot(
self, collection_name: str, location: str, **kwargs: Any
) -> Optional[bool]:
raise NotImplementedError()
async def list_shard_snapshots(
self, collection_name: str, shard_id: int, **kwargs: Any
) -> list[types.SnapshotDescription]:
raise NotImplementedError()
async def create_shard_snapshot(
self, collection_name: str, shard_id: int, **kwargs: Any
) -> Optional[types.SnapshotDescription]:
raise NotImplementedError()
async def delete_shard_snapshot(
self, collection_name: str, shard_id: int, snapshot_name: str, **kwargs: Any
) -> Optional[bool]:
raise NotImplementedError()
async def recover_shard_snapshot(
self, collection_name: str, shard_id: int, location: str, **kwargs: Any
) -> Optional[bool]:
raise NotImplementedError()
async def close(self, **kwargs: Any) -> None:
pass
def migrate(
self,
dest_client: "AsyncQdrantBase",
collection_names: Optional[list[str]] = None,
batch_size: int = 100,
recreate_on_collision: bool = False,
) -> None:
raise NotImplementedError()
async def create_shard_key(
self,
collection_name: str,
shard_key: types.ShardKey,
shards_number: Optional[int] = None,
replication_factor: Optional[int] = None,
placement: Optional[list[int]] = None,
**kwargs: Any,
) -> bool:
raise NotImplementedError()
async def delete_shard_key(
self, collection_name: str, shard_key: types.ShardKey, **kwargs: Any
) -> bool:
raise NotImplementedError()
async def info(self) -> types.VersionInfo:
raise NotImplementedError()
async def cluster_collection_update(
self, collection_name: str, cluster_operation: types.ClusterOperations, **kwargs: Any
) -> bool:
raise NotImplementedError()
async def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo:
raise NotImplementedError()
async def cluster_status(self) -> types.ClusterStatus:
raise NotImplementedError()
async def recover_current_peer(self) -> bool:
raise NotImplementedError()
async def remove_peer(self, peer_id: int, **kwargs: Any) -> bool:
raise NotImplementedError()
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,842 @@
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
#
# This file is autogenerated. Do not edit it manually.
# To regenerate this file, use
#
# ```
# bash -x tools/generate_async_client.sh
# ```
#
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
import uuid
from itertools import tee
from typing import Any, Iterable, Optional, Sequence, Union, get_args
from copy import deepcopy
import numpy as np
from pydantic import BaseModel
from qdrant_client import grpc
from qdrant_client.common.client_warnings import show_warning, show_warning_once
from qdrant_client.async_client_base import AsyncQdrantBase
from qdrant_client.embed.embedder import Embedder
from qdrant_client.embed.model_embedder import ModelEmbedder
from qdrant_client.http import models
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.conversion import GrpcToRest
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.schema_parser import ModelSchemaParser
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion
from qdrant_client.fastembed_common import FastEmbedMisc, OnnxProvider
from qdrant_client.fastembed_common import (
QueryResponse,
TextEmbedding,
SparseTextEmbedding,
IDF_EMBEDDING_MODELS,
)
class AsyncQdrantFastembedMixin(AsyncQdrantBase):
DEFAULT_EMBEDDING_MODEL = "BAAI/bge-small-en"
DEFAULT_BATCH_SIZE = 8
_FASTEMBED_INSTALLED: bool
def __init__(
self, parser: ModelSchemaParser, is_local_mode: bool, server_version: Optional[str]
):
self.__class__._FASTEMBED_INSTALLED = FastEmbedMisc.is_installed()
self._embedding_model_name: Optional[str] = None
self._sparse_embedding_model_name: Optional[str] = None
self._model_embedder = ModelEmbedder(
parser=parser, is_local_mode=is_local_mode, server_version=server_version
)
super().__init__()
@classmethod
async def list_text_models(cls) -> dict[str, tuple[int, models.Distance]]:
"""Lists the supported dense text models.
Returns:
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
"""
return FastEmbedMisc.list_text_models()
@classmethod
async def list_image_models(cls) -> dict[str, tuple[int, models.Distance]]:
"""Lists the supported image dense models.
Returns:
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
"""
return FastEmbedMisc.list_image_models()
@classmethod
async def list_late_interaction_text_models(cls) -> dict[str, tuple[int, models.Distance]]:
"""Lists the supported late interaction text models.
Returns:
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
"""
return FastEmbedMisc.list_late_interaction_text_models()
@classmethod
async def list_late_interaction_multimodal_models(
cls,
) -> dict[str, tuple[int, models.Distance]]:
"""Lists the supported late interaction multimodal models.
Returns:
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
"""
return FastEmbedMisc.list_late_interaction_multimodal_models()
@classmethod
async def list_sparse_models(cls) -> dict[str, dict[str, Any]]:
"""Lists the supported sparse text models.
Returns:
dict[str, dict[str, Any]]: A dict of model names and their descriptions.
"""
return FastEmbedMisc.list_sparse_models()
@property
def embedding_model_name(self) -> str:
if self._embedding_model_name is None:
self._embedding_model_name = self.DEFAULT_EMBEDDING_MODEL
return self._embedding_model_name
@property
def sparse_embedding_model_name(self) -> Optional[str]:
return self._sparse_embedding_model_name
def set_model(
self,
embedding_model_name: str,
max_length: Optional[int] = None,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
**kwargs: Any,
) -> None:
"""
Set embedding model to use for encoding documents and queries.
Args:
embedding_model_name: One of the supported embedding models. See `SUPPORTED_EMBEDDING_MODELS` for details.
max_length (int, optional): Deprecated. Defaults to None.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
providers: The list of onnx providers (with or without options) to use. Defaults to None.
Example configuration:
https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
Raises:
ValueError: If embedding model is not supported.
ImportError: If fastembed is not installed.
Returns:
None
"""
if max_length is not None:
show_warning(
message="max_length parameter is deprecated and will be removed in the future. It's not used by fastembed models.",
category=DeprecationWarning,
stacklevel=3,
)
self._get_or_init_model(
model_name=embedding_model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
cuda=cuda,
device_ids=device_ids,
lazy_load=lazy_load,
deprecated=True,
**kwargs,
)
self._embedding_model_name = embedding_model_name
def set_sparse_model(
self,
embedding_model_name: Optional[str],
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
lazy_load: bool = False,
**kwargs: Any,
) -> None:
"""
Set sparse embedding model to use for hybrid search over documents in combination with dense embeddings.
Args:
embedding_model_name: One of the supported sparse embedding models. See `SUPPORTED_SPARSE_EMBEDDING_MODELS` for details.
If None, sparse embeddings will not be used.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
providers: The list of onnx providers (with or without options) to use. Defaults to None.
Example configuration:
https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
Raises:
ValueError: If embedding model is not supported.
ImportError: If fastembed is not installed.
Returns:
None
"""
if embedding_model_name is not None:
self._get_or_init_sparse_model(
model_name=embedding_model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
cuda=cuda,
device_ids=device_ids,
lazy_load=lazy_load,
deprecated=True,
**kwargs,
)
self._sparse_embedding_model_name = embedding_model_name
@classmethod
def _get_model_params(cls, model_name: str) -> tuple[int, models.Distance]:
FastEmbedMisc.import_fastembed()
for descriptions in (
FastEmbedMisc.list_text_models(),
FastEmbedMisc.list_image_models(),
FastEmbedMisc.list_late_interaction_text_models(),
FastEmbedMisc.list_late_interaction_multimodal_models(),
):
if params := descriptions.get(model_name):
return params
if model_name in FastEmbedMisc.list_sparse_models():
raise ValueError(
"Sparse embeddings do not return fixed embedding size and distance type"
)
raise ValueError(f"Unsupported embedding model: {model_name}")
def _get_or_init_model(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
deprecated: bool = False,
**kwargs: Any,
) -> "TextEmbedding":
FastEmbedMisc.import_fastembed()
assert isinstance(self._model_embedder.embedder, Embedder)
return self._model_embedder.embedder.get_or_init_model(
model_name=model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
deprecated=deprecated,
**kwargs,
)
def _get_or_init_sparse_model(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
deprecated: bool = False,
**kwargs: Any,
) -> "SparseTextEmbedding":
FastEmbedMisc.import_fastembed()
assert isinstance(self._model_embedder.embedder, Embedder)
return self._model_embedder.embedder.get_or_init_sparse_model(
model_name=model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
deprecated=deprecated,
**kwargs,
)
def _embed_documents(
self,
documents: Iterable[str],
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
batch_size: int = 32,
embed_type: str = "default",
parallel: Optional[int] = None,
) -> Iterable[tuple[str, list[float]]]:
embedding_model = self._get_or_init_model(model_name=embedding_model_name, deprecated=True)
(documents_a, documents_b) = tee(documents, 2)
if embed_type == "passage":
vectors_iter = embedding_model.passage_embed(
documents_a, batch_size=batch_size, parallel=parallel
)
elif embed_type == "query":
vectors_iter = (
list(embedding_model.query_embed(query=query))[0] for query in documents_a
)
elif embed_type == "default":
vectors_iter = embedding_model.embed(
documents_a, batch_size=batch_size, parallel=parallel
)
else:
raise ValueError(f"Unknown embed type: {embed_type}")
for vector, doc in zip(vectors_iter, documents_b):
yield (doc, vector.tolist())
def _sparse_embed_documents(
self,
documents: Iterable[str],
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
batch_size: int = 32,
parallel: Optional[int] = None,
) -> Iterable[types.SparseVector]:
sparse_embedding_model = self._get_or_init_sparse_model(
model_name=embedding_model_name, deprecated=True
)
vectors_iter = sparse_embedding_model.embed(
documents, batch_size=batch_size, parallel=parallel
)
for sparse_vector in vectors_iter:
yield types.SparseVector(
indices=sparse_vector.indices.tolist(), values=sparse_vector.values.tolist()
)
def get_vector_field_name(self) -> str:
"""
Returns name of the vector field in qdrant collection, used by current fastembed model.
Returns:
Name of the vector field.
"""
model_name = self.embedding_model_name.split("/")[-1].lower()
return f"fast-{model_name}"
def get_sparse_vector_field_name(self) -> Optional[str]:
"""
Returns name of the vector field in qdrant collection, used by current fastembed model.
Returns:
Name of the vector field.
"""
if self.sparse_embedding_model_name is not None:
model_name = self.sparse_embedding_model_name.split("/")[-1].lower()
return f"fast-sparse-{model_name}"
return None
def _scored_points_to_query_responses(
self, scored_points: list[types.ScoredPoint]
) -> list[QueryResponse]:
response = []
vector_field_name = self.get_vector_field_name()
sparse_vector_field_name = self.get_sparse_vector_field_name()
for scored_point in scored_points:
embedding = (
scored_point.vector.get(vector_field_name, None)
if isinstance(scored_point.vector, dict)
else None
)
sparse_embedding = None
if sparse_vector_field_name is not None:
sparse_embedding = (
scored_point.vector.get(sparse_vector_field_name, None)
if isinstance(scored_point.vector, dict)
else None
)
response.append(
QueryResponse(
id=scored_point.id,
embedding=embedding,
sparse_embedding=sparse_embedding,
metadata=scored_point.payload,
document=scored_point.payload.get("document", ""),
score=scored_point.score,
)
)
return response
def _points_iterator(
self,
ids: Optional[Iterable[models.ExtendedPointId]],
metadata: Optional[Iterable[dict[str, Any]]],
encoded_docs: Iterable[tuple[str, list[float]]],
ids_accumulator: list,
sparse_vectors: Optional[Iterable[types.SparseVector]] = None,
) -> Iterable[models.PointStruct]:
if ids is None:
ids = iter(lambda: uuid.uuid4().hex, None)
if metadata is None:
metadata = iter(lambda: {}, None)
if sparse_vectors is None:
sparse_vectors = iter(lambda: None, True)
vector_name = self.get_vector_field_name()
sparse_vector_name = self.get_sparse_vector_field_name()
for idx, meta, (doc, vector), sparse_vector in zip(
ids, metadata, encoded_docs, sparse_vectors
):
ids_accumulator.append(idx)
payload = {"document": doc, **meta}
point_vector: dict[str, models.Vector] = {vector_name: vector}
if sparse_vector_name is not None and sparse_vector is not None:
point_vector[sparse_vector_name] = sparse_vector
yield models.PointStruct(id=idx, payload=payload, vector=point_vector)
def _validate_collection_info(self, collection_info: models.CollectionInfo) -> None:
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
vector_field_name = self.get_vector_field_name()
assert isinstance(
collection_info.config.params.vectors, dict
), f"Collection have incompatible vector params: {collection_info.config.params.vectors}"
assert (
vector_field_name in collection_info.config.params.vectors
), f"Collection have incompatible vector params: {collection_info.config.params.vectors}, expected {vector_field_name}"
vector_params = collection_info.config.params.vectors[vector_field_name]
assert (
embeddings_size == vector_params.size
), f"Embedding size mismatch: {embeddings_size} != {vector_params.size}"
assert (
distance == vector_params.distance
), f"Distance mismatch: {distance} != {vector_params.distance}"
sparse_vector_field_name = self.get_sparse_vector_field_name()
if sparse_vector_field_name is not None:
assert (
sparse_vector_field_name in collection_info.config.params.sparse_vectors
), f"Collection have incompatible vector params: {collection_info.config.params.vectors}"
if self.sparse_embedding_model_name in IDF_EMBEDDING_MODELS:
modifier = collection_info.config.params.sparse_vectors[
sparse_vector_field_name
].modifier
assert (
modifier == models.Modifier.IDF
), f"{self.sparse_embedding_model_name} requires modifier IDF, current modifier is {modifier}"
def get_embedding_size(self, model_name: Optional[str] = None) -> int:
"""Get the size of the embeddings produced by the specified model.
Args:
model_name: optional, the name of the model to get the embedding size for. If None, the default model will
be used.
Returns:
int: the size of the embeddings produced by the model.
Raises:
ValueError: If sparse model name is passed or model is not found in the supported models.
"""
model_name = model_name or self.embedding_model_name
(embeddings_size, _) = self._get_model_params(model_name=model_name)
return embeddings_size
def get_fastembed_vector_params(
self,
on_disk: Optional[bool] = None,
quantization_config: Optional[models.QuantizationConfig] = None,
hnsw_config: Optional[models.HnswConfigDiff] = None,
) -> dict[str, models.VectorParams]:
"""
Generates vector configuration, compatible with fastembed models.
Args:
on_disk: if True, vectors will be stored on disk. If None, default value will be used.
quantization_config: Quantization configuration. If None, quantization will be disabled.
hnsw_config: HNSW configuration. If None, default configuration will be used.
Returns:
Configuration for `vectors_config` argument in `create_collection` method.
"""
vector_field_name = self.get_vector_field_name()
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
return {
vector_field_name: models.VectorParams(
size=embeddings_size,
distance=distance,
on_disk=on_disk,
quantization_config=quantization_config,
hnsw_config=hnsw_config,
)
}
def get_fastembed_sparse_vector_params(
self, on_disk: Optional[bool] = None, modifier: Optional[models.Modifier] = None
) -> Optional[dict[str, models.SparseVectorParams]]:
"""
Generates vector configuration, compatible with fastembed sparse models.
Args:
on_disk: if True, vectors will be stored on disk. If None, default value will be used.
modifier: Sparse vector queries modifier. E.g. Modifier.IDF for idf-based rescoring. Default: None.
Returns:
Configuration for `vectors_config` argument in `create_collection` method.
"""
vector_field_name = self.get_sparse_vector_field_name()
if self.sparse_embedding_model_name in IDF_EMBEDDING_MODELS:
modifier = models.Modifier.IDF if modifier is None else modifier
if vector_field_name is None:
return None
return {
vector_field_name: models.SparseVectorParams(
index=models.SparseIndexParams(on_disk=on_disk), modifier=modifier
)
}
async def add(
self,
collection_name: str,
documents: Iterable[str],
metadata: Optional[Iterable[dict[str, Any]]] = None,
ids: Optional[Iterable[models.ExtendedPointId]] = None,
batch_size: int = 32,
parallel: Optional[int] = None,
**kwargs: Any,
) -> list[Union[str, int]]:
"""
Adds text documents into qdrant collection.
If collection does not exist, it will be created with default parameters.
Metadata in combination with documents will be added as payload.
Documents will be embedded using the specified embedding model.
If you want to use your own vectors, use `upsert` method instead.
Args:
collection_name (str):
Name of the collection to add documents to.
documents (Iterable[str]):
List of documents to embed and add to the collection.
metadata (Iterable[dict[str, Any]], optional):
List of metadata dicts. Defaults to None.
ids (Iterable[models.ExtendedPointId], optional):
List of ids to assign to documents.
If not specified, UUIDs will be generated. Defaults to None.
batch_size (int, optional):
How many documents to embed and upload in single request. Defaults to 32.
parallel (Optional[int], optional):
How many parallel workers to use for embedding. Defaults to None.
If number is specified, data-parallel process will be used.
Raises:
ImportError: If fastembed is not installed.
Returns:
List of IDs of added documents. If no ids provided, UUIDs will be randomly generated on client side.
"""
show_warning_once(
"`add` method has been deprecated and will be removed in 1.17. Instead, inference can be done internally within regular methods like `upsert` by wrapping data into `models.Document` or `models.Image`."
)
encoded_docs = self._embed_documents(
documents=documents,
embedding_model_name=self.embedding_model_name,
batch_size=batch_size,
embed_type="passage",
parallel=parallel,
)
encoded_sparse_docs = None
if self.sparse_embedding_model_name is not None:
encoded_sparse_docs = self._sparse_embed_documents(
documents=documents,
embedding_model_name=self.sparse_embedding_model_name,
batch_size=batch_size,
parallel=parallel,
)
try:
collection_info = await self.get_collection(collection_name=collection_name)
except Exception:
await self.create_collection(
collection_name=collection_name,
vectors_config=self.get_fastembed_vector_params(),
sparse_vectors_config=self.get_fastembed_sparse_vector_params(),
)
collection_info = await self.get_collection(collection_name=collection_name)
self._validate_collection_info(collection_info)
inserted_ids: list = []
points = self._points_iterator(
ids=ids,
metadata=metadata,
encoded_docs=encoded_docs,
ids_accumulator=inserted_ids,
sparse_vectors=encoded_sparse_docs,
)
self.upload_points(
collection_name=collection_name,
points=points,
wait=True,
parallel=parallel or 1,
batch_size=batch_size,
**kwargs,
)
return inserted_ids
async def query(
self,
collection_name: str,
query_text: str,
query_filter: Optional[models.Filter] = None,
limit: int = 10,
**kwargs: Any,
) -> list[QueryResponse]:
"""
Search for documents in a collection.
This method automatically embeds the query text using the specified embedding model.
If you want to use your own query vector, use `search` method instead.
Args:
collection_name: Collection to search in
query_text:
Text to search for. This text will be embedded using the specified embedding model.
And then used as a query vector.
query_filter:
- Exclude vectors which doesn't fit given conditions.
- If `None` - search among all vectors
limit: How many results return
**kwargs: Additional search parameters. See `qdrant_client.models.QueryRequest` for details.
Returns:
list[types.ScoredPoint]: List of scored points.
"""
show_warning_once(
"`query` method has been deprecated and will be removed in 1.17. Instead, inference can be done internally within regular methods like `query_points` by wrapping data into `models.Document` or `models.Image`."
)
embedding_model_inst = self._get_or_init_model(
model_name=self.embedding_model_name, deprecated=True
)
embeddings = list(embedding_model_inst.query_embed(query=query_text))
query_vector = embeddings[0].tolist()
if self.sparse_embedding_model_name is None:
return self._scored_points_to_query_responses(
(
await self.query_points(
collection_name=collection_name,
query=query_vector,
using=self.get_vector_field_name(),
query_filter=query_filter,
limit=limit,
with_payload=True,
**kwargs,
)
).points
)
sparse_embedding_model_inst = self._get_or_init_sparse_model(
model_name=self.sparse_embedding_model_name, deprecated=True
)
sparse_vector = list(sparse_embedding_model_inst.query_embed(query=query_text))[0]
sparse_query_vector = models.SparseVector(
indices=sparse_vector.indices.tolist(), values=sparse_vector.values.tolist()
)
dense_request = models.QueryRequest(
query=query_vector,
using=self.get_vector_field_name(),
filter=query_filter,
limit=limit,
with_payload=True,
**kwargs,
)
sparse_request = models.QueryRequest(
query=sparse_query_vector,
using=self.get_sparse_vector_field_name(),
filter=query_filter,
limit=limit,
with_payload=True,
**kwargs,
)
(dense_request_response, sparse_request_response) = await self.query_batch_points(
collection_name=collection_name, requests=[dense_request, sparse_request]
)
return self._scored_points_to_query_responses(
reciprocal_rank_fusion(
[dense_request_response.points, sparse_request_response.points], limit=limit
)
)
async def query_batch(
self,
collection_name: str,
query_texts: list[str],
query_filter: Optional[models.Filter] = None,
limit: int = 10,
**kwargs: Any,
) -> list[list[QueryResponse]]:
"""
Search for documents in a collection with batched query.
This method automatically embeds the query text using the specified embedding model.
Args:
collection_name: Collection to search in
query_texts:
A list of texts to search for. Each text will be embedded using the specified embedding model.
And then used as a query vector for a separate search requests.
query_filter:
- Exclude vectors which doesn't fit given conditions.
- If `None` - search among all vectors
This filter will be applied to all search requests.
limit: How many results return
**kwargs: Additional search parameters. See `qdrant_client.models.QueryRequest` for details.
Returns:
list[list[QueryResponse]]: List of lists of responses for each query text.
"""
show_warning_once(
"`query_batch` method has been deprecated and will be removed in 1.17. Instead, inference can be done internally within regular methods like `query_batch_points` by wrapping data into `models.Document` or `models.Image`."
)
embedding_model_inst = self._get_or_init_model(
model_name=self.embedding_model_name, deprecated=True
)
query_vectors = list(embedding_model_inst.query_embed(query=query_texts))
requests = []
for vector in query_vectors:
request = models.QueryRequest(
query=vector.tolist(),
using=self.get_vector_field_name(),
filter=query_filter,
limit=limit,
with_payload=True,
**kwargs,
)
requests.append(request)
if self.sparse_embedding_model_name is None:
responses = await self.query_batch_points(
collection_name=collection_name, requests=requests
)
return [
self._scored_points_to_query_responses(response.points) for response in responses
]
sparse_embedding_model_inst = self._get_or_init_sparse_model(
model_name=self.sparse_embedding_model_name, deprecated=True
)
sparse_query_vectors = [
models.SparseVector(
indices=sparse_vector.indices.tolist(), values=sparse_vector.values.tolist()
)
for sparse_vector in sparse_embedding_model_inst.embed(documents=query_texts)
]
for sparse_vector in sparse_query_vectors:
request = models.QueryRequest(
using=self.get_sparse_vector_field_name(),
query=sparse_vector,
filter=query_filter,
limit=limit,
with_payload=True,
**kwargs,
)
requests.append(request)
responses = await self.query_batch_points(
collection_name=collection_name, requests=requests
)
dense_responses = responses[: len(query_texts)]
sparse_responses = responses[len(query_texts) :]
responses = [
reciprocal_rank_fusion([dense_response.points, sparse_response.points], limit=limit)
for (dense_response, sparse_response) in zip(dense_responses, sparse_responses)
]
return [self._scored_points_to_query_responses(response) for response in responses]
@classmethod
def _resolve_query(
cls,
query: Union[
types.PointId,
list[float],
list[list[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
models.Document,
models.Image,
models.InferenceObject,
None,
],
) -> Optional[models.Query]:
"""Resolves query interface into a models.Query object
Args:
query: models.QueryInterface - query as a model or a plain structure like list[float]
Returns:
Optional[models.Query]: query as it was, models.Query(nearest=query) or None
Raises:
ValueError: if query is not of supported type
"""
if isinstance(query, get_args(types.Query)):
return query
if isinstance(query, types.SparseVector):
return models.NearestQuery(nearest=query)
if isinstance(query, np.ndarray):
return models.NearestQuery(nearest=query.tolist())
if isinstance(query, list):
return models.NearestQuery(nearest=query)
if isinstance(query, get_args(types.PointId)):
query = (
GrpcToRest.convert_point_id(query) if isinstance(query, grpc.PointId) else query
)
return models.NearestQuery(nearest=query)
if isinstance(query, get_args(INFERENCE_OBJECT_TYPES)):
return models.NearestQuery(nearest=query)
if query is None:
return None
raise ValueError(f"Unsupported query type: {type(query)}")
def _resolve_query_request(self, query: models.QueryRequest) -> models.QueryRequest:
"""Resolve QueryRequest query field
Args:
query: models.QueryRequest - query request to resolve
Returns:
models.QueryRequest: A deepcopy of the query request with resolved query field
"""
query = deepcopy(query)
query.query = self._resolve_query(query.query)
return query
def _resolve_query_batch_request(
self, requests: Sequence[models.QueryRequest]
) -> Sequence[models.QueryRequest]:
"""Resolve query field for each query request in a batch
Args:
requests: Sequence[models.QueryRequest] - query requests to resolve
Returns:
Sequence[models.QueryRequest]: A list of deep copied query requests with resolved query fields
"""
return [self._resolve_query_request(query) for query in requests]
def _embed_models(
self,
raw_models: Union[BaseModel, Iterable[BaseModel]],
is_query: bool = False,
batch_size: Optional[int] = None,
) -> Iterable[BaseModel]:
yield from self._model_embedder.embed_models(
raw_models=raw_models,
is_query=is_query,
batch_size=batch_size or self.DEFAULT_BATCH_SIZE,
)
def _embed_models_strict(
self,
raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]],
batch_size: Optional[int] = None,
parallel: Optional[int] = None,
) -> Iterable[BaseModel]:
yield from self._model_embedder.embed_models_strict(
raw_models=raw_models,
batch_size=batch_size or self.DEFAULT_BATCH_SIZE,
parallel=parallel,
)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1 @@
from qdrant_client.auth.bearer_auth import BearerAuth
@@ -0,0 +1,42 @@
import asyncio
from typing import Awaitable, Callable, Optional, Union
import httpx
class BearerAuth(httpx.Auth):
def __init__(
self,
auth_token_provider: Union[Callable[[], str], Callable[[], Awaitable[str]]],
):
self.async_token: Optional[Callable[[], Awaitable[str]]] = None
self.sync_token: Optional[Callable[[], str]] = None
if asyncio.iscoroutinefunction(auth_token_provider):
self.async_token = auth_token_provider
else:
if callable(auth_token_provider):
self.sync_token = auth_token_provider # type: ignore
else:
raise ValueError("auth_token_provider must be a callable or awaitable")
def _sync_get_token(self) -> str:
if self.sync_token is None:
raise ValueError("Synchronous token provider is not set.")
return self.sync_token()
def sync_auth_flow(self, request: httpx.Request) -> httpx.Request:
token = self._sync_get_token()
request.headers["Authorization"] = f"Bearer {token}"
yield request
async def _async_get_token(self) -> str:
if self.async_token is not None:
return await self.async_token() # type: ignore
# Fallback to synchronous token if asynchronous token is not available
return self._sync_get_token()
async def async_auth_flow(self, request: httpx.Request) -> httpx.Request:
token = await self._async_get_token()
request.headers["Authorization"] = f"Bearer {token}"
yield request
@@ -0,0 +1,416 @@
from typing import Any, Iterable, Mapping, Optional, Sequence, Union
from qdrant_client.conversions import common_types as types
class QdrantBase:
def __init__(self, **kwargs: Any):
pass
def search_matrix_offsets(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixOffsetsResponse:
raise NotImplementedError()
def search_matrix_pairs(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixPairsResponse:
raise NotImplementedError()
def query_batch_points(
self,
collection_name: str,
requests: Sequence[types.QueryRequest],
**kwargs: Any,
) -> list[types.QueryResponse]:
raise NotImplementedError()
def query_points(
self,
collection_name: str,
query: Union[
types.PointId,
list[float],
list[list[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
types.Image,
types.InferenceObject,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
offset: Optional[int] = None,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
lookup_from: Optional[types.LookupLocation] = None,
**kwargs: Any,
) -> types.QueryResponse:
raise NotImplementedError()
def query_points_groups(
self,
collection_name: str,
group_by: str,
query: Union[
types.PointId,
list[float],
list[list[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
types.Image,
types.InferenceObject,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
group_size: int = 3,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
with_lookup: Optional[types.WithLookupInterface] = None,
lookup_from: Optional[types.LookupLocation] = None,
**kwargs: Any,
) -> types.GroupsResult:
raise NotImplementedError()
def scroll(
self,
collection_name: str,
scroll_filter: Optional[types.Filter] = None,
limit: int = 10,
order_by: Optional[types.OrderBy] = None,
offset: Optional[types.PointId] = None,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
**kwargs: Any,
) -> tuple[list[types.Record], Optional[types.PointId]]:
raise NotImplementedError()
def count(
self,
collection_name: str,
count_filter: Optional[types.Filter] = None,
exact: bool = True,
**kwargs: Any,
) -> types.CountResult:
raise NotImplementedError()
def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
**kwargs: Any,
) -> types.FacetResponse:
raise NotImplementedError()
def upsert(
self,
collection_name: str,
points: types.Points,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def update_vectors(
self,
collection_name: str,
points: Sequence[types.PointVectors],
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def delete_vectors(
self,
collection_name: str,
vectors: Sequence[str],
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def retrieve(
self,
collection_name: str,
ids: Sequence[types.PointId],
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
**kwargs: Any,
) -> list[types.Record]:
raise NotImplementedError()
def delete(
self,
collection_name: str,
points_selector: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def set_payload(
self,
collection_name: str,
payload: types.Payload,
points: types.PointsSelector,
key: Optional[str] = None,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def overwrite_payload(
self,
collection_name: str,
payload: types.Payload,
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def delete_payload(
self,
collection_name: str,
keys: Sequence[str],
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def clear_payload(
self,
collection_name: str,
points_selector: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def batch_update_points(
self,
collection_name: str,
update_operations: Sequence[types.UpdateOperation],
**kwargs: Any,
) -> list[types.UpdateResult]:
raise NotImplementedError()
def update_collection_aliases(
self,
change_aliases_operations: Sequence[types.AliasOperations],
**kwargs: Any,
) -> bool:
raise NotImplementedError()
def get_collection_aliases(
self, collection_name: str, **kwargs: Any
) -> types.CollectionsAliasesResponse:
raise NotImplementedError()
def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
raise NotImplementedError()
def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
raise NotImplementedError()
def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
raise NotImplementedError()
def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
raise NotImplementedError()
def update_collection(
self,
collection_name: str,
**kwargs: Any,
) -> bool:
raise NotImplementedError()
def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
raise NotImplementedError()
def create_collection(
self,
collection_name: str,
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
**kwargs: Any,
) -> bool:
raise NotImplementedError()
def recreate_collection(
self,
collection_name: str,
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
**kwargs: Any,
) -> bool:
raise NotImplementedError()
def upload_points(
self,
collection_name: str,
points: Iterable[types.PointStruct],
**kwargs: Any,
) -> None:
raise NotImplementedError()
def upload_collection(
self,
collection_name: str,
vectors: Union[
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
],
payload: Optional[Iterable[dict[Any, Any]]] = None,
ids: Optional[Iterable[types.PointId]] = None,
**kwargs: Any,
) -> None:
raise NotImplementedError()
def create_payload_index(
self,
collection_name: str,
field_name: str,
field_schema: Optional[types.PayloadSchemaType] = None,
field_type: Optional[types.PayloadSchemaType] = None,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def delete_payload_index(
self,
collection_name: str,
field_name: str,
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
def list_snapshots(
self, collection_name: str, **kwargs: Any
) -> list[types.SnapshotDescription]:
raise NotImplementedError()
def create_snapshot(
self, collection_name: str, **kwargs: Any
) -> Optional[types.SnapshotDescription]:
raise NotImplementedError()
def delete_snapshot(
self, collection_name: str, snapshot_name: str, **kwargs: Any
) -> Optional[bool]:
raise NotImplementedError()
def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
raise NotImplementedError()
def create_full_snapshot(self, **kwargs: Any) -> Optional[types.SnapshotDescription]:
raise NotImplementedError()
def delete_full_snapshot(self, snapshot_name: str, **kwargs: Any) -> Optional[bool]:
raise NotImplementedError()
def recover_snapshot(
self,
collection_name: str,
location: str,
**kwargs: Any,
) -> Optional[bool]:
raise NotImplementedError()
def list_shard_snapshots(
self, collection_name: str, shard_id: int, **kwargs: Any
) -> list[types.SnapshotDescription]:
raise NotImplementedError()
def create_shard_snapshot(
self, collection_name: str, shard_id: int, **kwargs: Any
) -> Optional[types.SnapshotDescription]:
raise NotImplementedError()
def delete_shard_snapshot(
self, collection_name: str, shard_id: int, snapshot_name: str, **kwargs: Any
) -> Optional[bool]:
raise NotImplementedError()
def recover_shard_snapshot(
self,
collection_name: str,
shard_id: int,
location: str,
**kwargs: Any,
) -> Optional[bool]:
raise NotImplementedError()
def close(self, **kwargs: Any) -> None:
pass
def migrate(
self,
dest_client: "QdrantBase",
collection_names: Optional[list[str]] = None,
batch_size: int = 100,
recreate_on_collision: bool = False,
) -> None:
raise NotImplementedError()
def create_shard_key(
self,
collection_name: str,
shard_key: types.ShardKey,
shards_number: Optional[int] = None,
replication_factor: Optional[int] = None,
placement: Optional[list[int]] = None,
**kwargs: Any,
) -> bool:
raise NotImplementedError()
def delete_shard_key(
self,
collection_name: str,
shard_key: types.ShardKey,
**kwargs: Any,
) -> bool:
raise NotImplementedError()
def info(self) -> types.VersionInfo:
raise NotImplementedError()
def cluster_collection_update(
self,
collection_name: str,
cluster_operation: types.ClusterOperations,
**kwargs: Any,
) -> bool:
raise NotImplementedError()
def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo:
raise NotImplementedError()
def cluster_status(self) -> types.ClusterStatus:
raise NotImplementedError()
def recover_current_peer(self) -> bool:
raise NotImplementedError()
def remove_peer(self, peer_id: int, **kwargs: Any) -> bool:
raise NotImplementedError()
@@ -0,0 +1,16 @@
class QdrantException(Exception):
"""Base class"""
class ResourceExhaustedResponse(QdrantException):
def __init__(self, message: str, retry_after_s: int) -> None:
self.message = message if message else "Resource Exhausted Response"
try:
self.retry_after_s = int(retry_after_s)
except Exception as ex:
raise QdrantException(
f"Retry-After header value is not a valid integer: {retry_after_s}"
) from ex
def __str__(self) -> str:
return self.message.strip()
@@ -0,0 +1,24 @@
import warnings
from typing import Optional
SEEN_MESSAGES = set()
def show_warning(message: str, category: type[Warning] = UserWarning, stacklevel: int = 2) -> None:
warnings.warn(message, category, stacklevel=stacklevel)
def show_warning_once(
message: str,
category: type[Warning] = UserWarning,
idx: Optional[str] = None,
stacklevel: int = 1,
) -> None:
"""
Show a warning of the specified category only once per program run.
"""
key = idx if idx is not None else message
if key not in SEEN_MESSAGES:
SEEN_MESSAGES.add(key)
show_warning(message, category, stacklevel)
@@ -0,0 +1,65 @@
import logging
from typing import Any, Optional
from collections import namedtuple
import httpx
from qdrant_client.auth import BearerAuth
Version = namedtuple("Version", ["major", "minor", "rest"])
def get_server_version(
rest_uri: str, rest_headers: dict[str, Any], auth_provider: Optional[BearerAuth]
) -> Optional[str]:
response = httpx.get(rest_uri, headers=rest_headers, auth=auth_provider)
if response.status_code == 200:
version_info = response.json().get("version", None)
if not version_info:
logging.debug(
f"Unable to parse response from server: {response}, server version defaults to None"
)
return version_info
else:
logging.debug(
f"Unexpected response from server: {response}, server version defaults to None"
)
return None
def parse_version(version: str) -> Version:
if not version:
raise ValueError("Version is None")
try:
major, minor, *rest = version.split(".")
return Version(int(major), int(minor), rest)
except ValueError as er:
raise ValueError(
f"Unable to parse version, expected format: x.y.z, found: {version}"
) from er
def is_compatible(client_version: Optional[str], server_version: Optional[str]) -> bool:
if not client_version:
logging.debug(f"Unable to compare with client version {client_version}")
return False
if not server_version:
logging.debug(f"Unable to compare with server version {server_version}")
return False
if client_version == server_version:
return True
try:
parsed_server_version = parse_version(server_version)
parsed_client_version = parse_version(client_version)
except ValueError as er:
logging.debug(f"Unable to compare versions: {er}")
return False
major_dif = abs(parsed_server_version.major - parsed_client_version.major)
if major_dif >= 1:
return False
return abs(parsed_server_version.minor - parsed_client_version.minor) <= 1
@@ -0,0 +1,353 @@
import asyncio
import collections
from typing import Any, Awaitable, Callable, Optional, Union
import grpc
from qdrant_client.common.client_exceptions import ResourceExhaustedResponse
from qdrant_client.common.client_warnings import show_warning_once
# type: ignore # noqa: F401
# Source <https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/generic_client_interceptor.py>
class _GenericClientInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
def __init__(self, interceptor_function: Callable):
self._fn = interceptor_function
def intercept_unary_unary(
self, continuation: Any, client_call_details: Any, request: Any
) -> Any:
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request,)), False, False
)
response = continuation(new_details, next(new_request_iterator))
return postprocess(response) if postprocess else response
def intercept_unary_stream(
self, continuation: Any, client_call_details: Any, request: Any
) -> Any:
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request,)), False, True
)
response_it = continuation(new_details, next(new_request_iterator))
return postprocess(response_it) if postprocess else response_it
def intercept_stream_unary(
self, continuation: Any, client_call_details: Any, request_iterator: Any
) -> Any:
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, False
)
response = continuation(new_details, new_request_iterator)
return postprocess(response) if postprocess else response
def intercept_stream_stream(
self, continuation: Any, client_call_details: Any, request_iterator: Any
) -> Any:
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, True
)
response_it = continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it
class _GenericAsyncClientInterceptor(
grpc.aio.UnaryUnaryClientInterceptor,
grpc.aio.UnaryStreamClientInterceptor,
grpc.aio.StreamUnaryClientInterceptor,
grpc.aio.StreamStreamClientInterceptor,
):
def __init__(self, interceptor_function: Callable):
self._fn = interceptor_function
async def intercept_unary_unary(
self, continuation: Any, client_call_details: Any, request: Any
) -> Any:
new_details, new_request_iterator, postprocess = await self._fn(
client_call_details, iter((request,)), False, False
)
next_request = next(new_request_iterator)
response = await continuation(new_details, next_request)
return await postprocess(response) if postprocess else response
async def intercept_unary_stream(
self, continuation: Any, client_call_details: Any, request: Any
) -> Any:
new_details, new_request_iterator, postprocess = await self._fn(
client_call_details, iter((request,)), False, True
)
response_it = await continuation(new_details, next(new_request_iterator))
return await postprocess(response_it) if postprocess else response_it
async def intercept_stream_unary(
self, continuation: Any, client_call_details: Any, request_iterator: Any
) -> Any:
new_details, new_request_iterator, postprocess = await self._fn(
client_call_details, request_iterator, True, False
)
response = await continuation(new_details, new_request_iterator)
return await postprocess(response) if postprocess else response
async def intercept_stream_stream(
self, continuation: Any, client_call_details: Any, request_iterator: Any
) -> Any:
new_details, new_request_iterator, postprocess = await self._fn(
client_call_details, request_iterator, True, True
)
response_it = await continuation(new_details, new_request_iterator)
return await postprocess(response_it) if postprocess else response_it
def create_generic_client_interceptor(intercept_call: Any) -> _GenericClientInterceptor:
return _GenericClientInterceptor(intercept_call)
def create_generic_async_client_interceptor(
intercept_call: Any,
) -> _GenericAsyncClientInterceptor:
return _GenericAsyncClientInterceptor(intercept_call)
# Source:
# <https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/header_manipulator_client_interceptor.py>
class _ClientCallDetails(
collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")),
grpc.ClientCallDetails,
):
pass
class _ClientAsyncCallDetails(
collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")),
grpc.aio.ClientCallDetails,
):
pass
def header_adder_interceptor(
new_metadata: list[tuple[str, str]],
auth_token_provider: Optional[Callable[[], str]] = None,
) -> _GenericClientInterceptor:
def process_response(response: Any) -> Any:
if response.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
retry_after = None
for item in response.trailing_metadata():
if item.key == "retry-after":
try:
retry_after = int(item.value)
except Exception:
retry_after = None
break
reason_phrase = response.details() if response.details() else ""
if retry_after:
raise ResourceExhaustedResponse(message=reason_phrase, retry_after_s=retry_after)
return response
def intercept_call(
client_call_details: _ClientCallDetails,
request_iterator: Any,
_request_streaming: Any,
_response_streaming: Any,
) -> tuple[_ClientCallDetails, Any, Any]:
metadata = []
if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
for header, value in new_metadata:
metadata.append(
(
header,
value,
)
)
if auth_token_provider:
if not asyncio.iscoroutinefunction(auth_token_provider):
metadata.append(("authorization", f"Bearer {auth_token_provider()}"))
else:
raise ValueError("Synchronous channel requires synchronous auth token provider.")
client_call_details = _ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
)
return client_call_details, request_iterator, process_response
return create_generic_client_interceptor(intercept_call)
def header_adder_async_interceptor(
new_metadata: list[tuple[str, str]],
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
) -> _GenericAsyncClientInterceptor:
async def process_response(call: Any) -> Any:
try:
return await call
except grpc.aio.AioRpcError as er:
if er.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
retry_after = None
for item in er.trailing_metadata():
if item[0] == "retry-after":
try:
retry_after = int(item[1])
except Exception:
retry_after = None
break
reason_phrase = er.details() if er.details() else ""
if retry_after:
raise ResourceExhaustedResponse(
message=reason_phrase, retry_after_s=retry_after
) from er
raise
async def intercept_call(
client_call_details: grpc.aio.ClientCallDetails,
request_iterator: Any,
_request_streaming: Any,
_response_streaming: Any,
) -> tuple[_ClientAsyncCallDetails, Any, Any]:
metadata = []
if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
for header, value in new_metadata:
metadata.append(
(
header,
value,
)
)
if auth_token_provider:
if asyncio.iscoroutinefunction(auth_token_provider):
token = await auth_token_provider()
else:
token = auth_token_provider()
metadata.append(("authorization", f"Bearer {token}"))
client_call_details = client_call_details._replace(metadata=metadata)
return client_call_details, request_iterator, process_response
return create_generic_async_client_interceptor(intercept_call)
def parse_channel_options(options: Optional[dict[str, Any]] = None) -> list[tuple[str, Any]]:
default_options: list[tuple[str, Any]] = [
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
]
if options is None:
return default_options
_options = [(option_name, option_value) for option_name, option_value in options.items()]
for option_name, option_value in default_options:
if option_name not in options:
_options.append((option_name, option_value))
return _options
def parse_ssl_credentials(options: Optional[dict[str, Any]] = None) -> dict[str, Optional[bytes]]:
"""Parse ssl credentials to create `grpc.ssl_channel_credentials` for `grpc.secure_channel`
WARN: Directly modifies input `options`
Return:
dict[str, Optional[bytes]]: dict(root_certificates=..., private_key=..., certificate_chain=...)
"""
ssl_options: dict[str, Optional[bytes]] = dict(
root_certificates=None, private_key=None, certificate_chain=None
)
if options is None:
return ssl_options
for ssl_option_name in ssl_options:
option_value: Any = options.pop(ssl_option_name, None)
if f"grpc.{ssl_option_name}" in options:
show_warning_once(
f"`{ssl_option_name}` is supposed to be used without `grpc.` prefix",
idx=f"grpc.{ssl_option_name}",
stacklevel=10,
)
if option_value is None:
continue
if not isinstance(option_value, bytes):
raise TypeError(f"{ssl_option_name} must be a byte string")
ssl_options[ssl_option_name] = option_value
return ssl_options
def get_channel(
host: str,
port: int,
ssl: bool,
metadata: Optional[list[tuple[str, str]]] = None,
options: Optional[dict[str, Any]] = None,
compression: Optional[grpc.Compression] = None,
auth_token_provider: Optional[Callable[[], str]] = None,
) -> grpc.Channel:
# Parse gRPC client options
_copied_options = (
options.copy() if options is not None else None
) # we're changing options inplace
_ssl_cred_options = parse_ssl_credentials(_copied_options)
_options = parse_channel_options(_copied_options)
metadata_interceptor = header_adder_interceptor(
new_metadata=metadata or [], auth_token_provider=auth_token_provider
)
if ssl:
ssl_creds = grpc.ssl_channel_credentials(**_ssl_cred_options)
channel = grpc.secure_channel(f"{host}:{port}", ssl_creds, _options, compression)
return grpc.intercept_channel(channel, metadata_interceptor)
else:
channel = grpc.insecure_channel(f"{host}:{port}", _options, compression)
return grpc.intercept_channel(channel, metadata_interceptor)
def get_async_channel(
host: str,
port: int,
ssl: bool,
metadata: Optional[list[tuple[str, str]]] = None,
options: Optional[dict[str, Any]] = None,
compression: Optional[grpc.Compression] = None,
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
) -> grpc.aio.Channel:
# Parse gRPC client options
_copied_options = (
options.copy() if options is not None else None
) # we're changing options inplace
_ssl_cred_options = parse_ssl_credentials(_copied_options)
_options = parse_channel_options(_copied_options)
# Create metadata interceptor
metadata_interceptor = header_adder_async_interceptor(
new_metadata=metadata or [], auth_token_provider=auth_token_provider
)
if ssl:
ssl_creds = grpc.ssl_channel_credentials(**_ssl_cred_options)
return grpc.aio.secure_channel(
f"{host}:{port}",
ssl_creds,
_options,
compression,
interceptors=[metadata_interceptor],
)
else:
return grpc.aio.insecure_channel(
f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor]
)
@@ -0,0 +1,163 @@
import sys
import numpy as np
import numpy.typing as npt
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from typing import Union, get_args, Sequence
from uuid import UUID
from qdrant_client import grpc
from qdrant_client.http import models as rest
typing_remap = {
rest.StrictStr: str,
rest.StrictInt: int,
rest.StrictFloat: float,
rest.StrictBool: bool,
}
def remap_type(tp: type) -> type:
"""Remap type to a type that can be used in type annotations
Pydantic uses custom types for strict types, so we need to remap them to standard types
so that they can be used in type annotations and isinstance checks
"""
return typing_remap.get(tp, tp)
def get_args_subscribed(tp): # type: ignore
"""Get type arguments with all substitutions performed. Supports subscripted generics having __origin__
Args:
tp: type to get arguments from. Can be either a type or a subscripted generic
Returns:
tuple of type arguments
"""
return tuple(
remap_type(arg if not hasattr(arg, "__origin__") else arg.__origin__)
for arg in get_args(tp)
)
Filter = Union[rest.Filter, grpc.Filter]
SearchParams = Union[rest.SearchParams, grpc.SearchParams]
PayloadSelector = Union[rest.PayloadSelector, grpc.WithPayloadSelector]
Distance = Union[rest.Distance, int] # type(grpc.Distance) == int
HnswConfigDiff = Union[rest.HnswConfigDiff, grpc.HnswConfigDiff]
VectorsConfigDiff = Union[rest.VectorsConfigDiff, grpc.VectorsConfigDiff]
QuantizationConfigDiff = Union[rest.QuantizationConfigDiff, grpc.QuantizationConfigDiff]
OptimizersConfigDiff = Union[rest.OptimizersConfigDiff, grpc.OptimizersConfigDiff]
CollectionParamsDiff = Union[rest.CollectionParamsDiff, grpc.CollectionParamsDiff]
WalConfigDiff = Union[rest.WalConfigDiff, grpc.WalConfigDiff]
QuantizationConfig = Union[rest.QuantizationConfig, grpc.QuantizationConfig]
PointId = Union[int, str, UUID, grpc.PointId]
PayloadSchemaType = Union[
rest.PayloadSchemaType,
rest.PayloadSchemaParams,
int,
grpc.PayloadIndexParams,
] # type(grpc.PayloadSchemaType) == int
PointStruct: TypeAlias = rest.PointStruct
Batch: TypeAlias = rest.Batch
Points = Union[Batch, Sequence[Union[rest.PointStruct, grpc.PointStruct]]]
PointsSelector = Union[
list[PointId],
rest.Filter,
grpc.Filter,
rest.PointsSelector,
grpc.PointsSelector,
]
LookupLocation = Union[rest.LookupLocation, grpc.LookupLocation]
RecommendStrategy: TypeAlias = rest.RecommendStrategy
OrderBy = Union[rest.OrderByInterface, grpc.OrderBy]
ShardingMethod: TypeAlias = rest.ShardingMethod
ShardKey: TypeAlias = rest.ShardKey
ShardKeySelector: TypeAlias = rest.ShardKeySelector
AliasOperations = Union[
rest.CreateAliasOperation,
rest.RenameAliasOperation,
rest.DeleteAliasOperation,
grpc.AliasOperations,
]
Payload: TypeAlias = rest.Payload
ScoredPoint: TypeAlias = rest.ScoredPoint
UpdateResult: TypeAlias = rest.UpdateResult
Record: TypeAlias = rest.Record
CollectionsResponse: TypeAlias = rest.CollectionsResponse
CollectionInfo: TypeAlias = rest.CollectionInfo
CountResult: TypeAlias = rest.CountResult
SnapshotDescription: TypeAlias = rest.SnapshotDescription
NamedVector: TypeAlias = rest.NamedVector
NamedSparseVector: TypeAlias = rest.NamedSparseVector
SparseVector: TypeAlias = rest.SparseVector
PointVectors: TypeAlias = rest.PointVectors
Vector: TypeAlias = rest.Vector
VectorInput: TypeAlias = rest.VectorInput
VectorStruct: TypeAlias = rest.VectorStruct
VectorParams: TypeAlias = rest.VectorParams
SparseVectorParams: TypeAlias = rest.SparseVectorParams
SnapshotPriority: TypeAlias = rest.SnapshotPriority
CollectionsAliasesResponse: TypeAlias = rest.CollectionsAliasesResponse
UpdateOperation: TypeAlias = rest.UpdateOperation
Query: TypeAlias = rest.Query
Prefetch: TypeAlias = rest.Prefetch
Document: TypeAlias = rest.Document
Image: TypeAlias = rest.Image
InferenceObject: TypeAlias = rest.InferenceObject
StrictModeConfig: TypeAlias = rest.StrictModeConfig
QueryRequest: TypeAlias = rest.QueryRequest
Mmr: TypeAlias = rest.Mmr
ReadConsistency: TypeAlias = rest.ReadConsistency
WriteOrdering: TypeAlias = rest.WriteOrdering
WithLookupInterface: TypeAlias = rest.WithLookupInterface
GroupsResult: TypeAlias = rest.GroupsResult
QueryResponse: TypeAlias = rest.QueryResponse
FacetValue: TypeAlias = rest.FacetValue
FacetResponse: TypeAlias = rest.FacetResponse
SearchMatrixRequest = Union[rest.SearchMatrixRequest, grpc.SearchMatrixPoints]
SearchMatrixOffsetsResponse: TypeAlias = rest.SearchMatrixOffsetsResponse
SearchMatrixPairsResponse: TypeAlias = rest.SearchMatrixPairsResponse
SearchMatrixPair: TypeAlias = rest.SearchMatrixPair
VersionInfo: TypeAlias = rest.VersionInfo
ReplicaState: TypeAlias = rest.ReplicaState
ClusterOperations: TypeAlias = rest.ClusterOperations
ClusterStatus: TypeAlias = rest.ClusterStatus
CollectionClusterInfo: TypeAlias = rest.CollectionClusterInfo
# we can't use `nptyping` package due to numpy/python-version incompatibilities
# thus we need to define precise type annotations while we support python3.7
_np_numeric = Union[
np.bool_, # pylance can't handle np.bool8 alias
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.intp,
np.uintp,
np.float16,
np.float32,
np.float64,
np.longdouble, # np.float96 and np.float128 are platform dependant aliases for longdouble
]
NumpyArray: TypeAlias = npt.NDArray[_np_numeric]
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,94 @@
from typing import Optional, Any
from qdrant_client.http import models
from qdrant_client.embed.models import NumericVector
class BuiltinEmbedder:
_SUPPORTED_MODELS = ("Qdrant/Bm25",)
def __init__(self, **kwargs: Any) -> None:
pass
def embed(
self,
model_name: str,
texts: Optional[list[str]] = None,
options: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> NumericVector:
if texts is None:
if "images" in kwargs:
raise ValueError(
"Image processing is only available with cloud inference of FastEmbed"
)
raise ValueError("Texts must be provided for the inference")
if not self.is_supported_sparse_model(model_name):
raise ValueError(
f"Model {model_name} is not supported in {self.__class__.__name__}. "
f"Did you forget to enable cloud inference or install FastEmbed for local inference?"
)
return [models.Document(text=text, options=options, model=model_name) for text in texts]
@classmethod
def is_supported_text_model(cls, model_name: str) -> bool:
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return False # currently only Qdrant/Bm25 is supported
@classmethod
def is_supported_image_model(cls, model_name: str) -> bool:
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return False # currently only Qdrant/Bm25 is supported
@classmethod
def is_supported_late_interaction_text_model(cls, model_name: str) -> bool:
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return False # currently only Qdrant/Bm25 is supported
@classmethod
def is_supported_late_interaction_multimodal_model(cls, model_name: str) -> bool:
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return False # currently only Qdrant/Bm25 is supported
@classmethod
def is_supported_sparse_model(cls, model_name: str) -> bool:
"""Checks if the model is supported. Only `Qdrant/Bm25` is supported
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return model_name.lower() in [model.lower() for model in cls._SUPPORTED_MODELS]
@@ -0,0 +1,6 @@
from typing import Union
from qdrant_client.http import models
INFERENCE_OBJECT_NAMES: set[str] = {"Document", "Image", "InferenceObject"}
INFERENCE_OBJECT_TYPES = Union[models.Document, models.Image, models.InferenceObject]
@@ -0,0 +1,176 @@
from copy import copy
from typing import Union, Optional, Iterable, get_args
from pydantic import BaseModel
from qdrant_client._pydantic_compat import model_fields_set
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.schema_parser import ModelSchemaParser
from qdrant_client.embed.utils import convert_paths, FieldPath
class InspectorEmbed:
"""Inspector which collects paths to objects requiring inference in the received models
Attributes:
parser: ModelSchemaParser instance
"""
def __init__(self, parser: Optional[ModelSchemaParser] = None) -> None:
self.parser = ModelSchemaParser() if parser is None else parser
def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> list[FieldPath]:
"""Looks for all the paths to objects requiring inference in the received models
Args:
points: models to inspect
Returns:
list of FieldPath objects
"""
paths = []
if isinstance(points, BaseModel):
self.parser.parse_model(points.__class__)
paths.extend(self._inspect_model(points))
elif isinstance(points, dict):
for value in points.values():
paths.extend(self.inspect(value))
elif isinstance(points, Iterable):
for point in points:
if isinstance(point, BaseModel):
self.parser.parse_model(point.__class__)
paths.extend(self._inspect_model(point))
paths = sorted(set(paths))
return convert_paths(paths)
def _inspect_model(
self, mod: BaseModel, paths: Optional[list[FieldPath]] = None, accum: Optional[str] = None
) -> list[str]:
"""Looks for all the paths to objects requiring inference in the received model
Args:
mod: model to inspect
paths: list of paths to the fields possibly containing objects for inference
accum: accumulator for the path. Path is a dot separated string of field names which we assemble recursively
Returns:
list of paths to the model fields containing objects for inference
"""
paths = self.parser.path_cache.get(mod.__class__.__name__, []) if paths is None else paths
found_paths = []
for path in paths:
found_paths.extend(
self._inspect_inner_models(
mod, path.current, path.tail if path.tail else [], accum
)
)
return found_paths
def _inspect_inner_models(
self,
original_model: BaseModel,
current_path: str,
tail: list[FieldPath],
accum: Optional[str] = None,
) -> list[str]:
"""Looks for all the paths to objects requiring inference in the received model
Args:
original_model: model to inspect
current_path: the field to inspect on the current iteration
tail: list of FieldPath objects to the fields possibly containing objects for inference
accum: accumulator for the path. Path is a dot separated string of field names which we assemble recursively
Returns:
list of paths to the model fields containing objects for inference
"""
found_paths = []
if accum is None:
accum = current_path
else:
accum += f".{current_path}"
def inspect_recursive(member: BaseModel, accumulator: str) -> list[str]:
"""Iterates over the set model fields, expand recursive ones and find paths to objects requiring inference
Args:
member: currently inspected model, which may or may not contain recursive fields
accumulator: accumulator for the path, which is a dot separated string assembled recursively
"""
recursive_paths = []
for field in model_fields_set(member):
if field in self.parser.name_recursive_ref_mapping:
mapped_field = self.parser.name_recursive_ref_mapping[field]
recursive_paths.extend(self.parser.path_cache[mapped_field])
return self._inspect_model(member, copy(recursive_paths), accumulator)
model = getattr(original_model, current_path, None)
if model is None:
return []
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
return [accum]
if isinstance(model, BaseModel):
found_paths.extend(inspect_recursive(model, accum))
for next_path in tail:
found_paths.extend(
self._inspect_inner_models(
model, next_path.current, next_path.tail if next_path.tail else [], accum
)
)
return found_paths
elif isinstance(model, list):
for current_model in model:
if not isinstance(current_model, BaseModel):
continue
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
found_paths.append(accum)
found_paths.extend(inspect_recursive(current_model, accum))
for next_path in tail:
for current_model in model:
found_paths.extend(
self._inspect_inner_models(
current_model,
next_path.current,
next_path.tail if next_path.tail else [],
accum,
)
)
return found_paths
elif isinstance(model, dict):
found_paths = []
for key, values in model.items():
values = [values] if not isinstance(values, list) else values
for current_model in values:
if not isinstance(current_model, BaseModel):
continue
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
found_paths.append(accum)
found_paths.extend(inspect_recursive(current_model, accum))
for next_path in tail:
for current_model in values:
found_paths.extend(
self._inspect_inner_models(
current_model,
next_path.current,
next_path.tail if next_path.tail else [],
accum,
)
)
return found_paths
@@ -0,0 +1,447 @@
from collections import defaultdict
from typing import Optional, Sequence, Any, TypeVar, Generic
from pydantic import BaseModel
from qdrant_client.http import models
from qdrant_client.embed.models import NumericVector
from qdrant_client.fastembed_common import (
OnnxProvider,
ImageInput,
TextEmbedding,
SparseTextEmbedding,
LateInteractionTextEmbedding,
LateInteractionMultimodalEmbedding,
ImageEmbedding,
FastEmbedMisc,
)
T = TypeVar("T")
class ModelInstance(BaseModel, Generic[T], arbitrary_types_allowed=True): # type: ignore[call-arg]
model: T
options: dict[str, Any]
deprecated: bool = False
class Embedder:
def __init__(self, threads: Optional[int] = None, **kwargs: Any) -> None:
self.embedding_models: dict[str, list[ModelInstance[TextEmbedding]]] = defaultdict(list)
self.sparse_embedding_models: dict[str, list[ModelInstance[SparseTextEmbedding]]] = (
defaultdict(list)
)
self.late_interaction_embedding_models: dict[
str, list[ModelInstance[LateInteractionTextEmbedding]]
] = defaultdict(list)
self.image_embedding_models: dict[str, list[ModelInstance[ImageEmbedding]]] = defaultdict(
list
)
self.late_interaction_multimodal_embedding_models: dict[
str, list[ModelInstance[LateInteractionMultimodalEmbedding]]
] = defaultdict(list)
self._threads = threads
def get_or_init_model(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
deprecated: bool = False,
**kwargs: Any,
) -> TextEmbedding:
if not FastEmbedMisc.is_supported_text_model(model_name):
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_text_models()}"
)
options = {
"cache_dir": cache_dir,
"threads": threads or self._threads,
"providers": providers,
"cuda": cuda,
"device_ids": device_ids,
**kwargs,
}
for instance in self.embedding_models[model_name]:
if (deprecated and instance.deprecated) or (
not deprecated and instance.options == options
):
return instance.model
model = TextEmbedding(model_name=model_name, **options)
model_instance: ModelInstance[TextEmbedding] = ModelInstance(
model=model, options=options, deprecated=deprecated
)
self.embedding_models[model_name].append(model_instance)
return model
def get_or_init_sparse_model(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
deprecated: bool = False,
**kwargs: Any,
) -> SparseTextEmbedding:
if not FastEmbedMisc.is_supported_sparse_model(model_name):
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_sparse_models()}"
)
options = {
"cache_dir": cache_dir,
"threads": threads or self._threads,
"providers": providers,
"cuda": cuda,
"device_ids": device_ids,
**kwargs,
}
for instance in self.sparse_embedding_models[model_name]:
if (deprecated and instance.deprecated) or (
not deprecated and instance.options == options
):
return instance.model
model = SparseTextEmbedding(model_name=model_name, **options)
model_instance: ModelInstance[SparseTextEmbedding] = ModelInstance(
model=model, options=options, deprecated=deprecated
)
self.sparse_embedding_models[model_name].append(model_instance)
return model
def get_or_init_late_interaction_model(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
**kwargs: Any,
) -> LateInteractionTextEmbedding:
if not FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
raise ValueError(
f"Unsupported embedding model: {model_name}. "
f"Supported models: {FastEmbedMisc.list_late_interaction_text_models()}"
)
options = {
"cache_dir": cache_dir,
"threads": threads or self._threads,
"providers": providers,
"cuda": cuda,
"device_ids": device_ids,
**kwargs,
}
for instance in self.late_interaction_embedding_models[model_name]:
if instance.options == options:
return instance.model
model = LateInteractionTextEmbedding(model_name=model_name, **options)
model_instance: ModelInstance[LateInteractionTextEmbedding] = ModelInstance(
model=model, options=options
)
self.late_interaction_embedding_models[model_name].append(model_instance)
return model
def get_or_init_late_interaction_multimodal_model(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
**kwargs: Any,
) -> LateInteractionMultimodalEmbedding:
if not FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
raise ValueError(
f"Unsupported embedding model: {model_name}. "
f"Supported models: {FastEmbedMisc.list_late_interaction_multimodal_models()}"
)
options = {
"cache_dir": cache_dir,
"threads": threads or self._threads,
"providers": providers,
"cuda": cuda,
"device_ids": device_ids,
**kwargs,
}
for instance in self.late_interaction_multimodal_embedding_models[model_name]:
if instance.options == options:
return instance.model
model = LateInteractionMultimodalEmbedding(model_name=model_name, **options)
model_instance: ModelInstance[LateInteractionMultimodalEmbedding] = ModelInstance(
model=model, options=options
)
self.late_interaction_multimodal_embedding_models[model_name].append(model_instance)
return model
def get_or_init_image_model(
self,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
cuda: bool = False,
device_ids: Optional[list[int]] = None,
**kwargs: Any,
) -> ImageEmbedding:
if not FastEmbedMisc.is_supported_image_model(model_name):
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_image_models()}"
)
options = {
"cache_dir": cache_dir,
"threads": threads or self._threads,
"providers": providers,
"cuda": cuda,
"device_ids": device_ids,
**kwargs,
}
for instance in self.image_embedding_models[model_name]:
if instance.options == options:
return instance.model
model = ImageEmbedding(model_name=model_name, **options)
model_instance: ModelInstance[ImageEmbedding] = ModelInstance(model=model, options=options)
self.image_embedding_models[model_name].append(model_instance)
return model
def embed(
self,
model_name: str,
texts: Optional[list[str]] = None,
images: Optional[list[ImageInput]] = None,
options: Optional[dict[str, Any]] = None,
is_query: bool = False,
batch_size: int = 8,
) -> NumericVector:
if (texts is None) is (images is None):
raise ValueError("Either documents or images should be provided")
embeddings: NumericVector # define type for a static type checker
if texts is not None:
if FastEmbedMisc.is_supported_text_model(model_name):
embeddings = self._embed_dense_text(
texts, model_name, options, is_query, batch_size
)
elif FastEmbedMisc.is_supported_sparse_model(model_name):
embeddings = self._embed_sparse_text(
texts, model_name, options, is_query, batch_size
)
elif FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
embeddings = self._embed_late_interaction_text(
texts, model_name, options, is_query, batch_size
)
elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
embeddings = self._embed_late_interaction_multimodal_text(
texts, model_name, options, batch_size
)
else:
raise ValueError(f"Unsupported embedding model: {model_name}")
else:
assert (
images is not None
) # just to satisfy mypy which can't infer it from the previous conditions
if FastEmbedMisc.is_supported_image_model(model_name):
embeddings = self._embed_dense_image(images, model_name, options, batch_size)
elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
embeddings = self._embed_late_interaction_multimodal_image(
images, model_name, options, batch_size
)
else:
raise ValueError(f"Unsupported embedding model: {model_name}")
return embeddings
def _embed_dense_text(
self,
texts: list[str],
model_name: str,
options: Optional[dict[str, Any]],
is_query: bool,
batch_size: int,
) -> list[list[float]]:
embedding_model_inst = self.get_or_init_model(model_name=model_name, **options or {})
if not is_query:
embeddings = [
embedding.tolist()
for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
]
else:
embeddings = [
embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
]
return embeddings
def _embed_sparse_text(
self,
texts: list[str],
model_name: str,
options: Optional[dict[str, Any]],
is_query: bool,
batch_size: int,
) -> list[models.SparseVector]:
embedding_model_inst = self.get_or_init_sparse_model(
model_name=model_name, **options or {}
)
if not is_query:
embeddings = [
models.SparseVector(
indices=sparse_embedding.indices.tolist(),
values=sparse_embedding.values.tolist(),
)
for sparse_embedding in embedding_model_inst.embed(
documents=texts, batch_size=batch_size
)
]
else:
embeddings = [
models.SparseVector(
indices=sparse_embedding.indices.tolist(),
values=sparse_embedding.values.tolist(),
)
for sparse_embedding in embedding_model_inst.query_embed(query=texts)
]
return embeddings
def _embed_late_interaction_text(
self,
texts: list[str],
model_name: str,
options: Optional[dict[str, Any]],
is_query: bool,
batch_size: int,
) -> list[list[list[float]]]:
embedding_model_inst = self.get_or_init_late_interaction_model(
model_name=model_name, **options or {}
)
if not is_query:
embeddings = [
embedding.tolist()
for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
]
else:
embeddings = [
embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
]
return embeddings
def _embed_late_interaction_multimodal_text(
self,
texts: list[str],
model_name: str,
options: Optional[dict[str, Any]],
batch_size: int,
) -> list[list[list[float]]]:
embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
model_name=model_name, **options or {}
)
return [
embedding.tolist()
for embedding in embedding_model_inst.embed_text(
documents=texts, batch_size=batch_size
)
]
def _embed_late_interaction_multimodal_image(
self,
images: list[ImageInput],
model_name: str,
options: Optional[dict[str, Any]],
batch_size: int,
) -> list[list[list[float]]]:
embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
model_name=model_name, **options or {}
)
return [
embedding.tolist()
for embedding in embedding_model_inst.embed_image(images=images, batch_size=batch_size)
]
def _embed_dense_image(
self,
images: list[ImageInput],
model_name: str,
options: Optional[dict[str, Any]],
batch_size: int,
) -> list[list[float]]:
embedding_model_inst = self.get_or_init_image_model(model_name=model_name, **options or {})
embeddings = [
embedding.tolist()
for embedding in embedding_model_inst.embed(images=images, batch_size=batch_size)
]
return embeddings
@classmethod
def is_supported_text_model(cls, model_name: str) -> bool:
"""Check if model is supported by fastembed
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return FastEmbedMisc.is_supported_text_model(model_name)
@classmethod
def is_supported_image_model(cls, model_name: str) -> bool:
"""Check if model is supported by fastembed
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return FastEmbedMisc.is_supported_image_model(model_name)
@classmethod
def is_supported_late_interaction_text_model(cls, model_name: str) -> bool:
"""Check if model is supported by fastembed
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return FastEmbedMisc.is_supported_late_interaction_text_model(model_name)
@classmethod
def is_supported_late_interaction_multimodal_model(cls, model_name: str) -> bool:
"""Check if model is supported by fastembed
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name)
@classmethod
def is_supported_sparse_model(cls, model_name: str) -> bool:
"""Check if model is supported by fastembed
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
return FastEmbedMisc.is_supported_sparse_model(model_name)
@@ -0,0 +1,498 @@
import os
from collections import defaultdict
from copy import deepcopy
from multiprocessing import get_all_start_methods
from typing import Optional, Union, Iterable, Any, Type, get_args
from pydantic import BaseModel
from qdrant_client.embed.builtin_embedder import BuiltinEmbedder
from qdrant_client.http import models
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.embed_inspector import InspectorEmbed
from qdrant_client.embed.embedder import Embedder
from qdrant_client.embed.models import NumericVector, NumericVectorStruct
from qdrant_client.embed.schema_parser import ModelSchemaParser
from qdrant_client.embed.utils import FieldPath
from qdrant_client.fastembed_common import FastEmbedMisc
from qdrant_client.parallel_processor import ParallelWorkerPool, Worker
from qdrant_client.uploader.uploader import iter_batch
class ModelEmbedderWorker(Worker):
def __init__(self, batch_size: int, **kwargs: Any):
self.model_embedder = ModelEmbedder(**kwargs)
self.batch_size = batch_size
@classmethod
def start(cls, batch_size: int, **kwargs: Any) -> "ModelEmbedderWorker":
return cls(threads=1, batch_size=batch_size, **kwargs)
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
for idx, batch in items:
yield (
idx,
list(
self.model_embedder.embed_models_batch(
batch, inference_batch_size=self.batch_size
)
),
)
class ModelEmbedder:
MAX_INTERNAL_BATCH_SIZE = 64
def __init__(
self,
parser: Optional[ModelSchemaParser] = None,
is_local_mode: bool = False,
server_version: Optional[str] = None,
**kwargs: Any,
):
self._batch_accumulator: dict[str, list[INFERENCE_OBJECT_TYPES]] = {}
self._embed_storage: dict[str, list[NumericVector]] = {}
self._embed_inspector = InspectorEmbed(parser=parser)
self._is_builtin_embedder_available = self._check_builtin_embedder_availability(
is_local_mode, server_version
)
self.embedder = (
Embedder(**kwargs) if FastEmbedMisc.is_installed() else BuiltinEmbedder(**kwargs)
)
@staticmethod
def _check_builtin_embedder_availability(
is_local_mode: bool, server_version: Optional[str]
) -> bool:
if is_local_mode:
return False
if (
server_version is None
): # failed to detect server version, it might happen due to security or network
# problems even on supported server versions, so we are not blocking usage of BuiltinEmbedder.
return True
try:
major, minor, patch = server_version.split(".")
patch = patch.split("-")[0]
if (int(major), int(minor), int(patch)) >= (1, 15, 3):
return True
return False
except Exception:
return True
def embed_models(
self,
raw_models: Union[BaseModel, Iterable[BaseModel]],
is_query: bool = False,
batch_size: int = 8,
) -> Iterable[BaseModel]:
"""Embed raw data fields in models and return models with vectors
If any of model fields required inference, a deepcopy of a model with computed embeddings is returned,
otherwise returns original models.
Args:
raw_models: Iterable[BaseModel] - models which can contain fields with raw data
is_query: bool - flag to determine which embed method to use. Defaults to False.
batch_size: int - batch size for inference
Returns:
list[BaseModel]: models with embedded fields
"""
if not self._is_builtin_embedder_available:
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
if isinstance(raw_models, BaseModel):
raw_models = [raw_models]
for raw_models_batch in iter_batch(raw_models, batch_size):
yield from self.embed_models_batch(
raw_models_batch, is_query, inference_batch_size=batch_size
)
def embed_models_strict(
self,
raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]],
batch_size: int = 8,
parallel: Optional[int] = None,
) -> Iterable[Union[dict[str, BaseModel], BaseModel]]:
"""Embed raw data fields in models and return models with vectors
Requires every input sequences element to contain raw data fields to inference.
Does not accept ready vectors.
Args:
raw_models: Iterable[BaseModel] - models which contain fields with raw data to inference
batch_size: int - batch size for inference
parallel: int - number of parallel processes to use. Defaults to None.
Returns:
Iterable[Union[dict[str, BaseModel], BaseModel]]: models with embedded fields
"""
if not self._is_builtin_embedder_available:
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
is_small = False
if isinstance(raw_models, list):
if len(raw_models) < batch_size:
is_small = True
if (
isinstance(self.embedder, BuiltinEmbedder)
or parallel is None
or parallel == 1
or is_small
):
for batch in iter_batch(raw_models, batch_size):
yield from self.embed_models_batch(batch, inference_batch_size=batch_size)
else:
multiprocessing_batch_size = 1 # larger batch sizes do not help with data parallel
# on cpu. todo: adjust when multi-gpu is available
raw_models_batches = iter_batch(raw_models, size=multiprocessing_batch_size)
if parallel == 0:
parallel = os.cpu_count()
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
assert parallel is not None # just a mypy complaint
pool = ParallelWorkerPool(
num_workers=parallel,
worker=self._get_worker_class(),
start_method=start_method,
max_internal_batch_size=self.MAX_INTERNAL_BATCH_SIZE,
)
for batch in pool.ordered_map(
raw_models_batches, batch_size=multiprocessing_batch_size
):
yield from batch
def embed_models_batch(
self,
raw_models: list[Union[dict[str, BaseModel], BaseModel]],
is_query: bool = False,
inference_batch_size: int = 8,
) -> Iterable[BaseModel]:
"""Embed a batch of models with raw data fields and return models with vectors
If any of model fields required inference, a deepcopy of a model with computed embeddings is returned,
otherwise returns original models.
Args:
raw_models: list[Union[dict[str, BaseModel], BaseModel]] - models which can contain fields with raw data
is_query: bool - flag to determine which embed method to use. Defaults to False.
inference_batch_size: int - batch size for inference
Returns:
Iterable[BaseModel]: models with embedded fields
"""
if not self._is_builtin_embedder_available:
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
for raw_model in raw_models:
self._process_model(raw_model, is_query=is_query, accumulating=True)
if not self._batch_accumulator:
yield from raw_models
else:
yield from (
self._process_model(
raw_model,
is_query=is_query,
accumulating=False,
inference_batch_size=inference_batch_size,
)
for raw_model in raw_models
)
def _process_model(
self,
model: Union[dict[str, BaseModel], BaseModel],
paths: Optional[list[FieldPath]] = None,
is_query: bool = False,
accumulating: bool = False,
inference_batch_size: Optional[int] = None,
) -> Union[dict[str, BaseModel], dict[str, NumericVector], BaseModel, NumericVector]:
"""Embed model's fields requiring inference
Args:
model: Qdrant http model containing fields to embed
paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])]
is_query: Flag to determine which embed method to use. Defaults to False.
accumulating: Flag to determine if we are accumulating models for batch embedding. Defaults to False.
inference_batch_size: Optional[int] - batch size for inference
Returns:
A deepcopy of the method with embedded fields
"""
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
if accumulating:
self._accumulate(model) # type: ignore
else:
assert (
inference_batch_size is not None
), "inference_batch_size should be passed for inference"
return self._drain_accumulator(
model, # type: ignore
is_query=is_query,
inference_batch_size=inference_batch_size,
)
if paths is None:
model = deepcopy(model) if not accumulating else model
if isinstance(model, dict):
for key, value in model.items():
if accumulating:
self._process_model(value, paths, accumulating=True)
else:
model[key] = self._process_model(
value,
paths,
is_query=is_query,
accumulating=False,
inference_batch_size=inference_batch_size,
)
return model
paths = paths if paths is not None else self._embed_inspector.inspect(model)
for path in paths:
list_model = [model] if not isinstance(model, list) else model
for item in list_model:
current_model = getattr(item, path.current, None)
if current_model is None:
continue
if path.tail:
self._process_model(
current_model,
path.tail,
is_query=is_query,
accumulating=accumulating,
inference_batch_size=inference_batch_size,
)
else:
was_list = isinstance(current_model, list)
current_model = current_model if was_list else [current_model]
if not accumulating:
assert (
inference_batch_size is not None
), "inference_batch_size should be passed for inference"
embeddings = [
self._drain_accumulator(
data, is_query=is_query, inference_batch_size=inference_batch_size
)
for data in current_model
]
if was_list:
setattr(item, path.current, embeddings)
else:
setattr(item, path.current, embeddings[0])
else:
for data in current_model:
self._accumulate(data)
return model
def _accumulate(self, data: models.VectorStruct) -> None:
"""Add data to batch accumulator
Args:
data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - add them
to the accumulator, otherwise - do nothing. `InferenceObject` instances are converted to proper types.
Returns:
None
"""
if isinstance(data, dict):
for value in data.values():
self._accumulate(value)
return None
if isinstance(data, list):
for value in data:
if not isinstance(value, get_args(INFERENCE_OBJECT_TYPES)): # if value is a vector
return None
self._accumulate(value)
if not isinstance(data, get_args(INFERENCE_OBJECT_TYPES)):
return None
data = self._resolve_inference_object(data)
if data.model not in self._batch_accumulator:
self._batch_accumulator[data.model] = []
self._batch_accumulator[data.model].append(data)
return None
def _drain_accumulator(
self, data: models.VectorStruct, is_query: bool, inference_batch_size: int = 8
) -> NumericVectorStruct:
"""Drain accumulator and replaces inference objects with computed embeddings
It is assumed objects are traversed in the same order as they were added to the accumulator
Args:
data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - replace
them with computed embeddings. If embeddings haven't yet been computed - compute them and then replace
inference objects.
inference_batch_size: int - batch size for inference
Returns:
NumericVectorStruct: data with replaced inference objects
"""
if isinstance(data, dict):
for key, value in data.items():
data[key] = self._drain_accumulator(
value, is_query=is_query, inference_batch_size=inference_batch_size
)
return data
if isinstance(data, list):
for i, value in enumerate(data):
if not isinstance(value, get_args(INFERENCE_OBJECT_TYPES)): # if value is vector
return data
data[i] = self._drain_accumulator(
value, is_query=is_query, inference_batch_size=inference_batch_size
)
return data
if not isinstance(
data, get_args(INFERENCE_OBJECT_TYPES)
): # ide type checker ignores `not` and scolds
return data # type: ignore
if not self._embed_storage or not self._embed_storage.get(data.model, None):
self._embed_accumulator(is_query=is_query, inference_batch_size=inference_batch_size)
return self._next_embed(data.model)
def _embed_accumulator(self, is_query: bool = False, inference_batch_size: int = 8) -> None:
"""Embed all accumulated objects for all models
Args:
is_query: bool - flag to determine which embed method to use. Defaults to False.
inference_batch_size: int - batch size for inference
Returns:
None
"""
def embed(
objects: list[INFERENCE_OBJECT_TYPES], model_name: str, batch_size: int
) -> list[NumericVector]:
"""
Assemble batches by options and data type based groups, embeds and return embeddings in the original order
"""
unique_options: list[dict[str, Any]] = []
unique_options_is_text: list[bool] = [] # multimodal models can have both text
# and image data, we need to track which data we process to construct separate batches for texts and images
batches: list[Any] = []
group_indices: dict[int, list[int]] = defaultdict(list)
for i, obj in enumerate(objects):
is_text = isinstance(obj, models.Document)
for j, (options, options_is_text) in enumerate(
zip(unique_options, unique_options_is_text)
):
if options == obj.options and is_text == options_is_text:
group_indices[j].append(i)
batches[j].append(obj.text if is_text else obj.image)
break
else:
# Create a new group if no match was found
group_indices[len(unique_options)] = [i]
unique_options.append(obj.options)
unique_options_is_text.append(is_text)
batches.append([obj.text if is_text else obj.image])
embeddings = []
for i, (options, is_text) in enumerate(zip(unique_options, unique_options_is_text)):
embeddings.extend(
[
embedding
for embedding in self.embedder.embed(
model_name=model_name,
texts=batches[i] if is_text else None,
images=batches[i] if not is_text else None,
is_query=is_query,
options=options or {},
batch_size=batch_size,
)
]
)
iter_embeddings = iter(embeddings)
ordered_embeddings: list[list[NumericVector]] = [[]] * len(objects)
for indices in group_indices.values():
for index in indices:
ordered_embeddings[index] = next(iter_embeddings)
return ordered_embeddings
for model in self._batch_accumulator:
if not any(
(
self.embedder.is_supported_text_model(model),
self.embedder.is_supported_sparse_model(model),
self.embedder.is_supported_late_interaction_text_model(model),
self.embedder.is_supported_image_model(model),
self.embedder.is_supported_late_interaction_multimodal_model(model),
)
):
if isinstance(self.embedder, BuiltinEmbedder):
raise ValueError(
f"{model} is not among supported models. "
f"Have you forgotten to set `cloud_inference` or install `fastembed` for local inference?"
)
else:
raise ValueError(f"{model} is not among supported models")
for model, data in self._batch_accumulator.items():
self._embed_storage[model] = embed(
objects=data, model_name=model, batch_size=inference_batch_size
)
self._batch_accumulator.clear()
def _next_embed(self, model_name: str) -> NumericVector:
"""Get next computed embedding from embedded batch
Args:
model_name: str - retrieve embedding from the storage by this model name
Returns:
NumericVector: computed embedding
"""
return self._embed_storage[model_name].pop(0)
def _resolve_inference_object(self, data: models.VectorStruct) -> models.VectorStruct:
"""Resolve inference object into a model
Args:
data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type,
otherwise - keep unchanged
Returns:
models.VectorStruct: resolved data
"""
if not isinstance(data, models.InferenceObject):
return data
model_name = data.model
value = data.object
options = data.options
if any(
(
self.embedder.is_supported_text_model(model_name),
self.embedder.is_supported_sparse_model(model_name),
self.embedder.is_supported_late_interaction_text_model(model_name),
)
):
return models.Document(model=model_name, text=value, options=options)
if self.embedder.is_supported_image_model(model_name):
return models.Image(model=model_name, image=value, options=options)
if self.embedder.is_supported_late_interaction_multimodal_model(model_name):
raise ValueError(f"{model_name} does not support `InferenceObject` interface")
raise ValueError(f"{model_name} is not among supported models")
@classmethod
def _get_worker_class(cls) -> Type[ModelEmbedderWorker]:
return ModelEmbedderWorker
@@ -0,0 +1,25 @@
from typing import Union
from pydantic import StrictFloat, StrictStr
from qdrant_client.http.models import ExtendedPointId, SparseVector
NumericVector = Union[
list[StrictFloat],
SparseVector,
list[list[StrictFloat]],
]
NumericVectorInput = Union[
list[StrictFloat],
SparseVector,
list[list[StrictFloat]],
ExtendedPointId,
]
NumericVectorStruct = Union[
list[StrictFloat],
list[list[StrictFloat]],
dict[StrictStr, NumericVector],
]
__all__ = ["NumericVector", "NumericVectorInput", "NumericVectorStruct"]
@@ -0,0 +1,305 @@
from copy import copy, deepcopy
from pathlib import Path
from typing import Type, Union, Any, Optional
from pydantic import BaseModel
from qdrant_client._pydantic_compat import model_json_schema
from qdrant_client.embed.utils import FieldPath, convert_paths
try:
from qdrant_client.embed._inspection_cache import (
DEFS,
CACHE_STR_PATH,
RECURSIVE_REFS,
EXCLUDED_RECURSIVE_REFS,
INCLUDED_RECURSIVE_REFS,
NAME_RECURSIVE_REF_MAPPING,
)
except ImportError as e:
DEFS = {}
CACHE_STR_PATH = {}
RECURSIVE_REFS = set() # type: ignore
EXCLUDED_RECURSIVE_REFS = {"Filter"} # type: ignore
INCLUDED_RECURSIVE_REFS = set() # type: ignore
NAME_RECURSIVE_REF_MAPPING = {}
class ModelSchemaParser:
"""Model schema parser. Parses json schemas to retrieve paths to objects requiring inference.
The parser is stateful, it accumulates the results of parsing in its internal structures.
Attributes:
_defs: definitions extracted from json schemas
_recursive_refs: set of recursive refs found in the processed schemas, e.g.:
{"Filter", "Prefetch"}
_excluded_recursive_refs: predefined time-consuming recursive refs which don't have inference objects, e.g.:
{"Filter"}
_included_recursive_refs: set of recursive refs which have inference objects, e.g.:
{"Prefetch"}
_cache: cache of string paths for models containing objects for inference, e.g.:
{"Prefetch": ['prefetch.query', 'prefetch.query.context.negative', ...]}
path_cache: cache of FieldPath objects for models containing objects for inference, e.g.:
{
"Prefetch": [
FieldPath(
current="prefetch",
tail=[
FieldPath(
current="query",
tail=[
FieldPath(
current="recommend",
tail=[
FieldPath(current="negative", tail=None),
FieldPath(current="positive", tail=None),
],
),
...,
],
),
],
)
]
}
name_recursive_ref_mapping: mapping of model field names to ref names, e.g.:
{"prefetch": "Prefetch"}
"""
CACHE_PATH = "_inspection_cache.py"
INFERENCE_OBJECT_NAMES = {"Document", "Image", "InferenceObject"}
def __init__(self) -> None:
# self._defs does not include the whole schema, but only the part with the structures used in $defs
self._defs: dict[str, Union[dict[str, Any], list[dict[str, Any]]]] = deepcopy(DEFS) # type: ignore[arg-type]
self._cache: dict[str, list[str]] = deepcopy(CACHE_STR_PATH)
self._recursive_refs: set[str] = set(RECURSIVE_REFS)
self._excluded_recursive_refs: set[str] = set(EXCLUDED_RECURSIVE_REFS)
self._included_recursive_refs: set[str] = set(INCLUDED_RECURSIVE_REFS)
self.name_recursive_ref_mapping: dict[str, str] = {
k: v for k, v in NAME_RECURSIVE_REF_MAPPING.items()
}
self.path_cache: dict[str, list[FieldPath]] = {
model: convert_paths(paths) for model, paths in self._cache.items()
}
self._processed_recursive_defs: dict[str, Any] = {}
def _replace_refs(
self,
schema: Union[dict[str, Any], list[dict[str, Any]]],
parent: Optional[str] = None,
seen_refs: Optional[set] = None,
) -> Union[dict[str, Any], list[dict[str, Any]]]:
"""Replace refs in schema with their definitions
Args:
schema: schema to parse
parent: previous level key
seen_refs: set of seen refs to spot recursive paths
Returns:
schema with replaced refs
"""
parent = parent if parent else None
seen_refs = seen_refs if seen_refs else set()
if isinstance(schema, dict):
if "$ref" in schema:
ref_path = schema["$ref"]
def_key = ref_path.split("/")[-1]
if def_key in self._processed_recursive_defs:
return self._processed_recursive_defs[def_key]
if def_key == parent or def_key in seen_refs:
self._recursive_refs.add(def_key)
self._processed_recursive_defs[def_key] = schema
return schema
seen_refs.add(def_key)
return self._replace_refs(
self._defs[def_key], parent=def_key, seen_refs=copy(seen_refs)
)
schemes = {}
if "properties" in schema:
for k, v in schema.items():
if k == "properties":
schemes[k] = self._replace_refs(
schema=v, parent=parent, seen_refs=copy(seen_refs)
)
else:
schemes[k] = v
else:
for k, v in schema.items():
parent_key = k if isinstance(v, dict) and "properties" in v else parent
schemes[k] = self._replace_refs(
schema=v, parent=parent_key, seen_refs=copy(seen_refs)
)
return schemes
elif isinstance(schema, list):
return [
self._replace_refs(schema=item, parent=parent, seen_refs=copy(seen_refs)) # type: ignore
for item in schema
]
else:
return schema
def _find_document_paths(
self,
schema: Union[dict[str, Any], list[dict[str, Any]]],
current_path: str = "",
after_properties: bool = False,
seen_refs: Optional[set] = None,
) -> list[str]:
"""Read a schema and find paths to objects requiring inference
Populates model fields names to ref names mapping
Args:
schema: schema to parse
current_path: current path in the schema
after_properties: flag indicating if the current path is after "properties" key
seen_refs: set of seen refs to spot recursive paths
Returns:
List of string dot separated paths to objects requiring inference
"""
document_paths: list[str] = []
seen_recursive_refs = seen_refs if seen_refs is not None else set()
parts = current_path.split(".")
if len(parts) != len(set(parts)): # check for recursive paths
return document_paths
if not isinstance(schema, dict):
return document_paths
if "title" in schema and schema["title"] in self.INFERENCE_OBJECT_NAMES:
document_paths.append(current_path)
return document_paths
for key, value in schema.items():
if key == "$defs":
continue
if key == "$ref":
model_name = value.split("/")[-1]
value = self._defs[model_name]
if model_name in seen_recursive_refs:
continue
if (
model_name in self._excluded_recursive_refs
): # on the first run it might be empty
continue
if (
model_name in self._recursive_refs
): # included and excluded refs might not be filled up yet, we're looking in all recursive refs
# we would need to clean up name recursive ref mapping later and delete excluded refs from there
seen_recursive_refs.add(model_name)
self.name_recursive_ref_mapping[current_path.split(".")[-1]] = model_name
if after_properties: # field name seen in pydantic models comes after "properties" key
if current_path:
new_path = f"{current_path}.{key}"
else:
new_path = key
else:
new_path = current_path
if isinstance(value, dict):
document_paths.extend(
self._find_document_paths(
value, new_path, key == "properties", seen_refs=seen_recursive_refs
)
)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
document_paths.extend(
self._find_document_paths(
item,
new_path,
key == "properties",
seen_refs=seen_recursive_refs,
)
)
return sorted(set(document_paths))
def parse_model(self, model: Type[BaseModel]) -> None:
"""Parse model schema to retrieve paths to objects requiring inference.
Checks model json schema, extracts definitions and finds paths to objects requiring inference.
No parsing happens if model has already been processed.
Args:
model: model to parse
Returns:
None
"""
model_name = model.__name__
if model_name in self._cache:
return None
schema = model_json_schema(model)
for k, v in schema.get("$defs", {}).items():
if k not in self._defs:
self._defs[k] = v
if "$defs" in schema:
raw_refs = (
{"$ref": schema["$ref"]}
if "$ref" in schema
else {"properties": schema["properties"]}
)
refs = self._replace_refs(raw_refs)
self._cache[model_name] = self._find_document_paths(refs)
else:
self._cache[model_name] = []
for ref in self._recursive_refs:
if ref in self._excluded_recursive_refs or ref in self._included_recursive_refs:
continue
if self._find_document_paths(self._defs[ref]):
self._included_recursive_refs.add(ref)
else:
self._excluded_recursive_refs.add(ref)
self.name_recursive_ref_mapping = {
k: v
for k, v in self.name_recursive_ref_mapping.items()
if v not in self._excluded_recursive_refs
}
# convert str paths to FieldPath objects which group path parts and reduce the time of the traversal
self.path_cache = {model: convert_paths(paths) for model, paths in self._cache.items()}
def _persist(self, output_path: Union[Path, str] = CACHE_PATH) -> None:
"""Persist the parser state to a file
Args:
output_path: path to the file to save the parser state
Returns:
None
"""
with open(output_path, "w") as f:
f.write(f"CACHE_STR_PATH = {self._cache}\n")
f.write(f"DEFS = {self._defs}\n")
# `sorted is required` to use `diff` in comparisons
f.write(f"RECURSIVE_REFS = {sorted(self._recursive_refs)}\n")
f.write(f"INCLUDED_RECURSIVE_REFS = {sorted(self._included_recursive_refs)}\n")
f.write(f"EXCLUDED_RECURSIVE_REFS = {sorted(self._excluded_recursive_refs)}\n")
f.write(f"NAME_RECURSIVE_REF_MAPPING = {self.name_recursive_ref_mapping}\n")
@@ -0,0 +1,149 @@
from typing import Union, Optional, Iterable, get_args
from pydantic import BaseModel
from qdrant_client._pydantic_compat import model_fields_set
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.schema_parser import ModelSchemaParser
from qdrant_client.embed.utils import FieldPath
class Inspector:
"""Inspector which tries to find at least one occurrence of an object requiring inference
Inspector is stateful and accumulates parsed model schemes in its parser.
Attributes:
parser: ModelSchemaParser instance to inspect model json schemas
"""
def __init__(self, parser: Optional[ModelSchemaParser] = None) -> None:
self.parser = ModelSchemaParser() if parser is None else parser
def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> bool:
"""Looks for at least one occurrence of an object requiring inference in the received models
Args:
points: models to inspect
Returns:
True if at least one object requiring inference is found, False otherwise
"""
if isinstance(points, BaseModel):
self.parser.parse_model(points.__class__)
return self._inspect_model(points)
elif isinstance(points, dict):
for value in points.values():
if self.inspect(value):
return True
elif isinstance(points, Iterable):
for point in points:
if isinstance(point, BaseModel):
self.parser.parse_model(point.__class__)
if self._inspect_model(point):
return True
else:
return False
return False
def _inspect_model(self, model: BaseModel, paths: Optional[list[FieldPath]] = None) -> bool:
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
return True
paths = (
self.parser.path_cache.get(model.__class__.__name__, []) if paths is None else paths
)
for path in paths:
type_found = self._inspect_inner_models(
model, path.current, path.tail if path.tail else []
)
if type_found:
return True
return False
def _inspect_inner_models(
self, original_model: BaseModel, current_path: str, tail: list[FieldPath]
) -> bool:
def inspect_recursive(member: BaseModel) -> bool:
recursive_paths = []
for field_name in model_fields_set(member):
if field_name in self.parser.name_recursive_ref_mapping:
mapped_model_name = self.parser.name_recursive_ref_mapping[field_name]
recursive_paths.extend(self.parser.path_cache[mapped_model_name])
if recursive_paths:
found = self._inspect_model(member, recursive_paths)
if found:
return True
return False
model = getattr(original_model, current_path, None)
if model is None:
return False
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
return True
if isinstance(model, BaseModel):
type_found = inspect_recursive(model)
if type_found:
return True
for next_path in tail:
type_found = self._inspect_inner_models(
model, next_path.current, next_path.tail if next_path.tail else []
)
if type_found:
return True
return False
elif isinstance(model, list):
for current_model in model:
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
return True
if not isinstance(current_model, BaseModel):
continue
type_found = inspect_recursive(current_model)
if type_found:
return True
for next_path in tail:
for current_model in model:
type_found = self._inspect_inner_models(
current_model, next_path.current, next_path.tail if next_path.tail else []
)
if type_found:
return True
return False
elif isinstance(model, dict):
for key, values in model.items():
values = [values] if not isinstance(values, list) else values
for current_model in values:
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
return True
if not isinstance(current_model, BaseModel):
continue
found_type = inspect_recursive(current_model)
if found_type:
return True
for next_path in tail:
for current_model in values:
found_type = self._inspect_inner_models(
current_model,
next_path.current,
next_path.tail if next_path.tail else [],
)
if found_type:
return True
return False
@@ -0,0 +1,79 @@
import base64
from pathlib import Path
from typing import Optional, Union
from pydantic import BaseModel, Field
class FieldPath(BaseModel):
current: str
tail: Optional[list["FieldPath"]] = Field(default=None)
def as_str_list(self) -> list[str]:
"""
>>> FieldPath(current='a', tail=[FieldPath(current='b', tail=[FieldPath(current='c'), FieldPath(current='d')])]).as_str_list()
['a.b.c', 'a.b.d']
"""
# Recursive function to collect all paths
def collect_paths(path: FieldPath, prefix: str = "") -> list[str]:
current_path = prefix + path.current
if not path.tail:
return [current_path]
else:
paths = []
for sub_path in path.tail:
paths.extend(collect_paths(sub_path, current_path + "."))
return paths
# Collect all paths starting from this object
return collect_paths(self)
def convert_paths(paths: list[str]) -> list[FieldPath]:
"""Convert string paths into FieldPath objects
Paths which share the same root are grouped together.
Args:
paths: List[str]: List of str paths containing "." as separator
Returns:
List[FieldPath]: List of FieldPath objects
"""
sorted_paths = sorted(paths)
prev_root = None
converted_paths = []
for path in sorted_paths:
parts = path.split(".")
root = parts[0]
if root != prev_root:
converted_paths.append(FieldPath(current=root))
prev_root = root
current = converted_paths[-1]
for part in parts[1:]:
if current.tail is None:
current.tail = []
found = False
for tail in current.tail:
if tail.current == part:
current = tail
found = True
break
if not found:
new_tail = FieldPath(current=part)
assert current.tail is not None
current.tail.append(new_tail)
current = new_tail
return converted_paths
def read_base64(file_path: Union[str, Path]) -> str:
"""Convert a file path to a base64 encoded string."""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"The file {path} does not exist.")
with open(path, "rb") as file:
file_content = file.read()
return base64.b64encode(file_content).decode("utf-8")
@@ -0,0 +1,323 @@
from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from qdrant_client.conversions.common_types import SparseVector
from qdrant_client.http import models
try:
from fastembed import (
TextEmbedding,
SparseTextEmbedding,
ImageEmbedding,
LateInteractionTextEmbedding,
LateInteractionMultimodalEmbedding,
)
from fastembed.common import OnnxProvider, ImageInput
except ImportError:
TextEmbedding = None
SparseTextEmbedding = None
ImageEmbedding = None
LateInteractionTextEmbedding = None
LateInteractionMultimodalEmbedding = None
OnnxProvider = None
ImageInput = None
class QueryResponse(BaseModel, extra="forbid"): # type: ignore
id: Union[str, int]
embedding: Optional[list[float]]
sparse_embedding: Optional[SparseVector] = Field(default=None)
metadata: dict[str, Any]
document: str
score: float
class FastEmbedMisc:
IS_INSTALLED: bool = False
_TEXT_MODELS: set[str] = set()
_IMAGE_MODELS: set[str] = set()
_LATE_INTERACTION_TEXT_MODELS: set[str] = set()
_LATE_INTERACTION_MULTIMODAL_MODELS: set[str] = set()
_SPARSE_MODELS: set[str] = set()
@classmethod
def is_installed(cls) -> bool:
if cls.IS_INSTALLED:
return cls.IS_INSTALLED
try:
from fastembed import (
SparseTextEmbedding,
TextEmbedding,
ImageEmbedding,
LateInteractionMultimodalEmbedding,
LateInteractionTextEmbedding,
)
assert len(SparseTextEmbedding.list_supported_models()) > 0
assert len(TextEmbedding.list_supported_models()) > 0
assert len(ImageEmbedding.list_supported_models()) > 0
assert len(LateInteractionTextEmbedding.list_supported_models()) > 0
assert len(LateInteractionMultimodalEmbedding.list_supported_models()) > 0
cls.IS_INSTALLED = True
except ImportError:
cls.IS_INSTALLED = False
return cls.IS_INSTALLED
@classmethod
def import_fastembed(cls) -> None:
if cls.IS_INSTALLED:
return
# If it's not, ask the user to install it
raise ImportError(
"fastembed is not installed."
" Please install it to enable fast vector indexing with `pip install fastembed`."
)
@classmethod
def list_text_models(cls) -> dict[str, tuple[int, models.Distance]]:
"""Lists the supported dense text models.
Requires invocation of TextEmbedding.list_supported_models() to support custom models.
Returns:
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
"""
return (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in TextEmbedding.list_supported_models()
}
if TextEmbedding
else {}
)
@classmethod
def list_image_models(cls) -> dict[str, tuple[int, models.Distance]]:
"""Lists the supported image dense models.
Custom image models are not supported yet, but calls to ImageEmbedding.list_supported_models() is done each
time in order for preserving the same style as with TextEmbedding.
Returns:
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
"""
return (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in ImageEmbedding.list_supported_models()
}
if ImageEmbedding
else {}
)
@classmethod
def list_late_interaction_text_models(cls) -> dict[str, tuple[int, models.Distance]]:
"""Lists the supported late interaction text models.
Custom late interaction models are not supported yet, but calls to
LateInteractionTextEmbedding.list_supported_models()
is done each time in order for preserving the same style as with TextEmbedding.
Returns:
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
"""
return (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in LateInteractionTextEmbedding.list_supported_models()
}
if LateInteractionTextEmbedding
else {}
)
@classmethod
def list_late_interaction_multimodal_models(cls) -> dict[str, tuple[int, models.Distance]]:
"""Lists the supported late interaction multimodal models.
Custom late interaction multimodal models are not supported yet, but calls to
LateInteractionMultimodalEmbedding.list_supported_models()
is done each time in order for preserving the same style as with TextEmbedding.
Returns:
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
"""
return (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in LateInteractionMultimodalEmbedding.list_supported_models()
}
if LateInteractionMultimodalEmbedding
else {}
)
@classmethod
def list_sparse_models(cls) -> dict[str, dict[str, Any]]:
"""Lists the supported sparse models.
Custom sparse models are not supported yet, but calls to
SparseTextEmbedding.list_supported_models()
is done each time in order for preserving the same style as with TextEmbedding.
Returns:
dict[str, dict[str, Any]]: A dict of model names and their descriptions.
"""
descriptions = {}
if SparseTextEmbedding:
for description in SparseTextEmbedding.list_supported_models():
descriptions[description.pop("model")] = description
return descriptions
@classmethod
def is_supported_text_model(cls, model_name: str) -> bool:
"""Checks if the model is supported by fastembed.
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
if model_name.lower() in cls._TEXT_MODELS:
return True
# update cached list in case custom models were added
cls._TEXT_MODELS = {model.lower() for model in cls.list_text_models()}
if model_name.lower() in cls._TEXT_MODELS:
return True
return False
@classmethod
def is_supported_image_model(cls, model_name: str) -> bool:
"""Checks if the model is supported by fastembed.
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
if model_name.lower() in cls._IMAGE_MODELS:
return True
# update cached list in case custom models were added
cls._IMAGE_MODELS = {model.lower() for model in cls.list_image_models()}
if model_name.lower() in cls._IMAGE_MODELS:
return True
return False
@classmethod
def is_supported_late_interaction_text_model(cls, model_name: str) -> bool:
"""Checks if the model is supported by fastembed.
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
if model_name.lower() in cls._LATE_INTERACTION_TEXT_MODELS:
return True
# update cached list in case custom models were added
cls._LATE_INTERACTION_TEXT_MODELS = {
model.lower() for model in cls.list_late_interaction_text_models()
}
if model_name.lower() in cls._LATE_INTERACTION_TEXT_MODELS:
return True
return False
@classmethod
def is_supported_late_interaction_multimodal_model(cls, model_name: str) -> bool:
"""Checks if the model is supported by fastembed.
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
if model_name.lower() in cls._LATE_INTERACTION_MULTIMODAL_MODELS:
return True
# update cached list in case custom models were added
cls._LATE_INTERACTION_MULTIMODAL_MODELS = {
model.lower() for model in cls.list_late_interaction_multimodal_models()
}
if model_name.lower() in cls._LATE_INTERACTION_MULTIMODAL_MODELS:
return True
return False
@classmethod
def is_supported_sparse_model(cls, model_name: str) -> bool:
"""Checks if the model is supported by fastembed.
Args:
model_name (str): The name of the model to check.
Returns:
bool: True if the model is supported, False otherwise.
"""
if model_name.lower() in cls._SPARSE_MODELS:
return True
# update cached list in case custom models were added
cls._SPARSE_MODELS = {model.lower() for model in cls.list_sparse_models()}
if model_name.lower() in cls._SPARSE_MODELS:
return True
return False
# region deprecated
# prefer using methods builtin into QdrantClient, e.g. list_supported_text_models, list_supported_idf_models, etc.
SUPPORTED_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in TextEmbedding.list_supported_models()
}
if TextEmbedding
else {}
)
SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, dict[str, Any]] = (
{model["model"]: model for model in SparseTextEmbedding.list_supported_models()}
if SparseTextEmbedding
else {}
)
IDF_EMBEDDING_MODELS: set[str] = (
{
model_config["model"]
for model_config in SparseTextEmbedding.list_supported_models()
if model_config.get("requires_idf", None)
}
if SparseTextEmbedding
else set()
)
_LATE_INTERACTION_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in LateInteractionTextEmbedding.list_supported_models()
}
if LateInteractionTextEmbedding
else {}
)
_IMAGE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in ImageEmbedding.list_supported_models()
}
if ImageEmbedding
else {}
)
_LATE_INTERACTION_MULTIMODAL_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
{
model["model"]: (model["dim"], models.Distance.COSINE)
for model in LateInteractionMultimodalEmbedding.list_supported_models()
}
if LateInteractionMultimodalEmbedding
else {}
)
# endregion
@@ -0,0 +1,10 @@
from .points_pb2 import *
from .collections_pb2 import *
from .qdrant_common_pb2 import *
from .snapshots_service_pb2 import *
from .json_with_int_pb2 import *
from .collections_service_pb2_grpc import *
from .points_service_pb2_grpc import *
from .snapshots_service_pb2_grpc import *
from .qdrant_pb2 import *
from .qdrant_pb2_grpc import *
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,4 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: collections_service.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from . import collections_pb2 as collections__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19\x63ollections_service.proto\x12\x06qdrant\x1a\x11\x63ollections.proto2\xe2\x08\n\x0b\x43ollections\x12L\n\x03Get\x12 .qdrant.GetCollectionInfoRequest\x1a!.qdrant.GetCollectionInfoResponse\"\x00\x12I\n\x04List\x12\x1e.qdrant.ListCollectionsRequest\x1a\x1f.qdrant.ListCollectionsResponse\"\x00\x12I\n\x06\x43reate\x12\x18.qdrant.CreateCollection\x1a#.qdrant.CollectionOperationResponse\"\x00\x12I\n\x06Update\x12\x18.qdrant.UpdateCollection\x1a#.qdrant.CollectionOperationResponse\"\x00\x12I\n\x06\x44\x65lete\x12\x18.qdrant.DeleteCollection\x1a#.qdrant.CollectionOperationResponse\"\x00\x12M\n\rUpdateAliases\x12\x15.qdrant.ChangeAliases\x1a#.qdrant.CollectionOperationResponse\"\x00\x12\\\n\x15ListCollectionAliases\x12$.qdrant.ListCollectionAliasesRequest\x1a\x1b.qdrant.ListAliasesResponse\"\x00\x12H\n\x0bListAliases\x12\x1a.qdrant.ListAliasesRequest\x1a\x1b.qdrant.ListAliasesResponse\"\x00\x12\x66\n\x15\x43ollectionClusterInfo\x12$.qdrant.CollectionClusterInfoRequest\x1a%.qdrant.CollectionClusterInfoResponse\"\x00\x12W\n\x10\x43ollectionExists\x12\x1f.qdrant.CollectionExistsRequest\x1a .qdrant.CollectionExistsResponse\"\x00\x12{\n\x1cUpdateCollectionClusterSetup\x12+.qdrant.UpdateCollectionClusterSetupRequest\x1a,.qdrant.UpdateCollectionClusterSetupResponse\"\x00\x12Q\n\x0e\x43reateShardKey\x12\x1d.qdrant.CreateShardKeyRequest\x1a\x1e.qdrant.CreateShardKeyResponse\"\x00\x12Q\n\x0e\x44\x65leteShardKey\x12\x1d.qdrant.DeleteShardKeyRequest\x1a\x1e.qdrant.DeleteShardKeyResponse\"\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'collections_service_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
_globals['_COLLECTIONS']._serialized_start=57
_globals['_COLLECTIONS']._serialized_end=1179
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,7 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import google.protobuf.descriptor
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@@ -0,0 +1,488 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from . import collections_pb2 as collections__pb2
class CollectionsStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Get = channel.unary_unary(
'/qdrant.Collections/Get',
request_serializer=collections__pb2.GetCollectionInfoRequest.SerializeToString,
response_deserializer=collections__pb2.GetCollectionInfoResponse.FromString,
)
self.List = channel.unary_unary(
'/qdrant.Collections/List',
request_serializer=collections__pb2.ListCollectionsRequest.SerializeToString,
response_deserializer=collections__pb2.ListCollectionsResponse.FromString,
)
self.Create = channel.unary_unary(
'/qdrant.Collections/Create',
request_serializer=collections__pb2.CreateCollection.SerializeToString,
response_deserializer=collections__pb2.CollectionOperationResponse.FromString,
)
self.Update = channel.unary_unary(
'/qdrant.Collections/Update',
request_serializer=collections__pb2.UpdateCollection.SerializeToString,
response_deserializer=collections__pb2.CollectionOperationResponse.FromString,
)
self.Delete = channel.unary_unary(
'/qdrant.Collections/Delete',
request_serializer=collections__pb2.DeleteCollection.SerializeToString,
response_deserializer=collections__pb2.CollectionOperationResponse.FromString,
)
self.UpdateAliases = channel.unary_unary(
'/qdrant.Collections/UpdateAliases',
request_serializer=collections__pb2.ChangeAliases.SerializeToString,
response_deserializer=collections__pb2.CollectionOperationResponse.FromString,
)
self.ListCollectionAliases = channel.unary_unary(
'/qdrant.Collections/ListCollectionAliases',
request_serializer=collections__pb2.ListCollectionAliasesRequest.SerializeToString,
response_deserializer=collections__pb2.ListAliasesResponse.FromString,
)
self.ListAliases = channel.unary_unary(
'/qdrant.Collections/ListAliases',
request_serializer=collections__pb2.ListAliasesRequest.SerializeToString,
response_deserializer=collections__pb2.ListAliasesResponse.FromString,
)
self.CollectionClusterInfo = channel.unary_unary(
'/qdrant.Collections/CollectionClusterInfo',
request_serializer=collections__pb2.CollectionClusterInfoRequest.SerializeToString,
response_deserializer=collections__pb2.CollectionClusterInfoResponse.FromString,
)
self.CollectionExists = channel.unary_unary(
'/qdrant.Collections/CollectionExists',
request_serializer=collections__pb2.CollectionExistsRequest.SerializeToString,
response_deserializer=collections__pb2.CollectionExistsResponse.FromString,
)
self.UpdateCollectionClusterSetup = channel.unary_unary(
'/qdrant.Collections/UpdateCollectionClusterSetup',
request_serializer=collections__pb2.UpdateCollectionClusterSetupRequest.SerializeToString,
response_deserializer=collections__pb2.UpdateCollectionClusterSetupResponse.FromString,
)
self.CreateShardKey = channel.unary_unary(
'/qdrant.Collections/CreateShardKey',
request_serializer=collections__pb2.CreateShardKeyRequest.SerializeToString,
response_deserializer=collections__pb2.CreateShardKeyResponse.FromString,
)
self.DeleteShardKey = channel.unary_unary(
'/qdrant.Collections/DeleteShardKey',
request_serializer=collections__pb2.DeleteShardKeyRequest.SerializeToString,
response_deserializer=collections__pb2.DeleteShardKeyResponse.FromString,
)
class CollectionsServicer(object):
"""Missing associated documentation comment in .proto file."""
def Get(self, request, context):
"""
Get detailed information about specified existing collection
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def List(self, request, context):
"""
Get list name of all existing collections
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Create(self, request, context):
"""
Create new collection with given parameters
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Update(self, request, context):
"""
Update parameters of the existing collection
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Delete(self, request, context):
"""
Drop collection and all associated data
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def UpdateAliases(self, request, context):
"""
Update Aliases of the existing collection
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ListCollectionAliases(self, request, context):
"""
Get list of all aliases for a collection
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ListAliases(self, request, context):
"""
Get list of all aliases for all existing collections
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CollectionClusterInfo(self, request, context):
"""
Get cluster information for a collection
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CollectionExists(self, request, context):
"""
Check the existence of a collection
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def UpdateCollectionClusterSetup(self, request, context):
"""
Update cluster setup for a collection
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CreateShardKey(self, request, context):
"""
Create shard key
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def DeleteShardKey(self, request, context):
"""
Delete shard key
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_CollectionsServicer_to_server(servicer, server):
rpc_method_handlers = {
'Get': grpc.unary_unary_rpc_method_handler(
servicer.Get,
request_deserializer=collections__pb2.GetCollectionInfoRequest.FromString,
response_serializer=collections__pb2.GetCollectionInfoResponse.SerializeToString,
),
'List': grpc.unary_unary_rpc_method_handler(
servicer.List,
request_deserializer=collections__pb2.ListCollectionsRequest.FromString,
response_serializer=collections__pb2.ListCollectionsResponse.SerializeToString,
),
'Create': grpc.unary_unary_rpc_method_handler(
servicer.Create,
request_deserializer=collections__pb2.CreateCollection.FromString,
response_serializer=collections__pb2.CollectionOperationResponse.SerializeToString,
),
'Update': grpc.unary_unary_rpc_method_handler(
servicer.Update,
request_deserializer=collections__pb2.UpdateCollection.FromString,
response_serializer=collections__pb2.CollectionOperationResponse.SerializeToString,
),
'Delete': grpc.unary_unary_rpc_method_handler(
servicer.Delete,
request_deserializer=collections__pb2.DeleteCollection.FromString,
response_serializer=collections__pb2.CollectionOperationResponse.SerializeToString,
),
'UpdateAliases': grpc.unary_unary_rpc_method_handler(
servicer.UpdateAliases,
request_deserializer=collections__pb2.ChangeAliases.FromString,
response_serializer=collections__pb2.CollectionOperationResponse.SerializeToString,
),
'ListCollectionAliases': grpc.unary_unary_rpc_method_handler(
servicer.ListCollectionAliases,
request_deserializer=collections__pb2.ListCollectionAliasesRequest.FromString,
response_serializer=collections__pb2.ListAliasesResponse.SerializeToString,
),
'ListAliases': grpc.unary_unary_rpc_method_handler(
servicer.ListAliases,
request_deserializer=collections__pb2.ListAliasesRequest.FromString,
response_serializer=collections__pb2.ListAliasesResponse.SerializeToString,
),
'CollectionClusterInfo': grpc.unary_unary_rpc_method_handler(
servicer.CollectionClusterInfo,
request_deserializer=collections__pb2.CollectionClusterInfoRequest.FromString,
response_serializer=collections__pb2.CollectionClusterInfoResponse.SerializeToString,
),
'CollectionExists': grpc.unary_unary_rpc_method_handler(
servicer.CollectionExists,
request_deserializer=collections__pb2.CollectionExistsRequest.FromString,
response_serializer=collections__pb2.CollectionExistsResponse.SerializeToString,
),
'UpdateCollectionClusterSetup': grpc.unary_unary_rpc_method_handler(
servicer.UpdateCollectionClusterSetup,
request_deserializer=collections__pb2.UpdateCollectionClusterSetupRequest.FromString,
response_serializer=collections__pb2.UpdateCollectionClusterSetupResponse.SerializeToString,
),
'CreateShardKey': grpc.unary_unary_rpc_method_handler(
servicer.CreateShardKey,
request_deserializer=collections__pb2.CreateShardKeyRequest.FromString,
response_serializer=collections__pb2.CreateShardKeyResponse.SerializeToString,
),
'DeleteShardKey': grpc.unary_unary_rpc_method_handler(
servicer.DeleteShardKey,
request_deserializer=collections__pb2.DeleteShardKeyRequest.FromString,
response_serializer=collections__pb2.DeleteShardKeyResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'qdrant.Collections', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class Collections(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Get(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/Get',
collections__pb2.GetCollectionInfoRequest.SerializeToString,
collections__pb2.GetCollectionInfoResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def List(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/List',
collections__pb2.ListCollectionsRequest.SerializeToString,
collections__pb2.ListCollectionsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Create(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/Create',
collections__pb2.CreateCollection.SerializeToString,
collections__pb2.CollectionOperationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Update(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/Update',
collections__pb2.UpdateCollection.SerializeToString,
collections__pb2.CollectionOperationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Delete(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/Delete',
collections__pb2.DeleteCollection.SerializeToString,
collections__pb2.CollectionOperationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def UpdateAliases(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/UpdateAliases',
collections__pb2.ChangeAliases.SerializeToString,
collections__pb2.CollectionOperationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def ListCollectionAliases(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/ListCollectionAliases',
collections__pb2.ListCollectionAliasesRequest.SerializeToString,
collections__pb2.ListAliasesResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def ListAliases(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/ListAliases',
collections__pb2.ListAliasesRequest.SerializeToString,
collections__pb2.ListAliasesResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def CollectionClusterInfo(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/CollectionClusterInfo',
collections__pb2.CollectionClusterInfoRequest.SerializeToString,
collections__pb2.CollectionClusterInfoResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def CollectionExists(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/CollectionExists',
collections__pb2.CollectionExistsRequest.SerializeToString,
collections__pb2.CollectionExistsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def UpdateCollectionClusterSetup(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/UpdateCollectionClusterSetup',
collections__pb2.UpdateCollectionClusterSetupRequest.SerializeToString,
collections__pb2.UpdateCollectionClusterSetupResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def CreateShardKey(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/CreateShardKey',
collections__pb2.CreateShardKeyRequest.SerializeToString,
collections__pb2.CreateShardKeyResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def DeleteShardKey(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/DeleteShardKey',
collections__pb2.DeleteShardKeyRequest.SerializeToString,
collections__pb2.DeleteShardKeyResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: json_with_int.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13json_with_int.proto\x12\x06qdrant\"r\n\x06Struct\x12*\n\x06\x66ields\x18\x01 \x03(\x0b\x32\x1a.qdrant.Struct.FieldsEntry\x1a<\n\x0b\x46ieldsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1c\n\x05value\x18\x02 \x01(\x0b\x32\r.qdrant.Value:\x02\x38\x01\"\xe8\x01\n\x05Value\x12\'\n\nnull_value\x18\x01 \x01(\x0e\x32\x11.qdrant.NullValueH\x00\x12\x16\n\x0c\x64ouble_value\x18\x02 \x01(\x01H\x00\x12\x17\n\rinteger_value\x18\x03 \x01(\x03H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbool_value\x18\x05 \x01(\x08H\x00\x12&\n\x0cstruct_value\x18\x06 \x01(\x0b\x32\x0e.qdrant.StructH\x00\x12\'\n\nlist_value\x18\x07 \x01(\x0b\x32\x11.qdrant.ListValueH\x00\x42\x06\n\x04kind\"*\n\tListValue\x12\x1d\n\x06values\x18\x01 \x03(\x0b\x32\r.qdrant.Value*\x1b\n\tNullValue\x12\x0e\n\nNULL_VALUE\x10\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'json_with_int_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
_globals['_STRUCT_FIELDSENTRY']._options = None
_globals['_STRUCT_FIELDSENTRY']._serialized_options = b'8\001'
_globals['_NULLVALUE']._serialized_start=426
_globals['_NULLVALUE']._serialized_end=453
_globals['_STRUCT']._serialized_start=31
_globals['_STRUCT']._serialized_end=145
_globals['_STRUCT_FIELDSENTRY']._serialized_start=85
_globals['_STRUCT_FIELDSENTRY']._serialized_end=145
_globals['_VALUE']._serialized_start=148
_globals['_VALUE']._serialized_end=380
_globals['_LISTVALUE']._serialized_start=382
_globals['_LISTVALUE']._serialized_end=424
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,154 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Fork of the google.protobuf.Value with explicit support for integer values"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class _NullValue:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _NullValueEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_NullValue.ValueType], builtins.type): # noqa: F821
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NULL_VALUE: _NullValue.ValueType # 0
"""Null value."""
class NullValue(_NullValue, metaclass=_NullValueEnumTypeWrapper):
"""`NullValue` is a singleton enumeration to represent the null value for the
`Value` type union.
The JSON representation for `NullValue` is JSON `null`.
"""
NULL_VALUE: NullValue.ValueType # 0
"""Null value."""
global___NullValue = NullValue
class Struct(google.protobuf.message.Message):
"""`Struct` represents a structured data value, consisting of fields
which map to dynamically typed values. In some languages, `Struct`
might be supported by a native representation. For example, in
scripting languages like JS a struct is represented as an
object. The details of that representation are described together
with the proto support for the language.
The JSON representation for `Struct` is a JSON object.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class FieldsEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
@property
def value(self) -> global___Value: ...
def __init__(
self,
*,
key: builtins.str = ...,
value: global___Value | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]) -> None: ...
FIELDS_FIELD_NUMBER: builtins.int
@property
def fields(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___Value]:
"""Unordered map of dynamically typed values."""
def __init__(
self,
*,
fields: collections.abc.Mapping[builtins.str, global___Value] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["fields", b"fields"]) -> None: ...
global___Struct = Struct
class Value(google.protobuf.message.Message):
"""`Value` represents a dynamically typed value which can be either
null, a number, a string, a boolean, a recursive struct value, or a
list of values. A producer of value is expected to set one of those
variants, absence of any variant indicates an error.
The JSON representation for `Value` is a JSON value.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NULL_VALUE_FIELD_NUMBER: builtins.int
DOUBLE_VALUE_FIELD_NUMBER: builtins.int
INTEGER_VALUE_FIELD_NUMBER: builtins.int
STRING_VALUE_FIELD_NUMBER: builtins.int
BOOL_VALUE_FIELD_NUMBER: builtins.int
STRUCT_VALUE_FIELD_NUMBER: builtins.int
LIST_VALUE_FIELD_NUMBER: builtins.int
null_value: global___NullValue.ValueType
"""Represents a null value."""
double_value: builtins.float
"""Represents a double value."""
integer_value: builtins.int
"""Represents an integer value"""
string_value: builtins.str
"""Represents a string value."""
bool_value: builtins.bool
"""Represents a boolean value."""
@property
def struct_value(self) -> global___Struct:
"""Represents a structured value."""
@property
def list_value(self) -> global___ListValue:
"""Represents a repeated `Value`."""
def __init__(
self,
*,
null_value: global___NullValue.ValueType = ...,
double_value: builtins.float = ...,
integer_value: builtins.int = ...,
string_value: builtins.str = ...,
bool_value: builtins.bool = ...,
struct_value: global___Struct | None = ...,
list_value: global___ListValue | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["bool_value", b"bool_value", "double_value", b"double_value", "integer_value", b"integer_value", "kind", b"kind", "list_value", b"list_value", "null_value", b"null_value", "string_value", b"string_value", "struct_value", b"struct_value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["bool_value", b"bool_value", "double_value", b"double_value", "integer_value", b"integer_value", "kind", b"kind", "list_value", b"list_value", "null_value", b"null_value", "string_value", b"string_value", "struct_value", b"struct_value"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["kind", b"kind"]) -> typing_extensions.Literal["null_value", "double_value", "integer_value", "string_value", "bool_value", "struct_value", "list_value"] | None: ...
global___Value = Value
class ListValue(google.protobuf.message.Message):
"""`ListValue` is a wrapper around a repeated field of values.
The JSON representation for `ListValue` is a JSON array.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
VALUES_FIELD_NUMBER: builtins.int
@property
def values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Value]:
"""Repeated field of dynamically typed values."""
def __init__(
self,
*,
values: collections.abc.Iterable[global___Value] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["values", b"values"]) -> None: ...
global___ListValue = ListValue
@@ -0,0 +1,4 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,4 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: points_service.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from . import points_pb2 as points__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14points_service.proto\x12\x06qdrant\x1a\x0cpoints.proto2\xfd\x0f\n\x06Points\x12\x41\n\x06Upsert\x12\x14.qdrant.UpsertPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12\x41\n\x06\x44\x65lete\x12\x14.qdrant.DeletePoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12/\n\x03Get\x12\x11.qdrant.GetPoints\x1a\x13.qdrant.GetResponse\"\x00\x12N\n\rUpdateVectors\x12\x1a.qdrant.UpdatePointVectors\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12N\n\rDeleteVectors\x12\x1a.qdrant.DeletePointVectors\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12I\n\nSetPayload\x12\x18.qdrant.SetPayloadPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12O\n\x10OverwritePayload\x12\x18.qdrant.SetPayloadPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12O\n\rDeletePayload\x12\x1b.qdrant.DeletePayloadPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12M\n\x0c\x43learPayload\x12\x1a.qdrant.ClearPayloadPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12Y\n\x10\x43reateFieldIndex\x12\".qdrant.CreateFieldIndexCollection\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12Y\n\x10\x44\x65leteFieldIndex\x12\".qdrant.DeleteFieldIndexCollection\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12\x38\n\x06Search\x12\x14.qdrant.SearchPoints\x1a\x16.qdrant.SearchResponse\"\x00\x12G\n\x0bSearchBatch\x12\x19.qdrant.SearchBatchPoints\x1a\x1b.qdrant.SearchBatchResponse\"\x00\x12I\n\x0cSearchGroups\x12\x19.qdrant.SearchPointGroups\x1a\x1c.qdrant.SearchGroupsResponse\"\x00\x12\x38\n\x06Scroll\x12\x14.qdrant.ScrollPoints\x1a\x16.qdrant.ScrollResponse\"\x00\x12\x41\n\tRecommend\x12\x17.qdrant.RecommendPoints\x1a\x19.qdrant.RecommendResponse\"\x00\x12P\n\x0eRecommendBatch\x12\x1c.qdrant.RecommendBatchPoints\x1a\x1e.qdrant.RecommendBatchResponse\"\x00\x12R\n\x0fRecommendGroups\x12\x1c.qdrant.RecommendPointGroups\x1a\x1f.qdrant.RecommendGroupsResponse\"\x00\x12>\n\x08\x44iscover\x12\x16.qdrant.DiscoverPoints\x1a\x18.qdrant.DiscoverResponse\"\x00\x12M\n\rDiscoverBatch\x12\x1b.qdrant.DiscoverBatchPoints\x1a\x1d.qdrant.DiscoverBatchResponse\"\x00\x12\x35\n\x05\x43ount\x12\x13.qdrant.CountPoints\x1a\x15.qdrant.CountResponse\"\x00\x12G\n\x0bUpdateBatch\x12\x19.qdrant.UpdateBatchPoints\x1a\x1b.qdrant.UpdateBatchResponse\"\x00\x12\x35\n\x05Query\x12\x13.qdrant.QueryPoints\x1a\x15.qdrant.QueryResponse\"\x00\x12\x44\n\nQueryBatch\x12\x18.qdrant.QueryBatchPoints\x1a\x1a.qdrant.QueryBatchResponse\"\x00\x12\x46\n\x0bQueryGroups\x12\x18.qdrant.QueryPointGroups\x1a\x1b.qdrant.QueryGroupsResponse\"\x00\x12\x35\n\x05\x46\x61\x63\x65t\x12\x13.qdrant.FacetCounts\x1a\x15.qdrant.FacetResponse\"\x00\x12T\n\x11SearchMatrixPairs\x12\x1a.qdrant.SearchMatrixPoints\x1a!.qdrant.SearchMatrixPairsResponse\"\x00\x12X\n\x13SearchMatrixOffsets\x12\x1a.qdrant.SearchMatrixPoints\x1a#.qdrant.SearchMatrixOffsetsResponse\"\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'points_service_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
_globals['_POINTS']._serialized_start=47
_globals['_POINTS']._serialized_end=2092
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,7 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import google.protobuf.descriptor
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: qdrant_common.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13qdrant_common.proto\x12\x06qdrant\x1a\x1fgoogle/protobuf/timestamp.proto\"<\n\x07PointId\x12\r\n\x03num\x18\x01 \x01(\x04H\x00\x12\x0e\n\x04uuid\x18\x02 \x01(\tH\x00\x42\x12\n\x10point_id_options\"$\n\x08GeoPoint\x12\x0b\n\x03lon\x18\x01 \x01(\x01\x12\x0b\n\x03lat\x18\x02 \x01(\x01\"\xac\x01\n\x06\x46ilter\x12!\n\x06should\x18\x01 \x03(\x0b\x32\x11.qdrant.Condition\x12\x1f\n\x04must\x18\x02 \x03(\x0b\x32\x11.qdrant.Condition\x12#\n\x08must_not\x18\x03 \x03(\x0b\x32\x11.qdrant.Condition\x12*\n\nmin_should\x18\x04 \x01(\x0b\x32\x11.qdrant.MinShouldH\x00\x88\x01\x01\x42\r\n\x0b_min_should\"E\n\tMinShould\x12%\n\nconditions\x18\x01 \x03(\x0b\x32\x11.qdrant.Condition\x12\x11\n\tmin_count\x18\x02 \x01(\x04\"\xcb\x02\n\tCondition\x12\'\n\x05\x66ield\x18\x01 \x01(\x0b\x32\x16.qdrant.FieldConditionH\x00\x12,\n\x08is_empty\x18\x02 \x01(\x0b\x32\x18.qdrant.IsEmptyConditionH\x00\x12(\n\x06has_id\x18\x03 \x01(\x0b\x32\x16.qdrant.HasIdConditionH\x00\x12 \n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x0e.qdrant.FilterH\x00\x12*\n\x07is_null\x18\x05 \x01(\x0b\x32\x17.qdrant.IsNullConditionH\x00\x12)\n\x06nested\x18\x06 \x01(\x0b\x32\x17.qdrant.NestedConditionH\x00\x12\x30\n\nhas_vector\x18\x07 \x01(\x0b\x32\x1a.qdrant.HasVectorConditionH\x00\x42\x12\n\x10\x63ondition_one_of\"\x1f\n\x10IsEmptyCondition\x12\x0b\n\x03key\x18\x01 \x01(\t\"\x1e\n\x0fIsNullCondition\x12\x0b\n\x03key\x18\x01 \x01(\t\"1\n\x0eHasIdCondition\x12\x1f\n\x06has_id\x18\x01 \x03(\x0b\x32\x0f.qdrant.PointId\"(\n\x12HasVectorCondition\x12\x12\n\nhas_vector\x18\x01 \x01(\t\">\n\x0fNestedCondition\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1e\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x0e.qdrant.Filter\"\xfb\x02\n\x0e\x46ieldCondition\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1c\n\x05match\x18\x02 \x01(\x0b\x32\r.qdrant.Match\x12\x1c\n\x05range\x18\x03 \x01(\x0b\x32\r.qdrant.Range\x12\x30\n\x10geo_bounding_box\x18\x04 \x01(\x0b\x32\x16.qdrant.GeoBoundingBox\x12%\n\ngeo_radius\x18\x05 \x01(\x0b\x32\x11.qdrant.GeoRadius\x12)\n\x0cvalues_count\x18\x06 \x01(\x0b\x32\x13.qdrant.ValuesCount\x12\'\n\x0bgeo_polygon\x18\x07 \x01(\x0b\x32\x12.qdrant.GeoPolygon\x12-\n\x0e\x64\x61tetime_range\x18\x08 \x01(\x0b\x32\x15.qdrant.DatetimeRange\x12\x15\n\x08is_empty\x18\t \x01(\x08H\x00\x88\x01\x01\x12\x14\n\x07is_null\x18\n \x01(\x08H\x01\x88\x01\x01\x42\x0b\n\t_is_emptyB\n\n\x08_is_null\"\xc9\x02\n\x05Match\x12\x11\n\x07keyword\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x03H\x00\x12\x11\n\x07\x62oolean\x18\x03 \x01(\x08H\x00\x12\x0e\n\x04text\x18\x04 \x01(\tH\x00\x12+\n\x08keywords\x18\x05 \x01(\x0b\x32\x17.qdrant.RepeatedStringsH\x00\x12,\n\x08integers\x18\x06 \x01(\x0b\x32\x18.qdrant.RepeatedIntegersH\x00\x12\x33\n\x0f\x65xcept_integers\x18\x07 \x01(\x0b\x32\x18.qdrant.RepeatedIntegersH\x00\x12\x32\n\x0f\x65xcept_keywords\x18\x08 \x01(\x0b\x32\x17.qdrant.RepeatedStringsH\x00\x12\x10\n\x06phrase\x18\t \x01(\tH\x00\x12\x12\n\x08text_any\x18\n \x01(\tH\x00\x42\r\n\x0bmatch_value\"\"\n\x0fRepeatedStrings\x12\x0f\n\x07strings\x18\x01 \x03(\t\"$\n\x10RepeatedIntegers\x12\x10\n\x08integers\x18\x01 \x03(\x03\"k\n\x05Range\x12\x0f\n\x02lt\x18\x01 \x01(\x01H\x00\x88\x01\x01\x12\x0f\n\x02gt\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12\x10\n\x03gte\x18\x03 \x01(\x01H\x02\x88\x01\x01\x12\x10\n\x03lte\x18\x04 \x01(\x01H\x03\x88\x01\x01\x42\x05\n\x03_ltB\x05\n\x03_gtB\x06\n\x04_gteB\x06\n\x04_lte\"\xe3\x01\n\rDatetimeRange\x12+\n\x02lt\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x88\x01\x01\x12+\n\x02gt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x01\x88\x01\x01\x12,\n\x03gte\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x02\x88\x01\x01\x12,\n\x03lte\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x03\x88\x01\x01\x42\x05\n\x03_ltB\x05\n\x03_gtB\x06\n\x04_gteB\x06\n\x04_lte\"\\\n\x0eGeoBoundingBox\x12\"\n\x08top_left\x18\x01 \x01(\x0b\x32\x10.qdrant.GeoPoint\x12&\n\x0c\x62ottom_right\x18\x02 \x01(\x0b\x32\x10.qdrant.GeoPoint\"=\n\tGeoRadius\x12 \n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x10.qdrant.GeoPoint\x12\x0e\n\x06radius\x18\x02 \x01(\x02\"1\n\rGeoLineString\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.qdrant.GeoPoint\"_\n\nGeoPolygon\x12\'\n\x08\x65xterior\x18\x01 \x01(\x0b\x32\x15.qdrant.GeoLineString\x12(\n\tinteriors\x18\x02 \x03(\x0b\x32\x15.qdrant.GeoLineString\"q\n\x0bValuesCount\x12\x0f\n\x02lt\x18\x01 \x01(\x04H\x00\x88\x01\x01\x12\x0f\n\x02gt\x18\x02 \x01(\x04H\x01\x88\x01\x01\x12\x10\n\x03gte\x18\x03 \x01(\x04H\x02\x88\x01\x01\x12\x10\n\x03lte\x18\x04 \x01(\x04H\x03\x88\x01\x01\x42\x05\n\x03_ltB\x05\n\x03_gtB\x06\n\x04_gteB\x06\n\x04_lteB\x1d\x42\x06\x43ommon\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'qdrant_common_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'B\006Common\252\002\022Qdrant.Client.Grpc'
_globals['_POINTID']._serialized_start=64
_globals['_POINTID']._serialized_end=124
_globals['_GEOPOINT']._serialized_start=126
_globals['_GEOPOINT']._serialized_end=162
_globals['_FILTER']._serialized_start=165
_globals['_FILTER']._serialized_end=337
_globals['_MINSHOULD']._serialized_start=339
_globals['_MINSHOULD']._serialized_end=408
_globals['_CONDITION']._serialized_start=411
_globals['_CONDITION']._serialized_end=742
_globals['_ISEMPTYCONDITION']._serialized_start=744
_globals['_ISEMPTYCONDITION']._serialized_end=775
_globals['_ISNULLCONDITION']._serialized_start=777
_globals['_ISNULLCONDITION']._serialized_end=807
_globals['_HASIDCONDITION']._serialized_start=809
_globals['_HASIDCONDITION']._serialized_end=858
_globals['_HASVECTORCONDITION']._serialized_start=860
_globals['_HASVECTORCONDITION']._serialized_end=900
_globals['_NESTEDCONDITION']._serialized_start=902
_globals['_NESTEDCONDITION']._serialized_end=964
_globals['_FIELDCONDITION']._serialized_start=967
_globals['_FIELDCONDITION']._serialized_end=1346
_globals['_MATCH']._serialized_start=1349
_globals['_MATCH']._serialized_end=1678
_globals['_REPEATEDSTRINGS']._serialized_start=1680
_globals['_REPEATEDSTRINGS']._serialized_end=1714
_globals['_REPEATEDINTEGERS']._serialized_start=1716
_globals['_REPEATEDINTEGERS']._serialized_end=1752
_globals['_RANGE']._serialized_start=1754
_globals['_RANGE']._serialized_end=1861
_globals['_DATETIMERANGE']._serialized_start=1864
_globals['_DATETIMERANGE']._serialized_end=2091
_globals['_GEOBOUNDINGBOX']._serialized_start=2093
_globals['_GEOBOUNDINGBOX']._serialized_end=2185
_globals['_GEORADIUS']._serialized_start=2187
_globals['_GEORADIUS']._serialized_end=2248
_globals['_GEOLINESTRING']._serialized_start=2250
_globals['_GEOLINESTRING']._serialized_end=2299
_globals['_GEOPOLYGON']._serialized_start=2301
_globals['_GEOPOLYGON']._serialized_end=2396
_globals['_VALUESCOUNT']._serialized_start=2398
_globals['_VALUESCOUNT']._serialized_end=2511
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,561 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import google.protobuf.timestamp_pb2
import sys
import typing
if sys.version_info >= (3, 8):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class PointId(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NUM_FIELD_NUMBER: builtins.int
UUID_FIELD_NUMBER: builtins.int
num: builtins.int
"""Numerical ID of the point"""
uuid: builtins.str
"""UUID"""
def __init__(
self,
*,
num: builtins.int = ...,
uuid: builtins.str = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["num", b"num", "point_id_options", b"point_id_options", "uuid", b"uuid"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["num", b"num", "point_id_options", b"point_id_options", "uuid", b"uuid"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["point_id_options", b"point_id_options"]) -> typing_extensions.Literal["num", "uuid"] | None: ...
global___PointId = PointId
class GeoPoint(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
LON_FIELD_NUMBER: builtins.int
LAT_FIELD_NUMBER: builtins.int
lon: builtins.float
lat: builtins.float
def __init__(
self,
*,
lon: builtins.float = ...,
lat: builtins.float = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["lat", b"lat", "lon", b"lon"]) -> None: ...
global___GeoPoint = GeoPoint
class Filter(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SHOULD_FIELD_NUMBER: builtins.int
MUST_FIELD_NUMBER: builtins.int
MUST_NOT_FIELD_NUMBER: builtins.int
MIN_SHOULD_FIELD_NUMBER: builtins.int
@property
def should(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Condition]:
"""At least one of those conditions should match"""
@property
def must(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Condition]:
"""All conditions must match"""
@property
def must_not(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Condition]:
"""All conditions must NOT match"""
@property
def min_should(self) -> global___MinShould:
"""At least minimum amount of given conditions should match"""
def __init__(
self,
*,
should: collections.abc.Iterable[global___Condition] | None = ...,
must: collections.abc.Iterable[global___Condition] | None = ...,
must_not: collections.abc.Iterable[global___Condition] | None = ...,
min_should: global___MinShould | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_min_should", b"_min_should", "min_should", b"min_should"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_min_should", b"_min_should", "min_should", b"min_should", "must", b"must", "must_not", b"must_not", "should", b"should"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["_min_should", b"_min_should"]) -> typing_extensions.Literal["min_should"] | None: ...
global___Filter = Filter
class MinShould(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
CONDITIONS_FIELD_NUMBER: builtins.int
MIN_COUNT_FIELD_NUMBER: builtins.int
@property
def conditions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Condition]: ...
min_count: builtins.int
def __init__(
self,
*,
conditions: collections.abc.Iterable[global___Condition] | None = ...,
min_count: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["conditions", b"conditions", "min_count", b"min_count"]) -> None: ...
global___MinShould = MinShould
class Condition(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
FIELD_FIELD_NUMBER: builtins.int
IS_EMPTY_FIELD_NUMBER: builtins.int
HAS_ID_FIELD_NUMBER: builtins.int
FILTER_FIELD_NUMBER: builtins.int
IS_NULL_FIELD_NUMBER: builtins.int
NESTED_FIELD_NUMBER: builtins.int
HAS_VECTOR_FIELD_NUMBER: builtins.int
@property
def field(self) -> global___FieldCondition: ...
@property
def is_empty(self) -> global___IsEmptyCondition: ...
@property
def has_id(self) -> global___HasIdCondition: ...
@property
def filter(self) -> global___Filter: ...
@property
def is_null(self) -> global___IsNullCondition: ...
@property
def nested(self) -> global___NestedCondition: ...
@property
def has_vector(self) -> global___HasVectorCondition: ...
def __init__(
self,
*,
field: global___FieldCondition | None = ...,
is_empty: global___IsEmptyCondition | None = ...,
has_id: global___HasIdCondition | None = ...,
filter: global___Filter | None = ...,
is_null: global___IsNullCondition | None = ...,
nested: global___NestedCondition | None = ...,
has_vector: global___HasVectorCondition | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["condition_one_of", b"condition_one_of", "field", b"field", "filter", b"filter", "has_id", b"has_id", "has_vector", b"has_vector", "is_empty", b"is_empty", "is_null", b"is_null", "nested", b"nested"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["condition_one_of", b"condition_one_of", "field", b"field", "filter", b"filter", "has_id", b"has_id", "has_vector", b"has_vector", "is_empty", b"is_empty", "is_null", b"is_null", "nested", b"nested"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["condition_one_of", b"condition_one_of"]) -> typing_extensions.Literal["field", "is_empty", "has_id", "filter", "is_null", "nested", "has_vector"] | None: ...
global___Condition = Condition
class IsEmptyCondition(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
key: builtins.str
def __init__(
self,
*,
key: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["key", b"key"]) -> None: ...
global___IsEmptyCondition = IsEmptyCondition
class IsNullCondition(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
key: builtins.str
def __init__(
self,
*,
key: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["key", b"key"]) -> None: ...
global___IsNullCondition = IsNullCondition
class HasIdCondition(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
HAS_ID_FIELD_NUMBER: builtins.int
@property
def has_id(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PointId]: ...
def __init__(
self,
*,
has_id: collections.abc.Iterable[global___PointId] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["has_id", b"has_id"]) -> None: ...
global___HasIdCondition = HasIdCondition
class HasVectorCondition(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
HAS_VECTOR_FIELD_NUMBER: builtins.int
has_vector: builtins.str
def __init__(
self,
*,
has_vector: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["has_vector", b"has_vector"]) -> None: ...
global___HasVectorCondition = HasVectorCondition
class NestedCondition(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
FILTER_FIELD_NUMBER: builtins.int
key: builtins.str
"""Path to nested object"""
@property
def filter(self) -> global___Filter:
"""Filter condition"""
def __init__(
self,
*,
key: builtins.str = ...,
filter: global___Filter | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["filter", b"filter"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["filter", b"filter", "key", b"key"]) -> None: ...
global___NestedCondition = NestedCondition
class FieldCondition(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
MATCH_FIELD_NUMBER: builtins.int
RANGE_FIELD_NUMBER: builtins.int
GEO_BOUNDING_BOX_FIELD_NUMBER: builtins.int
GEO_RADIUS_FIELD_NUMBER: builtins.int
VALUES_COUNT_FIELD_NUMBER: builtins.int
GEO_POLYGON_FIELD_NUMBER: builtins.int
DATETIME_RANGE_FIELD_NUMBER: builtins.int
IS_EMPTY_FIELD_NUMBER: builtins.int
IS_NULL_FIELD_NUMBER: builtins.int
key: builtins.str
@property
def match(self) -> global___Match:
"""Check if point has field with a given value"""
@property
def range(self) -> global___Range:
"""Check if points value lies in a given range"""
@property
def geo_bounding_box(self) -> global___GeoBoundingBox:
"""Check if points geolocation lies in a given area"""
@property
def geo_radius(self) -> global___GeoRadius:
"""Check if geo point is within a given radius"""
@property
def values_count(self) -> global___ValuesCount:
"""Check number of values for a specific field"""
@property
def geo_polygon(self) -> global___GeoPolygon:
"""Check if geo point is within a given polygon"""
@property
def datetime_range(self) -> global___DatetimeRange:
"""Check if datetime is within a given range"""
is_empty: builtins.bool
"""Check if field is empty"""
is_null: builtins.bool
"""Check if field is null"""
def __init__(
self,
*,
key: builtins.str = ...,
match: global___Match | None = ...,
range: global___Range | None = ...,
geo_bounding_box: global___GeoBoundingBox | None = ...,
geo_radius: global___GeoRadius | None = ...,
values_count: global___ValuesCount | None = ...,
geo_polygon: global___GeoPolygon | None = ...,
datetime_range: global___DatetimeRange | None = ...,
is_empty: builtins.bool | None = ...,
is_null: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_is_empty", b"_is_empty", "_is_null", b"_is_null", "datetime_range", b"datetime_range", "geo_bounding_box", b"geo_bounding_box", "geo_polygon", b"geo_polygon", "geo_radius", b"geo_radius", "is_empty", b"is_empty", "is_null", b"is_null", "match", b"match", "range", b"range", "values_count", b"values_count"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_is_empty", b"_is_empty", "_is_null", b"_is_null", "datetime_range", b"datetime_range", "geo_bounding_box", b"geo_bounding_box", "geo_polygon", b"geo_polygon", "geo_radius", b"geo_radius", "is_empty", b"is_empty", "is_null", b"is_null", "key", b"key", "match", b"match", "range", b"range", "values_count", b"values_count"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_is_empty", b"_is_empty"]) -> typing_extensions.Literal["is_empty"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_is_null", b"_is_null"]) -> typing_extensions.Literal["is_null"] | None: ...
global___FieldCondition = FieldCondition
class Match(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEYWORD_FIELD_NUMBER: builtins.int
INTEGER_FIELD_NUMBER: builtins.int
BOOLEAN_FIELD_NUMBER: builtins.int
TEXT_FIELD_NUMBER: builtins.int
KEYWORDS_FIELD_NUMBER: builtins.int
INTEGERS_FIELD_NUMBER: builtins.int
EXCEPT_INTEGERS_FIELD_NUMBER: builtins.int
EXCEPT_KEYWORDS_FIELD_NUMBER: builtins.int
PHRASE_FIELD_NUMBER: builtins.int
TEXT_ANY_FIELD_NUMBER: builtins.int
keyword: builtins.str
"""Match string keyword"""
integer: builtins.int
"""Match integer"""
boolean: builtins.bool
"""Match boolean"""
text: builtins.str
"""Match text"""
@property
def keywords(self) -> global___RepeatedStrings:
"""Match multiple keywords"""
@property
def integers(self) -> global___RepeatedIntegers:
"""Match multiple integers"""
@property
def except_integers(self) -> global___RepeatedIntegers:
"""Match any other value except those integers"""
@property
def except_keywords(self) -> global___RepeatedStrings:
"""Match any other value except those keywords"""
phrase: builtins.str
"""Match phrase text"""
text_any: builtins.str
"""Match any word in the text"""
def __init__(
self,
*,
keyword: builtins.str = ...,
integer: builtins.int = ...,
boolean: builtins.bool = ...,
text: builtins.str = ...,
keywords: global___RepeatedStrings | None = ...,
integers: global___RepeatedIntegers | None = ...,
except_integers: global___RepeatedIntegers | None = ...,
except_keywords: global___RepeatedStrings | None = ...,
phrase: builtins.str = ...,
text_any: builtins.str = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["boolean", b"boolean", "except_integers", b"except_integers", "except_keywords", b"except_keywords", "integer", b"integer", "integers", b"integers", "keyword", b"keyword", "keywords", b"keywords", "match_value", b"match_value", "phrase", b"phrase", "text", b"text", "text_any", b"text_any"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["boolean", b"boolean", "except_integers", b"except_integers", "except_keywords", b"except_keywords", "integer", b"integer", "integers", b"integers", "keyword", b"keyword", "keywords", b"keywords", "match_value", b"match_value", "phrase", b"phrase", "text", b"text", "text_any", b"text_any"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["match_value", b"match_value"]) -> typing_extensions.Literal["keyword", "integer", "boolean", "text", "keywords", "integers", "except_integers", "except_keywords", "phrase", "text_any"] | None: ...
global___Match = Match
class RepeatedStrings(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STRINGS_FIELD_NUMBER: builtins.int
@property
def strings(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
strings: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["strings", b"strings"]) -> None: ...
global___RepeatedStrings = RepeatedStrings
class RepeatedIntegers(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
INTEGERS_FIELD_NUMBER: builtins.int
@property
def integers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
def __init__(
self,
*,
integers: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["integers", b"integers"]) -> None: ...
global___RepeatedIntegers = RepeatedIntegers
class Range(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
LT_FIELD_NUMBER: builtins.int
GT_FIELD_NUMBER: builtins.int
GTE_FIELD_NUMBER: builtins.int
LTE_FIELD_NUMBER: builtins.int
lt: builtins.float
gt: builtins.float
gte: builtins.float
lte: builtins.float
def __init__(
self,
*,
lt: builtins.float | None = ...,
gt: builtins.float | None = ...,
gte: builtins.float | None = ...,
lte: builtins.float | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gt", b"_gt"]) -> typing_extensions.Literal["gt"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gte", b"_gte"]) -> typing_extensions.Literal["gte"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lt", b"_lt"]) -> typing_extensions.Literal["lt"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lte", b"_lte"]) -> typing_extensions.Literal["lte"] | None: ...
global___Range = Range
class DatetimeRange(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
LT_FIELD_NUMBER: builtins.int
GT_FIELD_NUMBER: builtins.int
GTE_FIELD_NUMBER: builtins.int
LTE_FIELD_NUMBER: builtins.int
@property
def lt(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
@property
def gt(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
@property
def gte(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
@property
def lte(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
def __init__(
self,
*,
lt: google.protobuf.timestamp_pb2.Timestamp | None = ...,
gt: google.protobuf.timestamp_pb2.Timestamp | None = ...,
gte: google.protobuf.timestamp_pb2.Timestamp | None = ...,
lte: google.protobuf.timestamp_pb2.Timestamp | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gt", b"_gt"]) -> typing_extensions.Literal["gt"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gte", b"_gte"]) -> typing_extensions.Literal["gte"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lt", b"_lt"]) -> typing_extensions.Literal["lt"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lte", b"_lte"]) -> typing_extensions.Literal["lte"] | None: ...
global___DatetimeRange = DatetimeRange
class GeoBoundingBox(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOP_LEFT_FIELD_NUMBER: builtins.int
BOTTOM_RIGHT_FIELD_NUMBER: builtins.int
@property
def top_left(self) -> global___GeoPoint:
"""north-west corner"""
@property
def bottom_right(self) -> global___GeoPoint:
"""south-east corner"""
def __init__(
self,
*,
top_left: global___GeoPoint | None = ...,
bottom_right: global___GeoPoint | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["bottom_right", b"bottom_right", "top_left", b"top_left"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["bottom_right", b"bottom_right", "top_left", b"top_left"]) -> None: ...
global___GeoBoundingBox = GeoBoundingBox
class GeoRadius(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
CENTER_FIELD_NUMBER: builtins.int
RADIUS_FIELD_NUMBER: builtins.int
@property
def center(self) -> global___GeoPoint:
"""Center of the circle"""
radius: builtins.float
"""In meters"""
def __init__(
self,
*,
center: global___GeoPoint | None = ...,
radius: builtins.float = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["center", b"center"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["center", b"center", "radius", b"radius"]) -> None: ...
global___GeoRadius = GeoRadius
class GeoLineString(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
POINTS_FIELD_NUMBER: builtins.int
@property
def points(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GeoPoint]:
"""Ordered sequence of GeoPoints representing the line"""
def __init__(
self,
*,
points: collections.abc.Iterable[global___GeoPoint] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["points", b"points"]) -> None: ...
global___GeoLineString = GeoLineString
class GeoPolygon(google.protobuf.message.Message):
"""For a valid GeoPolygon, both the exterior and interior GeoLineStrings must consist of a minimum of 4 points.
Additionally, the first and last points of each GeoLineString must be the same.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
EXTERIOR_FIELD_NUMBER: builtins.int
INTERIORS_FIELD_NUMBER: builtins.int
@property
def exterior(self) -> global___GeoLineString:
"""The exterior line bounds the surface"""
@property
def interiors(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GeoLineString]:
"""Interior lines (if present) bound holes within the surface"""
def __init__(
self,
*,
exterior: global___GeoLineString | None = ...,
interiors: collections.abc.Iterable[global___GeoLineString] | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["exterior", b"exterior"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["exterior", b"exterior", "interiors", b"interiors"]) -> None: ...
global___GeoPolygon = GeoPolygon
class ValuesCount(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
LT_FIELD_NUMBER: builtins.int
GT_FIELD_NUMBER: builtins.int
GTE_FIELD_NUMBER: builtins.int
LTE_FIELD_NUMBER: builtins.int
lt: builtins.int
gt: builtins.int
gte: builtins.int
lte: builtins.int
def __init__(
self,
*,
lt: builtins.int | None = ...,
gt: builtins.int | None = ...,
gte: builtins.int | None = ...,
lte: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gt", b"_gt"]) -> typing_extensions.Literal["gt"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gte", b"_gte"]) -> typing_extensions.Literal["gte"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lt", b"_lt"]) -> typing_extensions.Literal["lt"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lte", b"_lte"]) -> typing_extensions.Literal["lte"] | None: ...
global___ValuesCount = ValuesCount
@@ -0,0 +1,4 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: qdrant.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from . import collections_service_pb2 as collections__service__pb2
from . import points_service_pb2 as points__service__pb2
from . import snapshots_service_pb2 as snapshots__service__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cqdrant.proto\x12\x06qdrant\x1a\x19\x63ollections_service.proto\x1a\x14points_service.proto\x1a\x17snapshots_service.proto\"\x14\n\x12HealthCheckRequest\"R\n\x10HealthCheckReply\x12\r\n\x05title\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x13\n\x06\x63ommit\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_commit2O\n\x06Qdrant\x12\x45\n\x0bHealthCheck\x12\x1a.qdrant.HealthCheckRequest\x1a\x18.qdrant.HealthCheckReply\"\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'qdrant_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
_globals['_HEALTHCHECKREQUEST']._serialized_start=98
_globals['_HEALTHCHECKREQUEST']._serialized_end=118
_globals['_HEALTHCHECKREPLY']._serialized_start=120
_globals['_HEALTHCHECKREPLY']._serialized_end=202
_globals['_QDRANT']._serialized_start=204
_globals['_QDRANT']._serialized_end=283
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,46 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import google.protobuf.descriptor
import google.protobuf.message
import sys
if sys.version_info >= (3, 8):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class HealthCheckRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___HealthCheckRequest = HealthCheckRequest
class HealthCheckReply(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TITLE_FIELD_NUMBER: builtins.int
VERSION_FIELD_NUMBER: builtins.int
COMMIT_FIELD_NUMBER: builtins.int
title: builtins.str
version: builtins.str
commit: builtins.str
def __init__(
self,
*,
title: builtins.str = ...,
version: builtins.str = ...,
commit: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_commit", b"_commit", "commit", b"commit"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_commit", b"_commit", "commit", b"commit", "title", b"title", "version", b"version"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["_commit", b"_commit"]) -> typing_extensions.Literal["commit"] | None: ...
global___HealthCheckReply = HealthCheckReply
@@ -0,0 +1,66 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from . import qdrant_pb2 as qdrant__pb2
class QdrantStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.HealthCheck = channel.unary_unary(
'/qdrant.Qdrant/HealthCheck',
request_serializer=qdrant__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=qdrant__pb2.HealthCheckReply.FromString,
)
class QdrantServicer(object):
"""Missing associated documentation comment in .proto file."""
def HealthCheck(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_QdrantServicer_to_server(servicer, server):
rpc_method_handlers = {
'HealthCheck': grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=qdrant__pb2.HealthCheckRequest.FromString,
response_serializer=qdrant__pb2.HealthCheckReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'qdrant.Qdrant', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class Qdrant(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def HealthCheck(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Qdrant/HealthCheck',
qdrant__pb2.HealthCheckRequest.SerializeToString,
qdrant__pb2.HealthCheckReply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: snapshots_service.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17snapshots_service.proto\x12\x06qdrant\x1a\x1fgoogle/protobuf/timestamp.proto\"\x1b\n\x19\x43reateFullSnapshotRequest\"\x1a\n\x18ListFullSnapshotsRequest\"2\n\x19\x44\x65leteFullSnapshotRequest\x12\x15\n\rsnapshot_name\x18\x01 \x01(\t\"0\n\x15\x43reateSnapshotRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\"/\n\x14ListSnapshotsRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\"G\n\x15\x44\x65leteSnapshotRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\x12\x15\n\rsnapshot_name\x18\x02 \x01(\t\"\x88\x01\n\x13SnapshotDescription\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\rcreation_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04size\x18\x03 \x01(\x03\x12\x15\n\x08\x63hecksum\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\x0b\n\t_checksum\"a\n\x16\x43reateSnapshotResponse\x12\x39\n\x14snapshot_description\x18\x01 \x01(\x0b\x32\x1b.qdrant.SnapshotDescription\x12\x0c\n\x04time\x18\x02 \x01(\x01\"a\n\x15ListSnapshotsResponse\x12:\n\x15snapshot_descriptions\x18\x01 \x03(\x0b\x32\x1b.qdrant.SnapshotDescription\x12\x0c\n\x04time\x18\x02 \x01(\x01\"&\n\x16\x44\x65leteSnapshotResponse\x12\x0c\n\x04time\x18\x01 \x01(\x01\x32\xdd\x03\n\tSnapshots\x12I\n\x06\x43reate\x12\x1d.qdrant.CreateSnapshotRequest\x1a\x1e.qdrant.CreateSnapshotResponse\"\x00\x12\x45\n\x04List\x12\x1c.qdrant.ListSnapshotsRequest\x1a\x1d.qdrant.ListSnapshotsResponse\"\x00\x12I\n\x06\x44\x65lete\x12\x1d.qdrant.DeleteSnapshotRequest\x1a\x1e.qdrant.DeleteSnapshotResponse\"\x00\x12Q\n\nCreateFull\x12!.qdrant.CreateFullSnapshotRequest\x1a\x1e.qdrant.CreateSnapshotResponse\"\x00\x12M\n\x08ListFull\x12 .qdrant.ListFullSnapshotsRequest\x1a\x1d.qdrant.ListSnapshotsResponse\"\x00\x12Q\n\nDeleteFull\x12!.qdrant.DeleteFullSnapshotRequest\x1a\x1e.qdrant.DeleteSnapshotResponse\"\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'snapshots_service_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
_globals['_CREATEFULLSNAPSHOTREQUEST']._serialized_start=68
_globals['_CREATEFULLSNAPSHOTREQUEST']._serialized_end=95
_globals['_LISTFULLSNAPSHOTSREQUEST']._serialized_start=97
_globals['_LISTFULLSNAPSHOTSREQUEST']._serialized_end=123
_globals['_DELETEFULLSNAPSHOTREQUEST']._serialized_start=125
_globals['_DELETEFULLSNAPSHOTREQUEST']._serialized_end=175
_globals['_CREATESNAPSHOTREQUEST']._serialized_start=177
_globals['_CREATESNAPSHOTREQUEST']._serialized_end=225
_globals['_LISTSNAPSHOTSREQUEST']._serialized_start=227
_globals['_LISTSNAPSHOTSREQUEST']._serialized_end=274
_globals['_DELETESNAPSHOTREQUEST']._serialized_start=276
_globals['_DELETESNAPSHOTREQUEST']._serialized_end=347
_globals['_SNAPSHOTDESCRIPTION']._serialized_start=350
_globals['_SNAPSHOTDESCRIPTION']._serialized_end=486
_globals['_CREATESNAPSHOTRESPONSE']._serialized_start=488
_globals['_CREATESNAPSHOTRESPONSE']._serialized_end=585
_globals['_LISTSNAPSHOTSRESPONSE']._serialized_start=587
_globals['_LISTSNAPSHOTSRESPONSE']._serialized_end=684
_globals['_DELETESNAPSHOTRESPONSE']._serialized_start=686
_globals['_DELETESNAPSHOTRESPONSE']._serialized_end=724
_globals['_SNAPSHOTS']._serialized_start=727
_globals['_SNAPSHOTS']._serialized_end=1204
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,184 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import google.protobuf.timestamp_pb2
import sys
if sys.version_info >= (3, 8):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class CreateFullSnapshotRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___CreateFullSnapshotRequest = CreateFullSnapshotRequest
class ListFullSnapshotsRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___ListFullSnapshotsRequest = ListFullSnapshotsRequest
class DeleteFullSnapshotRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SNAPSHOT_NAME_FIELD_NUMBER: builtins.int
snapshot_name: builtins.str
"""Name of the full snapshot"""
def __init__(
self,
*,
snapshot_name: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["snapshot_name", b"snapshot_name"]) -> None: ...
global___DeleteFullSnapshotRequest = DeleteFullSnapshotRequest
class CreateSnapshotRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
COLLECTION_NAME_FIELD_NUMBER: builtins.int
collection_name: builtins.str
"""Name of the collection"""
def __init__(
self,
*,
collection_name: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["collection_name", b"collection_name"]) -> None: ...
global___CreateSnapshotRequest = CreateSnapshotRequest
class ListSnapshotsRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
COLLECTION_NAME_FIELD_NUMBER: builtins.int
collection_name: builtins.str
"""Name of the collection"""
def __init__(
self,
*,
collection_name: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["collection_name", b"collection_name"]) -> None: ...
global___ListSnapshotsRequest = ListSnapshotsRequest
class DeleteSnapshotRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
COLLECTION_NAME_FIELD_NUMBER: builtins.int
SNAPSHOT_NAME_FIELD_NUMBER: builtins.int
collection_name: builtins.str
"""Name of the collection"""
snapshot_name: builtins.str
"""Name of the collection snapshot"""
def __init__(
self,
*,
collection_name: builtins.str = ...,
snapshot_name: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["collection_name", b"collection_name", "snapshot_name", b"snapshot_name"]) -> None: ...
global___DeleteSnapshotRequest = DeleteSnapshotRequest
class SnapshotDescription(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NAME_FIELD_NUMBER: builtins.int
CREATION_TIME_FIELD_NUMBER: builtins.int
SIZE_FIELD_NUMBER: builtins.int
CHECKSUM_FIELD_NUMBER: builtins.int
name: builtins.str
"""Name of the snapshot"""
@property
def creation_time(self) -> google.protobuf.timestamp_pb2.Timestamp:
"""Creation time of the snapshot"""
size: builtins.int
"""Size of the snapshot in bytes"""
checksum: builtins.str
"""SHA256 digest of the snapshot file"""
def __init__(
self,
*,
name: builtins.str = ...,
creation_time: google.protobuf.timestamp_pb2.Timestamp | None = ...,
size: builtins.int = ...,
checksum: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_checksum", b"_checksum", "checksum", b"checksum", "creation_time", b"creation_time"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_checksum", b"_checksum", "checksum", b"checksum", "creation_time", b"creation_time", "name", b"name", "size", b"size"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["_checksum", b"_checksum"]) -> typing_extensions.Literal["checksum"] | None: ...
global___SnapshotDescription = SnapshotDescription
class CreateSnapshotResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SNAPSHOT_DESCRIPTION_FIELD_NUMBER: builtins.int
TIME_FIELD_NUMBER: builtins.int
@property
def snapshot_description(self) -> global___SnapshotDescription: ...
time: builtins.float
"""Time spent to process"""
def __init__(
self,
*,
snapshot_description: global___SnapshotDescription | None = ...,
time: builtins.float = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["snapshot_description", b"snapshot_description"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["snapshot_description", b"snapshot_description", "time", b"time"]) -> None: ...
global___CreateSnapshotResponse = CreateSnapshotResponse
class ListSnapshotsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SNAPSHOT_DESCRIPTIONS_FIELD_NUMBER: builtins.int
TIME_FIELD_NUMBER: builtins.int
@property
def snapshot_descriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___SnapshotDescription]: ...
time: builtins.float
"""Time spent to process"""
def __init__(
self,
*,
snapshot_descriptions: collections.abc.Iterable[global___SnapshotDescription] | None = ...,
time: builtins.float = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["snapshot_descriptions", b"snapshot_descriptions", "time", b"time"]) -> None: ...
global___ListSnapshotsResponse = ListSnapshotsResponse
class DeleteSnapshotResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TIME_FIELD_NUMBER: builtins.int
time: builtins.float
"""Time spent to process"""
def __init__(
self,
*,
time: builtins.float = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["time", b"time"]) -> None: ...
global___DeleteSnapshotResponse = DeleteSnapshotResponse
@@ -0,0 +1,243 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from . import snapshots_service_pb2 as snapshots__service__pb2
class SnapshotsStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Create = channel.unary_unary(
'/qdrant.Snapshots/Create',
request_serializer=snapshots__service__pb2.CreateSnapshotRequest.SerializeToString,
response_deserializer=snapshots__service__pb2.CreateSnapshotResponse.FromString,
)
self.List = channel.unary_unary(
'/qdrant.Snapshots/List',
request_serializer=snapshots__service__pb2.ListSnapshotsRequest.SerializeToString,
response_deserializer=snapshots__service__pb2.ListSnapshotsResponse.FromString,
)
self.Delete = channel.unary_unary(
'/qdrant.Snapshots/Delete',
request_serializer=snapshots__service__pb2.DeleteSnapshotRequest.SerializeToString,
response_deserializer=snapshots__service__pb2.DeleteSnapshotResponse.FromString,
)
self.CreateFull = channel.unary_unary(
'/qdrant.Snapshots/CreateFull',
request_serializer=snapshots__service__pb2.CreateFullSnapshotRequest.SerializeToString,
response_deserializer=snapshots__service__pb2.CreateSnapshotResponse.FromString,
)
self.ListFull = channel.unary_unary(
'/qdrant.Snapshots/ListFull',
request_serializer=snapshots__service__pb2.ListFullSnapshotsRequest.SerializeToString,
response_deserializer=snapshots__service__pb2.ListSnapshotsResponse.FromString,
)
self.DeleteFull = channel.unary_unary(
'/qdrant.Snapshots/DeleteFull',
request_serializer=snapshots__service__pb2.DeleteFullSnapshotRequest.SerializeToString,
response_deserializer=snapshots__service__pb2.DeleteSnapshotResponse.FromString,
)
class SnapshotsServicer(object):
"""Missing associated documentation comment in .proto file."""
def Create(self, request, context):
"""
Create collection snapshot
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def List(self, request, context):
"""
List collection snapshots
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Delete(self, request, context):
"""
Delete collection snapshot
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CreateFull(self, request, context):
"""
Create full storage snapshot
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ListFull(self, request, context):
"""
List full storage snapshots
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def DeleteFull(self, request, context):
"""
Delete full storage snapshot
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_SnapshotsServicer_to_server(servicer, server):
rpc_method_handlers = {
'Create': grpc.unary_unary_rpc_method_handler(
servicer.Create,
request_deserializer=snapshots__service__pb2.CreateSnapshotRequest.FromString,
response_serializer=snapshots__service__pb2.CreateSnapshotResponse.SerializeToString,
),
'List': grpc.unary_unary_rpc_method_handler(
servicer.List,
request_deserializer=snapshots__service__pb2.ListSnapshotsRequest.FromString,
response_serializer=snapshots__service__pb2.ListSnapshotsResponse.SerializeToString,
),
'Delete': grpc.unary_unary_rpc_method_handler(
servicer.Delete,
request_deserializer=snapshots__service__pb2.DeleteSnapshotRequest.FromString,
response_serializer=snapshots__service__pb2.DeleteSnapshotResponse.SerializeToString,
),
'CreateFull': grpc.unary_unary_rpc_method_handler(
servicer.CreateFull,
request_deserializer=snapshots__service__pb2.CreateFullSnapshotRequest.FromString,
response_serializer=snapshots__service__pb2.CreateSnapshotResponse.SerializeToString,
),
'ListFull': grpc.unary_unary_rpc_method_handler(
servicer.ListFull,
request_deserializer=snapshots__service__pb2.ListFullSnapshotsRequest.FromString,
response_serializer=snapshots__service__pb2.ListSnapshotsResponse.SerializeToString,
),
'DeleteFull': grpc.unary_unary_rpc_method_handler(
servicer.DeleteFull,
request_deserializer=snapshots__service__pb2.DeleteFullSnapshotRequest.FromString,
response_serializer=snapshots__service__pb2.DeleteSnapshotResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'qdrant.Snapshots', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class Snapshots(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Create(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/Create',
snapshots__service__pb2.CreateSnapshotRequest.SerializeToString,
snapshots__service__pb2.CreateSnapshotResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def List(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/List',
snapshots__service__pb2.ListSnapshotsRequest.SerializeToString,
snapshots__service__pb2.ListSnapshotsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Delete(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/Delete',
snapshots__service__pb2.DeleteSnapshotRequest.SerializeToString,
snapshots__service__pb2.DeleteSnapshotResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def CreateFull(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/CreateFull',
snapshots__service__pb2.CreateFullSnapshotRequest.SerializeToString,
snapshots__service__pb2.CreateSnapshotResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def ListFull(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/ListFull',
snapshots__service__pb2.ListFullSnapshotsRequest.SerializeToString,
snapshots__service__pb2.ListSnapshotsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def DeleteFull(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/DeleteFull',
snapshots__service__pb2.DeleteFullSnapshotRequest.SerializeToString,
snapshots__service__pb2.DeleteSnapshotResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -0,0 +1,17 @@
import inspect
from pydantic import BaseModel
from qdrant_client._pydantic_compat import update_forward_refs
from qdrant_client.http.api_client import ( # noqa F401
ApiClient as ApiClient,
AsyncApiClient as AsyncApiClient,
AsyncApis as AsyncApis,
SyncApis as SyncApis,
)
from qdrant_client.http.models import models as models # noqa F401
for model in inspect.getmembers(models, inspect.isclass):
if model[1].__module__ == "qdrant_client.http.models.models":
model_class = model[1]
if issubclass(model_class, BaseModel):
update_forward_refs(model_class)
@@ -0,0 +1,171 @@
# flake8: noqa E501
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
from qdrant_client.http.models import models as m
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _AliasesApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_get_collection_aliases(
self,
collection_name: str,
):
"""
Get list of all aliases for a collection
"""
path_params = {
"collection_name": str(collection_name),
}
headers = {}
return self.api_client.request(
type_=m.InlineResponse2008,
method="GET",
url="/collections/{collection_name}/aliases",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_get_collections_aliases(
self,
):
"""
Get list of all existing collections aliases
"""
headers = {}
return self.api_client.request(
type_=m.InlineResponse2008,
method="GET",
url="/aliases",
headers=headers if headers else None,
)
def _build_for_update_aliases(
self,
timeout: int = None,
change_aliases_operation: m.ChangeAliasesOperation = None,
):
query_params = {}
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(change_aliases_operation)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse200,
method="POST",
url="/collections/aliases",
headers=headers if headers else None,
params=query_params,
content=body,
)
class AsyncAliasesApi(_AliasesApi):
async def get_collection_aliases(
self,
collection_name: str,
) -> m.InlineResponse2008:
"""
Get list of all aliases for a collection
"""
return await self._build_for_get_collection_aliases(
collection_name=collection_name,
)
async def get_collections_aliases(
self,
) -> m.InlineResponse2008:
"""
Get list of all existing collections aliases
"""
return await self._build_for_get_collections_aliases()
async def update_aliases(
self,
timeout: int = None,
change_aliases_operation: m.ChangeAliasesOperation = None,
) -> m.InlineResponse200:
return await self._build_for_update_aliases(
timeout=timeout,
change_aliases_operation=change_aliases_operation,
)
class SyncAliasesApi(_AliasesApi):
def get_collection_aliases(
self,
collection_name: str,
) -> m.InlineResponse2008:
"""
Get list of all aliases for a collection
"""
return self._build_for_get_collection_aliases(
collection_name=collection_name,
)
def get_collections_aliases(
self,
) -> m.InlineResponse2008:
"""
Get list of all existing collections aliases
"""
return self._build_for_get_collections_aliases()
def update_aliases(
self,
timeout: int = None,
change_aliases_operation: m.ChangeAliasesOperation = None,
) -> m.InlineResponse200:
return self._build_for_update_aliases(
timeout=timeout,
change_aliases_operation=change_aliases_operation,
)
@@ -0,0 +1,116 @@
# flake8: noqa E501
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _BetaApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_clear_issues(
self,
):
"""
Removes all issues reported so far
"""
headers = {}
return self.api_client.request(
type_=bool,
method="DELETE",
url="/issues",
headers=headers if headers else None,
)
def _build_for_get_issues(
self,
):
"""
Get a report of performance issues and configuration suggestions
"""
headers = {}
return self.api_client.request(
type_=object,
method="GET",
url="/issues",
headers=headers if headers else None,
)
class AsyncBetaApi(_BetaApi):
async def clear_issues(
self,
) -> bool:
"""
Removes all issues reported so far
"""
return await self._build_for_clear_issues()
async def get_issues(
self,
) -> object:
"""
Get a report of performance issues and configuration suggestions
"""
return await self._build_for_get_issues()
class SyncBetaApi(_BetaApi):
def clear_issues(
self,
) -> bool:
"""
Removes all issues reported so far
"""
return self._build_for_clear_issues()
def get_issues(
self,
) -> object:
"""
Get a report of performance issues and configuration suggestions
"""
return self._build_for_get_issues()
@@ -0,0 +1,345 @@
# flake8: noqa E501
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
from qdrant_client.http.models import models as m
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _CollectionsApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_collection_exists(
self,
collection_name: str,
):
"""
Returns \"true\" if the given collection name exists, and \"false\" otherwise
"""
path_params = {
"collection_name": str(collection_name),
}
headers = {}
return self.api_client.request(
type_=m.InlineResponse2006,
method="GET",
url="/collections/{collection_name}/exists",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_create_collection(
self,
collection_name: str,
timeout: int = None,
create_collection: m.CreateCollection = None,
):
"""
Create new collection with given parameters
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(create_collection)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse200,
method="PUT",
url="/collections/{collection_name}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_delete_collection(
self,
collection_name: str,
timeout: int = None,
):
"""
Drop collection and all associated data
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
return self.api_client.request(
type_=m.InlineResponse200,
method="DELETE",
url="/collections/{collection_name}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
def _build_for_get_collection(
self,
collection_name: str,
):
"""
Get detailed information about specified existing collection
"""
path_params = {
"collection_name": str(collection_name),
}
headers = {}
return self.api_client.request(
type_=m.InlineResponse2004,
method="GET",
url="/collections/{collection_name}",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_get_collections(
self,
):
"""
Get list name of all existing collections
"""
headers = {}
return self.api_client.request(
type_=m.InlineResponse2003,
method="GET",
url="/collections",
headers=headers if headers else None,
)
def _build_for_update_collection(
self,
collection_name: str,
timeout: int = None,
update_collection: m.UpdateCollection = None,
):
"""
Update parameters of the existing collection
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(update_collection)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse200,
method="PATCH",
url="/collections/{collection_name}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
class AsyncCollectionsApi(_CollectionsApi):
async def collection_exists(
self,
collection_name: str,
) -> m.InlineResponse2006:
"""
Returns \"true\" if the given collection name exists, and \"false\" otherwise
"""
return await self._build_for_collection_exists(
collection_name=collection_name,
)
async def create_collection(
self,
collection_name: str,
timeout: int = None,
create_collection: m.CreateCollection = None,
) -> m.InlineResponse200:
"""
Create new collection with given parameters
"""
return await self._build_for_create_collection(
collection_name=collection_name,
timeout=timeout,
create_collection=create_collection,
)
async def delete_collection(
self,
collection_name: str,
timeout: int = None,
) -> m.InlineResponse200:
"""
Drop collection and all associated data
"""
return await self._build_for_delete_collection(
collection_name=collection_name,
timeout=timeout,
)
async def get_collection(
self,
collection_name: str,
) -> m.InlineResponse2004:
"""
Get detailed information about specified existing collection
"""
return await self._build_for_get_collection(
collection_name=collection_name,
)
async def get_collections(
self,
) -> m.InlineResponse2003:
"""
Get list name of all existing collections
"""
return await self._build_for_get_collections()
async def update_collection(
self,
collection_name: str,
timeout: int = None,
update_collection: m.UpdateCollection = None,
) -> m.InlineResponse200:
"""
Update parameters of the existing collection
"""
return await self._build_for_update_collection(
collection_name=collection_name,
timeout=timeout,
update_collection=update_collection,
)
class SyncCollectionsApi(_CollectionsApi):
def collection_exists(
self,
collection_name: str,
) -> m.InlineResponse2006:
"""
Returns \"true\" if the given collection name exists, and \"false\" otherwise
"""
return self._build_for_collection_exists(
collection_name=collection_name,
)
def create_collection(
self,
collection_name: str,
timeout: int = None,
create_collection: m.CreateCollection = None,
) -> m.InlineResponse200:
"""
Create new collection with given parameters
"""
return self._build_for_create_collection(
collection_name=collection_name,
timeout=timeout,
create_collection=create_collection,
)
def delete_collection(
self,
collection_name: str,
timeout: int = None,
) -> m.InlineResponse200:
"""
Drop collection and all associated data
"""
return self._build_for_delete_collection(
collection_name=collection_name,
timeout=timeout,
)
def get_collection(
self,
collection_name: str,
) -> m.InlineResponse2004:
"""
Get detailed information about specified existing collection
"""
return self._build_for_get_collection(
collection_name=collection_name,
)
def get_collections(
self,
) -> m.InlineResponse2003:
"""
Get list name of all existing collections
"""
return self._build_for_get_collections()
def update_collection(
self,
collection_name: str,
timeout: int = None,
update_collection: m.UpdateCollection = None,
) -> m.InlineResponse200:
"""
Update parameters of the existing collection
"""
return self._build_for_update_collection(
collection_name=collection_name,
timeout=timeout,
update_collection=update_collection,
)
@@ -0,0 +1,365 @@
# flake8: noqa E501
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
from qdrant_client.http.models import models as m
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _DistributedApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_cluster_status(
self,
):
"""
Get information about the current state and composition of the cluster
"""
headers = {}
return self.api_client.request(
type_=m.InlineResponse2002,
method="GET",
url="/cluster",
headers=headers if headers else None,
)
def _build_for_collection_cluster_info(
self,
collection_name: str,
):
"""
Get cluster information for a collection
"""
path_params = {
"collection_name": str(collection_name),
}
headers = {}
return self.api_client.request(
type_=m.InlineResponse2007,
method="GET",
url="/collections/{collection_name}/cluster",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_create_shard_key(
self,
collection_name: str,
timeout: int = None,
create_sharding_key: m.CreateShardingKey = None,
):
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(create_sharding_key)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse200,
method="PUT",
url="/collections/{collection_name}/shards",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_delete_shard_key(
self,
collection_name: str,
timeout: int = None,
drop_sharding_key: m.DropShardingKey = None,
):
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(drop_sharding_key)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse200,
method="POST",
url="/collections/{collection_name}/shards/delete",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_recover_current_peer(
self,
):
headers = {}
return self.api_client.request(
type_=m.InlineResponse200,
method="POST",
url="/cluster/recover",
headers=headers if headers else None,
)
def _build_for_remove_peer(
self,
peer_id: int,
timeout: int = None,
force: bool = None,
):
"""
Tries to remove peer from the cluster. Will return an error if peer has shards on it.
"""
path_params = {
"peer_id": str(peer_id),
}
query_params = {}
if timeout is not None:
query_params["timeout"] = str(timeout)
if force is not None:
query_params["force"] = str(force).lower()
headers = {}
return self.api_client.request(
type_=m.InlineResponse200,
method="DELETE",
url="/cluster/peer/{peer_id}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
def _build_for_update_collection_cluster(
self,
collection_name: str,
timeout: int = None,
cluster_operations: m.ClusterOperations = None,
):
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(cluster_operations)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse200,
method="POST",
url="/collections/{collection_name}/cluster",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
class AsyncDistributedApi(_DistributedApi):
async def cluster_status(
self,
) -> m.InlineResponse2002:
"""
Get information about the current state and composition of the cluster
"""
return await self._build_for_cluster_status()
async def collection_cluster_info(
self,
collection_name: str,
) -> m.InlineResponse2007:
"""
Get cluster information for a collection
"""
return await self._build_for_collection_cluster_info(
collection_name=collection_name,
)
async def create_shard_key(
self,
collection_name: str,
timeout: int = None,
create_sharding_key: m.CreateShardingKey = None,
) -> m.InlineResponse200:
return await self._build_for_create_shard_key(
collection_name=collection_name,
timeout=timeout,
create_sharding_key=create_sharding_key,
)
async def delete_shard_key(
self,
collection_name: str,
timeout: int = None,
drop_sharding_key: m.DropShardingKey = None,
) -> m.InlineResponse200:
return await self._build_for_delete_shard_key(
collection_name=collection_name,
timeout=timeout,
drop_sharding_key=drop_sharding_key,
)
async def recover_current_peer(
self,
) -> m.InlineResponse200:
return await self._build_for_recover_current_peer()
async def remove_peer(
self,
peer_id: int,
timeout: int = None,
force: bool = None,
) -> m.InlineResponse200:
"""
Tries to remove peer from the cluster. Will return an error if peer has shards on it.
"""
return await self._build_for_remove_peer(
peer_id=peer_id,
timeout=timeout,
force=force,
)
async def update_collection_cluster(
self,
collection_name: str,
timeout: int = None,
cluster_operations: m.ClusterOperations = None,
) -> m.InlineResponse200:
return await self._build_for_update_collection_cluster(
collection_name=collection_name,
timeout=timeout,
cluster_operations=cluster_operations,
)
class SyncDistributedApi(_DistributedApi):
def cluster_status(
self,
) -> m.InlineResponse2002:
"""
Get information about the current state and composition of the cluster
"""
return self._build_for_cluster_status()
def collection_cluster_info(
self,
collection_name: str,
) -> m.InlineResponse2007:
"""
Get cluster information for a collection
"""
return self._build_for_collection_cluster_info(
collection_name=collection_name,
)
def create_shard_key(
self,
collection_name: str,
timeout: int = None,
create_sharding_key: m.CreateShardingKey = None,
) -> m.InlineResponse200:
return self._build_for_create_shard_key(
collection_name=collection_name,
timeout=timeout,
create_sharding_key=create_sharding_key,
)
def delete_shard_key(
self,
collection_name: str,
timeout: int = None,
drop_sharding_key: m.DropShardingKey = None,
) -> m.InlineResponse200:
return self._build_for_delete_shard_key(
collection_name=collection_name,
timeout=timeout,
drop_sharding_key=drop_sharding_key,
)
def recover_current_peer(
self,
) -> m.InlineResponse200:
return self._build_for_recover_current_peer()
def remove_peer(
self,
peer_id: int,
timeout: int = None,
force: bool = None,
) -> m.InlineResponse200:
"""
Tries to remove peer from the cluster. Will return an error if peer has shards on it.
"""
return self._build_for_remove_peer(
peer_id=peer_id,
timeout=timeout,
force=force,
)
def update_collection_cluster(
self,
collection_name: str,
timeout: int = None,
cluster_operations: m.ClusterOperations = None,
) -> m.InlineResponse200:
return self._build_for_update_collection_cluster(
collection_name=collection_name,
timeout=timeout,
cluster_operations=cluster_operations,
)
@@ -0,0 +1,190 @@
# flake8: noqa E501
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
from qdrant_client.http.models import models as m
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _IndexesApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_create_field_index(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
create_field_index: m.CreateFieldIndex = None,
):
"""
Create index for field in collection
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(create_field_index)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="PUT",
url="/collections/{collection_name}/index",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_delete_field_index(
self,
collection_name: str,
field_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
):
"""
Delete field index for collection
"""
path_params = {
"collection_name": str(collection_name),
"field_name": str(field_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
return self.api_client.request(
type_=m.InlineResponse2005,
method="DELETE",
url="/collections/{collection_name}/index/{field_name}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
class AsyncIndexesApi(_IndexesApi):
async def create_field_index(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
create_field_index: m.CreateFieldIndex = None,
) -> m.InlineResponse2005:
"""
Create index for field in collection
"""
return await self._build_for_create_field_index(
collection_name=collection_name,
wait=wait,
ordering=ordering,
create_field_index=create_field_index,
)
async def delete_field_index(
self,
collection_name: str,
field_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
) -> m.InlineResponse2005:
"""
Delete field index for collection
"""
return await self._build_for_delete_field_index(
collection_name=collection_name,
field_name=field_name,
wait=wait,
ordering=ordering,
)
class SyncIndexesApi(_IndexesApi):
def create_field_index(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
create_field_index: m.CreateFieldIndex = None,
) -> m.InlineResponse2005:
"""
Create index for field in collection
"""
return self._build_for_create_field_index(
collection_name=collection_name,
wait=wait,
ordering=ordering,
create_field_index=create_field_index,
)
def delete_field_index(
self,
collection_name: str,
field_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
) -> m.InlineResponse2005:
"""
Delete field index for collection
"""
return self._build_for_delete_field_index(
collection_name=collection_name,
field_name=field_name,
wait=wait,
ordering=ordering,
)
@@ -0,0 +1,999 @@
# flake8: noqa E501
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
from qdrant_client.http.models import models as m
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _PointsApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_batch_update(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
update_operations: m.UpdateOperations = None,
):
"""
Apply a series of update operations for points, vectors and payloads
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(update_operations)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20014,
method="POST",
url="/collections/{collection_name}/points/batch",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_clear_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
points_selector: m.PointsSelector = None,
):
"""
Remove all payload for specified points
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(points_selector)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="POST",
url="/collections/{collection_name}/points/payload/clear",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_count_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
count_request: m.CountRequest = None,
):
"""
Count points which matches given filtering condition
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(count_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20019,
method="POST",
url="/collections/{collection_name}/points/count",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_delete_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
delete_payload: m.DeletePayload = None,
):
"""
Delete specified key payload for points
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(delete_payload)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="POST",
url="/collections/{collection_name}/points/payload/delete",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_delete_points(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
points_selector: m.PointsSelector = None,
):
"""
Delete points
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(points_selector)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="POST",
url="/collections/{collection_name}/points/delete",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_delete_vectors(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
delete_vectors: m.DeleteVectors = None,
):
"""
Delete named vectors from the given points.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(delete_vectors)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="POST",
url="/collections/{collection_name}/points/vectors/delete",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_facet(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
facet_request: m.FacetRequest = None,
):
"""
Count points that satisfy the given filter for each unique value of a payload key.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(facet_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20020,
method="POST",
url="/collections/{collection_name}/facet",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_get_point(
self,
collection_name: str,
id: m.ExtendedPointId,
consistency: m.ReadConsistency = None,
):
"""
Retrieve full information of single point by id
"""
path_params = {
"collection_name": str(collection_name),
"id": str(id),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
headers = {}
return self.api_client.request(
type_=m.InlineResponse20012,
method="GET",
url="/collections/{collection_name}/points/{id}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
def _build_for_get_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
point_request: m.PointRequest = None,
):
"""
Retrieve multiple points by specified IDs
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(point_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20013,
method="POST",
url="/collections/{collection_name}/points",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_overwrite_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
set_payload: m.SetPayload = None,
):
"""
Replace full payload of points with new one
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(set_payload)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="PUT",
url="/collections/{collection_name}/points/payload",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_scroll_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
scroll_request: m.ScrollRequest = None,
):
"""
Scroll request - paginate over all points which matches given filtering condition
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(scroll_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20015,
method="POST",
url="/collections/{collection_name}/points/scroll",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_set_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
set_payload: m.SetPayload = None,
):
"""
Set payload values for points
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(set_payload)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="POST",
url="/collections/{collection_name}/points/payload",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_update_vectors(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
update_vectors: m.UpdateVectors = None,
):
"""
Update specified named vectors on points, keep unspecified vectors intact.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(update_vectors)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="PUT",
url="/collections/{collection_name}/points/vectors",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_upsert_points(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
point_insert_operations: m.PointInsertOperations = None,
):
"""
Perform insert + updates on points. If point with given ID already exists - it will be overwritten.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if ordering is not None:
query_params["ordering"] = str(ordering)
headers = {}
body = jsonable_encoder(point_insert_operations)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2005,
method="PUT",
url="/collections/{collection_name}/points",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
class AsyncPointsApi(_PointsApi):
async def batch_update(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
update_operations: m.UpdateOperations = None,
) -> m.InlineResponse20014:
"""
Apply a series of update operations for points, vectors and payloads
"""
return await self._build_for_batch_update(
collection_name=collection_name,
wait=wait,
ordering=ordering,
update_operations=update_operations,
)
async def clear_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
points_selector: m.PointsSelector = None,
) -> m.InlineResponse2005:
"""
Remove all payload for specified points
"""
return await self._build_for_clear_payload(
collection_name=collection_name,
wait=wait,
ordering=ordering,
points_selector=points_selector,
)
async def count_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
count_request: m.CountRequest = None,
) -> m.InlineResponse20019:
"""
Count points which matches given filtering condition
"""
return await self._build_for_count_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
count_request=count_request,
)
async def delete_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
delete_payload: m.DeletePayload = None,
) -> m.InlineResponse2005:
"""
Delete specified key payload for points
"""
return await self._build_for_delete_payload(
collection_name=collection_name,
wait=wait,
ordering=ordering,
delete_payload=delete_payload,
)
async def delete_points(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
points_selector: m.PointsSelector = None,
) -> m.InlineResponse2005:
"""
Delete points
"""
return await self._build_for_delete_points(
collection_name=collection_name,
wait=wait,
ordering=ordering,
points_selector=points_selector,
)
async def delete_vectors(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
delete_vectors: m.DeleteVectors = None,
) -> m.InlineResponse2005:
"""
Delete named vectors from the given points.
"""
return await self._build_for_delete_vectors(
collection_name=collection_name,
wait=wait,
ordering=ordering,
delete_vectors=delete_vectors,
)
async def facet(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
facet_request: m.FacetRequest = None,
) -> m.InlineResponse20020:
"""
Count points that satisfy the given filter for each unique value of a payload key.
"""
return await self._build_for_facet(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
facet_request=facet_request,
)
async def get_point(
self,
collection_name: str,
id: m.ExtendedPointId,
consistency: m.ReadConsistency = None,
) -> m.InlineResponse20012:
"""
Retrieve full information of single point by id
"""
return await self._build_for_get_point(
collection_name=collection_name,
id=id,
consistency=consistency,
)
async def get_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
point_request: m.PointRequest = None,
) -> m.InlineResponse20013:
"""
Retrieve multiple points by specified IDs
"""
return await self._build_for_get_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
point_request=point_request,
)
async def overwrite_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
set_payload: m.SetPayload = None,
) -> m.InlineResponse2005:
"""
Replace full payload of points with new one
"""
return await self._build_for_overwrite_payload(
collection_name=collection_name,
wait=wait,
ordering=ordering,
set_payload=set_payload,
)
async def scroll_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
scroll_request: m.ScrollRequest = None,
) -> m.InlineResponse20015:
"""
Scroll request - paginate over all points which matches given filtering condition
"""
return await self._build_for_scroll_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
scroll_request=scroll_request,
)
async def set_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
set_payload: m.SetPayload = None,
) -> m.InlineResponse2005:
"""
Set payload values for points
"""
return await self._build_for_set_payload(
collection_name=collection_name,
wait=wait,
ordering=ordering,
set_payload=set_payload,
)
async def update_vectors(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
update_vectors: m.UpdateVectors = None,
) -> m.InlineResponse2005:
"""
Update specified named vectors on points, keep unspecified vectors intact.
"""
return await self._build_for_update_vectors(
collection_name=collection_name,
wait=wait,
ordering=ordering,
update_vectors=update_vectors,
)
async def upsert_points(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
point_insert_operations: m.PointInsertOperations = None,
) -> m.InlineResponse2005:
"""
Perform insert + updates on points. If point with given ID already exists - it will be overwritten.
"""
return await self._build_for_upsert_points(
collection_name=collection_name,
wait=wait,
ordering=ordering,
point_insert_operations=point_insert_operations,
)
class SyncPointsApi(_PointsApi):
def batch_update(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
update_operations: m.UpdateOperations = None,
) -> m.InlineResponse20014:
"""
Apply a series of update operations for points, vectors and payloads
"""
return self._build_for_batch_update(
collection_name=collection_name,
wait=wait,
ordering=ordering,
update_operations=update_operations,
)
def clear_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
points_selector: m.PointsSelector = None,
) -> m.InlineResponse2005:
"""
Remove all payload for specified points
"""
return self._build_for_clear_payload(
collection_name=collection_name,
wait=wait,
ordering=ordering,
points_selector=points_selector,
)
def count_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
count_request: m.CountRequest = None,
) -> m.InlineResponse20019:
"""
Count points which matches given filtering condition
"""
return self._build_for_count_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
count_request=count_request,
)
def delete_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
delete_payload: m.DeletePayload = None,
) -> m.InlineResponse2005:
"""
Delete specified key payload for points
"""
return self._build_for_delete_payload(
collection_name=collection_name,
wait=wait,
ordering=ordering,
delete_payload=delete_payload,
)
def delete_points(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
points_selector: m.PointsSelector = None,
) -> m.InlineResponse2005:
"""
Delete points
"""
return self._build_for_delete_points(
collection_name=collection_name,
wait=wait,
ordering=ordering,
points_selector=points_selector,
)
def delete_vectors(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
delete_vectors: m.DeleteVectors = None,
) -> m.InlineResponse2005:
"""
Delete named vectors from the given points.
"""
return self._build_for_delete_vectors(
collection_name=collection_name,
wait=wait,
ordering=ordering,
delete_vectors=delete_vectors,
)
def facet(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
facet_request: m.FacetRequest = None,
) -> m.InlineResponse20020:
"""
Count points that satisfy the given filter for each unique value of a payload key.
"""
return self._build_for_facet(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
facet_request=facet_request,
)
def get_point(
self,
collection_name: str,
id: m.ExtendedPointId,
consistency: m.ReadConsistency = None,
) -> m.InlineResponse20012:
"""
Retrieve full information of single point by id
"""
return self._build_for_get_point(
collection_name=collection_name,
id=id,
consistency=consistency,
)
def get_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
point_request: m.PointRequest = None,
) -> m.InlineResponse20013:
"""
Retrieve multiple points by specified IDs
"""
return self._build_for_get_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
point_request=point_request,
)
def overwrite_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
set_payload: m.SetPayload = None,
) -> m.InlineResponse2005:
"""
Replace full payload of points with new one
"""
return self._build_for_overwrite_payload(
collection_name=collection_name,
wait=wait,
ordering=ordering,
set_payload=set_payload,
)
def scroll_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
scroll_request: m.ScrollRequest = None,
) -> m.InlineResponse20015:
"""
Scroll request - paginate over all points which matches given filtering condition
"""
return self._build_for_scroll_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
scroll_request=scroll_request,
)
def set_payload(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
set_payload: m.SetPayload = None,
) -> m.InlineResponse2005:
"""
Set payload values for points
"""
return self._build_for_set_payload(
collection_name=collection_name,
wait=wait,
ordering=ordering,
set_payload=set_payload,
)
def update_vectors(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
update_vectors: m.UpdateVectors = None,
) -> m.InlineResponse2005:
"""
Update specified named vectors on points, keep unspecified vectors intact.
"""
return self._build_for_update_vectors(
collection_name=collection_name,
wait=wait,
ordering=ordering,
update_vectors=update_vectors,
)
def upsert_points(
self,
collection_name: str,
wait: bool = None,
ordering: WriteOrdering = None,
point_insert_operations: m.PointInsertOperations = None,
) -> m.InlineResponse2005:
"""
Perform insert + updates on points. If point with given ID already exists - it will be overwritten.
"""
return self._build_for_upsert_points(
collection_name=collection_name,
wait=wait,
ordering=ordering,
point_insert_operations=point_insert_operations,
)
@@ -0,0 +1,941 @@
# flake8: noqa E501
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
from qdrant_client.http.models import models as m
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _SearchApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_discover_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
discover_request_batch: m.DiscoverRequestBatch = None,
):
"""
Look for points based on target and/or positive and negative example pairs, in batch.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(discover_request_batch)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20017,
method="POST",
url="/collections/{collection_name}/points/discover/batch",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_discover_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
discover_request: m.DiscoverRequest = None,
):
"""
Use context and a target to find the most similar points to the target, constrained by the context. When using only the context (without a target), a special search - called context search - is performed where pairs of points are used to generate a loss that guides the search towards the zone where most positive examples overlap. This means that the score minimizes the scenario of finding a point closer to a negative than to a positive part of a pair. Since the score of a context relates to loss, the maximum score a point can get is 0.0, and it becomes normal that many points can have a score of 0.0. When using target (with or without context), the score behaves a little different: The integer part of the score represents the rank with respect to the context, while the decimal part of the score relates to the distance to the target. The context part of the score for each pair is calculated +1 if the point is closer to a positive than to a negative part of a pair, and -1 otherwise.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(discover_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20016,
method="POST",
url="/collections/{collection_name}/points/discover",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_query_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_request_batch: m.QueryRequestBatch = None,
):
"""
Universally query points in batch. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(query_request_batch)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20022,
method="POST",
url="/collections/{collection_name}/points/query/batch",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_query_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_request: m.QueryRequest = None,
):
"""
Universally query points. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(query_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20021,
method="POST",
url="/collections/{collection_name}/points/query",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_query_points_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_groups_request: m.QueryGroupsRequest = None,
):
"""
Universally query points, grouped by a given payload field
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(query_groups_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20018,
method="POST",
url="/collections/{collection_name}/points/query/groups",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_recommend_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_request_batch: m.RecommendRequestBatch = None,
):
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(recommend_request_batch)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20017,
method="POST",
url="/collections/{collection_name}/points/recommend/batch",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_recommend_point_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_groups_request: m.RecommendGroupsRequest = None,
):
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples, grouped by a given payload field.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(recommend_groups_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20018,
method="POST",
url="/collections/{collection_name}/points/recommend/groups",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_recommend_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_request: m.RecommendRequest = None,
):
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(recommend_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20016,
method="POST",
url="/collections/{collection_name}/points/recommend",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_search_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_request_batch: m.SearchRequestBatch = None,
):
"""
Retrieve by batch the closest points based on vector similarity and given filtering conditions
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(search_request_batch)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20017,
method="POST",
url="/collections/{collection_name}/points/search/batch",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_search_matrix_offsets(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_matrix_request: m.SearchMatrixRequest = None,
):
"""
Compute distance matrix for sampled points with an offset based output format
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(search_matrix_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20024,
method="POST",
url="/collections/{collection_name}/points/search/matrix/offsets",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_search_matrix_pairs(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_matrix_request: m.SearchMatrixRequest = None,
):
"""
Compute distance matrix for sampled points with a pair based output format
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(search_matrix_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20023,
method="POST",
url="/collections/{collection_name}/points/search/matrix/pairs",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_search_point_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_groups_request: m.SearchGroupsRequest = None,
):
"""
Retrieve closest points based on vector similarity and given filtering conditions, grouped by a given payload field
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(search_groups_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20018,
method="POST",
url="/collections/{collection_name}/points/search/groups",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_search_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_request: m.SearchRequest = None,
):
"""
Retrieve closest points based on vector similarity and given filtering conditions
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if consistency is not None:
query_params["consistency"] = str(consistency)
if timeout is not None:
query_params["timeout"] = str(timeout)
headers = {}
body = jsonable_encoder(search_request)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse20016,
method="POST",
url="/collections/{collection_name}/points/search",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
class AsyncSearchApi(_SearchApi):
async def discover_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
discover_request_batch: m.DiscoverRequestBatch = None,
) -> m.InlineResponse20017:
"""
Look for points based on target and/or positive and negative example pairs, in batch.
"""
return await self._build_for_discover_batch_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
discover_request_batch=discover_request_batch,
)
async def discover_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
discover_request: m.DiscoverRequest = None,
) -> m.InlineResponse20016:
"""
Use context and a target to find the most similar points to the target, constrained by the context. When using only the context (without a target), a special search - called context search - is performed where pairs of points are used to generate a loss that guides the search towards the zone where most positive examples overlap. This means that the score minimizes the scenario of finding a point closer to a negative than to a positive part of a pair. Since the score of a context relates to loss, the maximum score a point can get is 0.0, and it becomes normal that many points can have a score of 0.0. When using target (with or without context), the score behaves a little different: The integer part of the score represents the rank with respect to the context, while the decimal part of the score relates to the distance to the target. The context part of the score for each pair is calculated +1 if the point is closer to a positive than to a negative part of a pair, and -1 otherwise.
"""
return await self._build_for_discover_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
discover_request=discover_request,
)
async def query_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_request_batch: m.QueryRequestBatch = None,
) -> m.InlineResponse20022:
"""
Universally query points in batch. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
"""
return await self._build_for_query_batch_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
query_request_batch=query_request_batch,
)
async def query_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_request: m.QueryRequest = None,
) -> m.InlineResponse20021:
"""
Universally query points. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
"""
return await self._build_for_query_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
query_request=query_request,
)
async def query_points_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_groups_request: m.QueryGroupsRequest = None,
) -> m.InlineResponse20018:
"""
Universally query points, grouped by a given payload field
"""
return await self._build_for_query_points_groups(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
query_groups_request=query_groups_request,
)
async def recommend_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_request_batch: m.RecommendRequestBatch = None,
) -> m.InlineResponse20017:
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
"""
return await self._build_for_recommend_batch_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
recommend_request_batch=recommend_request_batch,
)
async def recommend_point_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_groups_request: m.RecommendGroupsRequest = None,
) -> m.InlineResponse20018:
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples, grouped by a given payload field.
"""
return await self._build_for_recommend_point_groups(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
recommend_groups_request=recommend_groups_request,
)
async def recommend_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_request: m.RecommendRequest = None,
) -> m.InlineResponse20016:
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
"""
return await self._build_for_recommend_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
recommend_request=recommend_request,
)
async def search_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_request_batch: m.SearchRequestBatch = None,
) -> m.InlineResponse20017:
"""
Retrieve by batch the closest points based on vector similarity and given filtering conditions
"""
return await self._build_for_search_batch_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_request_batch=search_request_batch,
)
async def search_matrix_offsets(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_matrix_request: m.SearchMatrixRequest = None,
) -> m.InlineResponse20024:
"""
Compute distance matrix for sampled points with an offset based output format
"""
return await self._build_for_search_matrix_offsets(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_matrix_request=search_matrix_request,
)
async def search_matrix_pairs(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_matrix_request: m.SearchMatrixRequest = None,
) -> m.InlineResponse20023:
"""
Compute distance matrix for sampled points with a pair based output format
"""
return await self._build_for_search_matrix_pairs(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_matrix_request=search_matrix_request,
)
async def search_point_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_groups_request: m.SearchGroupsRequest = None,
) -> m.InlineResponse20018:
"""
Retrieve closest points based on vector similarity and given filtering conditions, grouped by a given payload field
"""
return await self._build_for_search_point_groups(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_groups_request=search_groups_request,
)
async def search_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_request: m.SearchRequest = None,
) -> m.InlineResponse20016:
"""
Retrieve closest points based on vector similarity and given filtering conditions
"""
return await self._build_for_search_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_request=search_request,
)
class SyncSearchApi(_SearchApi):
def discover_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
discover_request_batch: m.DiscoverRequestBatch = None,
) -> m.InlineResponse20017:
"""
Look for points based on target and/or positive and negative example pairs, in batch.
"""
return self._build_for_discover_batch_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
discover_request_batch=discover_request_batch,
)
def discover_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
discover_request: m.DiscoverRequest = None,
) -> m.InlineResponse20016:
"""
Use context and a target to find the most similar points to the target, constrained by the context. When using only the context (without a target), a special search - called context search - is performed where pairs of points are used to generate a loss that guides the search towards the zone where most positive examples overlap. This means that the score minimizes the scenario of finding a point closer to a negative than to a positive part of a pair. Since the score of a context relates to loss, the maximum score a point can get is 0.0, and it becomes normal that many points can have a score of 0.0. When using target (with or without context), the score behaves a little different: The integer part of the score represents the rank with respect to the context, while the decimal part of the score relates to the distance to the target. The context part of the score for each pair is calculated +1 if the point is closer to a positive than to a negative part of a pair, and -1 otherwise.
"""
return self._build_for_discover_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
discover_request=discover_request,
)
def query_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_request_batch: m.QueryRequestBatch = None,
) -> m.InlineResponse20022:
"""
Universally query points in batch. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
"""
return self._build_for_query_batch_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
query_request_batch=query_request_batch,
)
def query_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_request: m.QueryRequest = None,
) -> m.InlineResponse20021:
"""
Universally query points. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
"""
return self._build_for_query_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
query_request=query_request,
)
def query_points_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
query_groups_request: m.QueryGroupsRequest = None,
) -> m.InlineResponse20018:
"""
Universally query points, grouped by a given payload field
"""
return self._build_for_query_points_groups(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
query_groups_request=query_groups_request,
)
def recommend_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_request_batch: m.RecommendRequestBatch = None,
) -> m.InlineResponse20017:
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
"""
return self._build_for_recommend_batch_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
recommend_request_batch=recommend_request_batch,
)
def recommend_point_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_groups_request: m.RecommendGroupsRequest = None,
) -> m.InlineResponse20018:
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples, grouped by a given payload field.
"""
return self._build_for_recommend_point_groups(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
recommend_groups_request=recommend_groups_request,
)
def recommend_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
recommend_request: m.RecommendRequest = None,
) -> m.InlineResponse20016:
"""
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
"""
return self._build_for_recommend_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
recommend_request=recommend_request,
)
def search_batch_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_request_batch: m.SearchRequestBatch = None,
) -> m.InlineResponse20017:
"""
Retrieve by batch the closest points based on vector similarity and given filtering conditions
"""
return self._build_for_search_batch_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_request_batch=search_request_batch,
)
def search_matrix_offsets(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_matrix_request: m.SearchMatrixRequest = None,
) -> m.InlineResponse20024:
"""
Compute distance matrix for sampled points with an offset based output format
"""
return self._build_for_search_matrix_offsets(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_matrix_request=search_matrix_request,
)
def search_matrix_pairs(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_matrix_request: m.SearchMatrixRequest = None,
) -> m.InlineResponse20023:
"""
Compute distance matrix for sampled points with a pair based output format
"""
return self._build_for_search_matrix_pairs(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_matrix_request=search_matrix_request,
)
def search_point_groups(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_groups_request: m.SearchGroupsRequest = None,
) -> m.InlineResponse20018:
"""
Retrieve closest points based on vector similarity and given filtering conditions, grouped by a given payload field
"""
return self._build_for_search_point_groups(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_groups_request=search_groups_request,
)
def search_points(
self,
collection_name: str,
consistency: m.ReadConsistency = None,
timeout: int = None,
search_request: m.SearchRequest = None,
) -> m.InlineResponse20016:
"""
Retrieve closest points based on vector similarity and given filtering conditions
"""
return self._build_for_search_points(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_request=search_request,
)
@@ -0,0 +1,268 @@
# flake8: noqa E501
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
from qdrant_client.http.models import models as m
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _ServiceApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_healthz(
self,
):
"""
An endpoint for health checking used in Kubernetes.
"""
headers = {}
return self.api_client.request(
type_=str,
method="GET",
url="/healthz",
headers=headers if headers else None,
)
def _build_for_livez(
self,
):
"""
An endpoint for health checking used in Kubernetes.
"""
headers = {}
return self.api_client.request(
type_=str,
method="GET",
url="/livez",
headers=headers if headers else None,
)
def _build_for_metrics(
self,
anonymize: bool = None,
):
"""
Collect metrics data including app info, collections info, cluster info and statistics
"""
query_params = {}
if anonymize is not None:
query_params["anonymize"] = str(anonymize).lower()
headers = {}
return self.api_client.request(
type_=str,
method="GET",
url="/metrics",
headers=headers if headers else None,
params=query_params,
)
def _build_for_readyz(
self,
):
"""
An endpoint for health checking used in Kubernetes.
"""
headers = {}
return self.api_client.request(
type_=str,
method="GET",
url="/readyz",
headers=headers if headers else None,
)
def _build_for_root(
self,
):
"""
Returns information about the running Qdrant instance like version and commit id
"""
headers = {}
return self.api_client.request(
type_=m.VersionInfo,
method="GET",
url="/",
headers=headers if headers else None,
)
def _build_for_telemetry(
self,
anonymize: bool = None,
details_level: int = None,
):
"""
Collect telemetry data including app info, system info, collections info, cluster info, configs and statistics
"""
query_params = {}
if anonymize is not None:
query_params["anonymize"] = str(anonymize).lower()
if details_level is not None:
query_params["details_level"] = str(details_level)
headers = {}
return self.api_client.request(
type_=m.InlineResponse2001,
method="GET",
url="/telemetry",
headers=headers if headers else None,
params=query_params,
)
class AsyncServiceApi(_ServiceApi):
async def healthz(
self,
) -> str:
"""
An endpoint for health checking used in Kubernetes.
"""
return await self._build_for_healthz()
async def livez(
self,
) -> str:
"""
An endpoint for health checking used in Kubernetes.
"""
return await self._build_for_livez()
async def metrics(
self,
anonymize: bool = None,
) -> str:
"""
Collect metrics data including app info, collections info, cluster info and statistics
"""
return await self._build_for_metrics(
anonymize=anonymize,
)
async def readyz(
self,
) -> str:
"""
An endpoint for health checking used in Kubernetes.
"""
return await self._build_for_readyz()
async def root(
self,
) -> m.VersionInfo:
"""
Returns information about the running Qdrant instance like version and commit id
"""
return await self._build_for_root()
async def telemetry(
self,
anonymize: bool = None,
details_level: int = None,
) -> m.InlineResponse2001:
"""
Collect telemetry data including app info, system info, collections info, cluster info, configs and statistics
"""
return await self._build_for_telemetry(
anonymize=anonymize,
details_level=details_level,
)
class SyncServiceApi(_ServiceApi):
def healthz(
self,
) -> str:
"""
An endpoint for health checking used in Kubernetes.
"""
return self._build_for_healthz()
def livez(
self,
) -> str:
"""
An endpoint for health checking used in Kubernetes.
"""
return self._build_for_livez()
def metrics(
self,
anonymize: bool = None,
) -> str:
"""
Collect metrics data including app info, collections info, cluster info and statistics
"""
return self._build_for_metrics(
anonymize=anonymize,
)
def readyz(
self,
) -> str:
"""
An endpoint for health checking used in Kubernetes.
"""
return self._build_for_readyz()
def root(
self,
) -> m.VersionInfo:
"""
Returns information about the running Qdrant instance like version and commit id
"""
return self._build_for_root()
def telemetry(
self,
anonymize: bool = None,
details_level: int = None,
) -> m.InlineResponse2001:
"""
Collect telemetry data including app info, system info, collections info, cluster info, configs and statistics
"""
return self._build_for_telemetry(
anonymize=anonymize,
details_level=details_level,
)
@@ -0,0 +1,937 @@
# flake8: noqa E501
from typing import IO, TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
from pydantic import BaseModel
from pydantic.main import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from qdrant_client.http.models import *
from qdrant_client.http.models import models as m
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
Model = TypeVar("Model", bound="BaseModel")
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
file = None
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
if PYDANTIC_V2:
return model.model_dump_json(*args, **kwargs)
else:
return model.json(*args, **kwargs)
def jsonable_encoder(
obj: Any,
include: Union[SetIntStr, DictIntStrAny] = None,
exclude=None,
by_alias: bool = True,
skip_defaults: bool = None,
exclude_unset: bool = True,
exclude_none: bool = True,
):
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
return to_json(
obj,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
)
return obj
if TYPE_CHECKING:
from qdrant_client.http.api_client import ApiClient
class _SnapshotsApi:
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
self.api_client = api_client
def _build_for_create_full_snapshot(
self,
wait: bool = None,
):
"""
Create new snapshot of the whole storage
"""
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
headers = {}
return self.api_client.request(
type_=m.InlineResponse20011,
method="POST",
url="/snapshots",
headers=headers if headers else None,
params=query_params,
)
def _build_for_create_shard_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
):
"""
Create new snapshot of a shard for a collection
"""
path_params = {
"collection_name": str(collection_name),
"shard_id": str(shard_id),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
headers = {}
return self.api_client.request(
type_=m.InlineResponse20011,
method="POST",
url="/collections/{collection_name}/shards/{shard_id}/snapshots",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
def _build_for_create_snapshot(
self,
collection_name: str,
wait: bool = None,
):
"""
Create new snapshot for a collection
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
headers = {}
return self.api_client.request(
type_=m.InlineResponse20011,
method="POST",
url="/collections/{collection_name}/snapshots",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
def _build_for_delete_full_snapshot(
self,
snapshot_name: str,
wait: bool = None,
):
"""
Delete snapshot of the whole storage
"""
path_params = {
"snapshot_name": str(snapshot_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
headers = {}
return self.api_client.request(
type_=m.InlineResponse2009,
method="DELETE",
url="/snapshots/{snapshot_name}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
def _build_for_delete_shard_snapshot(
self,
collection_name: str,
shard_id: int,
snapshot_name: str,
wait: bool = None,
):
"""
Delete snapshot of a shard for a collection
"""
path_params = {
"collection_name": str(collection_name),
"shard_id": str(shard_id),
"snapshot_name": str(snapshot_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
headers = {}
return self.api_client.request(
type_=m.InlineResponse2009,
method="DELETE",
url="/collections/{collection_name}/shards/{shard_id}/snapshots/{snapshot_name}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
def _build_for_delete_snapshot(
self,
collection_name: str,
snapshot_name: str,
wait: bool = None,
):
"""
Delete snapshot for a collection
"""
path_params = {
"collection_name": str(collection_name),
"snapshot_name": str(snapshot_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
headers = {}
return self.api_client.request(
type_=m.InlineResponse2009,
method="DELETE",
url="/collections/{collection_name}/snapshots/{snapshot_name}",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
)
def _build_for_get_full_snapshot(
self,
snapshot_name: str,
):
"""
Download specified snapshot of the whole storage as a file
"""
path_params = {
"snapshot_name": str(snapshot_name),
}
headers = {}
return self.api_client.request(
type_=file,
method="GET",
url="/snapshots/{snapshot_name}",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_get_shard_snapshot(
self,
collection_name: str,
shard_id: int,
snapshot_name: str,
):
"""
Download specified snapshot of a shard from a collection as a file
"""
path_params = {
"collection_name": str(collection_name),
"shard_id": str(shard_id),
"snapshot_name": str(snapshot_name),
}
headers = {}
return self.api_client.request(
type_=file,
method="GET",
url="/collections/{collection_name}/shards/{shard_id}/snapshots/{snapshot_name}",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_get_snapshot(
self,
collection_name: str,
snapshot_name: str,
):
"""
Download specified snapshot from a collection as a file
"""
path_params = {
"collection_name": str(collection_name),
"snapshot_name": str(snapshot_name),
}
headers = {}
return self.api_client.request(
type_=file,
method="GET",
url="/collections/{collection_name}/snapshots/{snapshot_name}",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_list_full_snapshots(
self,
):
"""
Get list of snapshots of the whole storage
"""
headers = {}
return self.api_client.request(
type_=m.InlineResponse20010,
method="GET",
url="/snapshots",
headers=headers if headers else None,
)
def _build_for_list_shard_snapshots(
self,
collection_name: str,
shard_id: int,
):
"""
Get list of snapshots for a shard of a collection
"""
path_params = {
"collection_name": str(collection_name),
"shard_id": str(shard_id),
}
headers = {}
return self.api_client.request(
type_=m.InlineResponse20010,
method="GET",
url="/collections/{collection_name}/shards/{shard_id}/snapshots",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_list_snapshots(
self,
collection_name: str,
):
"""
Get list of snapshots for a collection
"""
path_params = {
"collection_name": str(collection_name),
}
headers = {}
return self.api_client.request(
type_=m.InlineResponse20010,
method="GET",
url="/collections/{collection_name}/snapshots",
headers=headers if headers else None,
path_params=path_params,
)
def _build_for_recover_from_snapshot(
self,
collection_name: str,
wait: bool = None,
snapshot_recover: m.SnapshotRecover = None,
):
"""
Recover local collection data from a snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
headers = {}
body = jsonable_encoder(snapshot_recover)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2009,
method="PUT",
url="/collections/{collection_name}/snapshots/recover",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_recover_from_uploaded_snapshot(
self,
collection_name: str,
wait: bool = None,
priority: SnapshotPriority = None,
checksum: str = None,
snapshot: IO[Any] = None,
):
"""
Recover local collection data from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
"""
path_params = {
"collection_name": str(collection_name),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if priority is not None:
query_params["priority"] = str(priority)
if checksum is not None:
query_params["checksum"] = str(checksum)
headers = {}
files: Dict[str, IO[Any]] = {} # noqa F841
data: Dict[str, Any] = {} # noqa F841
if snapshot is not None:
files["snapshot"] = snapshot
return self.api_client.request(
type_=m.InlineResponse2009,
method="POST",
url="/collections/{collection_name}/snapshots/upload",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
data=data,
files=files,
)
def _build_for_recover_shard_from_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
shard_snapshot_recover: m.ShardSnapshotRecover = None,
):
"""
Recover shard of a local collection data from a snapshot. This will overwrite any data, stored in this shard, for the collection.
"""
path_params = {
"collection_name": str(collection_name),
"shard_id": str(shard_id),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
headers = {}
body = jsonable_encoder(shard_snapshot_recover)
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return self.api_client.request(
type_=m.InlineResponse2009,
method="PUT",
url="/collections/{collection_name}/shards/{shard_id}/snapshots/recover",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
content=body,
)
def _build_for_recover_shard_from_uploaded_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
priority: SnapshotPriority = None,
checksum: str = None,
snapshot: IO[Any] = None,
):
"""
Recover shard of a local collection from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection shard.
"""
path_params = {
"collection_name": str(collection_name),
"shard_id": str(shard_id),
}
query_params = {}
if wait is not None:
query_params["wait"] = str(wait).lower()
if priority is not None:
query_params["priority"] = str(priority)
if checksum is not None:
query_params["checksum"] = str(checksum)
headers = {}
files: Dict[str, IO[Any]] = {} # noqa F841
data: Dict[str, Any] = {} # noqa F841
if snapshot is not None:
files["snapshot"] = snapshot
return self.api_client.request(
type_=m.InlineResponse2009,
method="POST",
url="/collections/{collection_name}/shards/{shard_id}/snapshots/upload",
headers=headers if headers else None,
path_params=path_params,
params=query_params,
data=data,
files=files,
)
class AsyncSnapshotsApi(_SnapshotsApi):
async def create_full_snapshot(
self,
wait: bool = None,
) -> m.InlineResponse20011:
"""
Create new snapshot of the whole storage
"""
return await self._build_for_create_full_snapshot(
wait=wait,
)
async def create_shard_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
) -> m.InlineResponse20011:
"""
Create new snapshot of a shard for a collection
"""
return await self._build_for_create_shard_snapshot(
collection_name=collection_name,
shard_id=shard_id,
wait=wait,
)
async def create_snapshot(
self,
collection_name: str,
wait: bool = None,
) -> m.InlineResponse20011:
"""
Create new snapshot for a collection
"""
return await self._build_for_create_snapshot(
collection_name=collection_name,
wait=wait,
)
async def delete_full_snapshot(
self,
snapshot_name: str,
wait: bool = None,
) -> m.InlineResponse2009:
"""
Delete snapshot of the whole storage
"""
return await self._build_for_delete_full_snapshot(
snapshot_name=snapshot_name,
wait=wait,
)
async def delete_shard_snapshot(
self,
collection_name: str,
shard_id: int,
snapshot_name: str,
wait: bool = None,
) -> m.InlineResponse2009:
"""
Delete snapshot of a shard for a collection
"""
return await self._build_for_delete_shard_snapshot(
collection_name=collection_name,
shard_id=shard_id,
snapshot_name=snapshot_name,
wait=wait,
)
async def delete_snapshot(
self,
collection_name: str,
snapshot_name: str,
wait: bool = None,
) -> m.InlineResponse2009:
"""
Delete snapshot for a collection
"""
return await self._build_for_delete_snapshot(
collection_name=collection_name,
snapshot_name=snapshot_name,
wait=wait,
)
async def get_full_snapshot(
self,
snapshot_name: str,
) -> file:
"""
Download specified snapshot of the whole storage as a file
"""
return await self._build_for_get_full_snapshot(
snapshot_name=snapshot_name,
)
async def get_shard_snapshot(
self,
collection_name: str,
shard_id: int,
snapshot_name: str,
) -> file:
"""
Download specified snapshot of a shard from a collection as a file
"""
return await self._build_for_get_shard_snapshot(
collection_name=collection_name,
shard_id=shard_id,
snapshot_name=snapshot_name,
)
async def get_snapshot(
self,
collection_name: str,
snapshot_name: str,
) -> file:
"""
Download specified snapshot from a collection as a file
"""
return await self._build_for_get_snapshot(
collection_name=collection_name,
snapshot_name=snapshot_name,
)
async def list_full_snapshots(
self,
) -> m.InlineResponse20010:
"""
Get list of snapshots of the whole storage
"""
return await self._build_for_list_full_snapshots()
async def list_shard_snapshots(
self,
collection_name: str,
shard_id: int,
) -> m.InlineResponse20010:
"""
Get list of snapshots for a shard of a collection
"""
return await self._build_for_list_shard_snapshots(
collection_name=collection_name,
shard_id=shard_id,
)
async def list_snapshots(
self,
collection_name: str,
) -> m.InlineResponse20010:
"""
Get list of snapshots for a collection
"""
return await self._build_for_list_snapshots(
collection_name=collection_name,
)
async def recover_from_snapshot(
self,
collection_name: str,
wait: bool = None,
snapshot_recover: m.SnapshotRecover = None,
) -> m.InlineResponse2009:
"""
Recover local collection data from a snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
"""
return await self._build_for_recover_from_snapshot(
collection_name=collection_name,
wait=wait,
snapshot_recover=snapshot_recover,
)
async def recover_from_uploaded_snapshot(
self,
collection_name: str,
wait: bool = None,
priority: SnapshotPriority = None,
checksum: str = None,
snapshot: IO[Any] = None,
) -> m.InlineResponse2009:
"""
Recover local collection data from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
"""
return await self._build_for_recover_from_uploaded_snapshot(
collection_name=collection_name,
wait=wait,
priority=priority,
checksum=checksum,
snapshot=snapshot,
)
async def recover_shard_from_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
shard_snapshot_recover: m.ShardSnapshotRecover = None,
) -> m.InlineResponse2009:
"""
Recover shard of a local collection data from a snapshot. This will overwrite any data, stored in this shard, for the collection.
"""
return await self._build_for_recover_shard_from_snapshot(
collection_name=collection_name,
shard_id=shard_id,
wait=wait,
shard_snapshot_recover=shard_snapshot_recover,
)
async def recover_shard_from_uploaded_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
priority: SnapshotPriority = None,
checksum: str = None,
snapshot: IO[Any] = None,
) -> m.InlineResponse2009:
"""
Recover shard of a local collection from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection shard.
"""
return await self._build_for_recover_shard_from_uploaded_snapshot(
collection_name=collection_name,
shard_id=shard_id,
wait=wait,
priority=priority,
checksum=checksum,
snapshot=snapshot,
)
class SyncSnapshotsApi(_SnapshotsApi):
def create_full_snapshot(
self,
wait: bool = None,
) -> m.InlineResponse20011:
"""
Create new snapshot of the whole storage
"""
return self._build_for_create_full_snapshot(
wait=wait,
)
def create_shard_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
) -> m.InlineResponse20011:
"""
Create new snapshot of a shard for a collection
"""
return self._build_for_create_shard_snapshot(
collection_name=collection_name,
shard_id=shard_id,
wait=wait,
)
def create_snapshot(
self,
collection_name: str,
wait: bool = None,
) -> m.InlineResponse20011:
"""
Create new snapshot for a collection
"""
return self._build_for_create_snapshot(
collection_name=collection_name,
wait=wait,
)
def delete_full_snapshot(
self,
snapshot_name: str,
wait: bool = None,
) -> m.InlineResponse2009:
"""
Delete snapshot of the whole storage
"""
return self._build_for_delete_full_snapshot(
snapshot_name=snapshot_name,
wait=wait,
)
def delete_shard_snapshot(
self,
collection_name: str,
shard_id: int,
snapshot_name: str,
wait: bool = None,
) -> m.InlineResponse2009:
"""
Delete snapshot of a shard for a collection
"""
return self._build_for_delete_shard_snapshot(
collection_name=collection_name,
shard_id=shard_id,
snapshot_name=snapshot_name,
wait=wait,
)
def delete_snapshot(
self,
collection_name: str,
snapshot_name: str,
wait: bool = None,
) -> m.InlineResponse2009:
"""
Delete snapshot for a collection
"""
return self._build_for_delete_snapshot(
collection_name=collection_name,
snapshot_name=snapshot_name,
wait=wait,
)
def get_full_snapshot(
self,
snapshot_name: str,
) -> file:
"""
Download specified snapshot of the whole storage as a file
"""
return self._build_for_get_full_snapshot(
snapshot_name=snapshot_name,
)
def get_shard_snapshot(
self,
collection_name: str,
shard_id: int,
snapshot_name: str,
) -> file:
"""
Download specified snapshot of a shard from a collection as a file
"""
return self._build_for_get_shard_snapshot(
collection_name=collection_name,
shard_id=shard_id,
snapshot_name=snapshot_name,
)
def get_snapshot(
self,
collection_name: str,
snapshot_name: str,
) -> file:
"""
Download specified snapshot from a collection as a file
"""
return self._build_for_get_snapshot(
collection_name=collection_name,
snapshot_name=snapshot_name,
)
def list_full_snapshots(
self,
) -> m.InlineResponse20010:
"""
Get list of snapshots of the whole storage
"""
return self._build_for_list_full_snapshots()
def list_shard_snapshots(
self,
collection_name: str,
shard_id: int,
) -> m.InlineResponse20010:
"""
Get list of snapshots for a shard of a collection
"""
return self._build_for_list_shard_snapshots(
collection_name=collection_name,
shard_id=shard_id,
)
def list_snapshots(
self,
collection_name: str,
) -> m.InlineResponse20010:
"""
Get list of snapshots for a collection
"""
return self._build_for_list_snapshots(
collection_name=collection_name,
)
def recover_from_snapshot(
self,
collection_name: str,
wait: bool = None,
snapshot_recover: m.SnapshotRecover = None,
) -> m.InlineResponse2009:
"""
Recover local collection data from a snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
"""
return self._build_for_recover_from_snapshot(
collection_name=collection_name,
wait=wait,
snapshot_recover=snapshot_recover,
)
def recover_from_uploaded_snapshot(
self,
collection_name: str,
wait: bool = None,
priority: SnapshotPriority = None,
checksum: str = None,
snapshot: IO[Any] = None,
) -> m.InlineResponse2009:
"""
Recover local collection data from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
"""
return self._build_for_recover_from_uploaded_snapshot(
collection_name=collection_name,
wait=wait,
priority=priority,
checksum=checksum,
snapshot=snapshot,
)
def recover_shard_from_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
shard_snapshot_recover: m.ShardSnapshotRecover = None,
) -> m.InlineResponse2009:
"""
Recover shard of a local collection data from a snapshot. This will overwrite any data, stored in this shard, for the collection.
"""
return self._build_for_recover_shard_from_snapshot(
collection_name=collection_name,
shard_id=shard_id,
wait=wait,
shard_snapshot_recover=shard_snapshot_recover,
)
def recover_shard_from_uploaded_snapshot(
self,
collection_name: str,
shard_id: int,
wait: bool = None,
priority: SnapshotPriority = None,
checksum: str = None,
snapshot: IO[Any] = None,
) -> m.InlineResponse2009:
"""
Recover shard of a local collection from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection shard.
"""
return self._build_for_recover_shard_from_uploaded_snapshot(
collection_name=collection_name,
shard_id=shard_id,
wait=wait,
priority=priority,
checksum=checksum,
snapshot=snapshot,
)
@@ -0,0 +1,263 @@
from asyncio import get_event_loop
from functools import lru_cache
from typing import Any, Awaitable, Callable, Dict, Generic, Type, TypeVar, overload
from urllib.parse import urljoin
from httpx import AsyncClient, Client, Request, Response
from pydantic import ValidationError
from qdrant_client.common.client_exceptions import ResourceExhaustedResponse
from qdrant_client.http.api.aliases_api import AsyncAliasesApi, SyncAliasesApi
from qdrant_client.http.api.beta_api import AsyncBetaApi, SyncBetaApi
from qdrant_client.http.api.collections_api import AsyncCollectionsApi, SyncCollectionsApi
from qdrant_client.http.api.distributed_api import AsyncDistributedApi, SyncDistributedApi
from qdrant_client.http.api.indexes_api import AsyncIndexesApi, SyncIndexesApi
from qdrant_client.http.api.points_api import AsyncPointsApi, SyncPointsApi
from qdrant_client.http.api.search_api import AsyncSearchApi, SyncSearchApi
from qdrant_client.http.api.service_api import AsyncServiceApi, SyncServiceApi
from qdrant_client.http.api.snapshots_api import AsyncSnapshotsApi, SyncSnapshotsApi
from qdrant_client.http.exceptions import ResponseHandlingException, UnexpectedResponse
ClientT = TypeVar("ClientT", bound="ApiClient")
AsyncClientT = TypeVar("AsyncClientT", bound="AsyncApiClient")
class AsyncApis(Generic[AsyncClientT]):
def __init__(self, host: str, **kwargs: Any):
self.client = AsyncApiClient(host, **kwargs)
self.aliases_api = AsyncAliasesApi(self.client)
self.beta_api = AsyncBetaApi(self.client)
self.collections_api = AsyncCollectionsApi(self.client)
self.distributed_api = AsyncDistributedApi(self.client)
self.indexes_api = AsyncIndexesApi(self.client)
self.points_api = AsyncPointsApi(self.client)
self.search_api = AsyncSearchApi(self.client)
self.service_api = AsyncServiceApi(self.client)
self.snapshots_api = AsyncSnapshotsApi(self.client)
async def aclose(self) -> None:
await self.client.aclose()
class SyncApis(Generic[ClientT]):
def __init__(self, host: str, **kwargs: Any):
self.client = ApiClient(host, **kwargs)
self.aliases_api = SyncAliasesApi(self.client)
self.beta_api = SyncBetaApi(self.client)
self.collections_api = SyncCollectionsApi(self.client)
self.distributed_api = SyncDistributedApi(self.client)
self.indexes_api = SyncIndexesApi(self.client)
self.points_api = SyncPointsApi(self.client)
self.search_api = SyncSearchApi(self.client)
self.service_api = SyncServiceApi(self.client)
self.snapshots_api = SyncSnapshotsApi(self.client)
def close(self) -> None:
self.client.close()
T = TypeVar("T")
Send = Callable[[Request], Response]
SendAsync = Callable[[Request], Awaitable[Response]]
MiddlewareT = Callable[[Request, Send], Response]
AsyncMiddlewareT = Callable[[Request, SendAsync], Awaitable[Response]]
class ApiClient:
def __init__(self, host: str, **kwargs: Any) -> None:
self.host = host
self.middleware: MiddlewareT = BaseMiddleware()
self._client = Client(**kwargs)
@overload
def request(self, *, type_: Type[T], method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any) -> T:
...
@overload # noqa F811
def request(self, *, type_: None, method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any) -> None:
...
def request( # noqa F811
self, *, type_: Any, method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any
) -> Any:
if path_params is None:
path_params = {}
host = self.host if self.host.endswith("/") else self.host + "/"
url = url[1:] if url.startswith("/") else url
# in order to do a correct join, url join requires base_url to end with /, and url to not start with /,
# since url is treated as an absolute path and might truncate prefix in base_url
url = urljoin(host, url.format(**path_params))
if "params" in kwargs and "timeout" in kwargs["params"]:
kwargs["timeout"] = int(kwargs["params"]["timeout"])
request = self._client.build_request(method, url, **kwargs)
return self.send(request, type_)
@overload
def request_sync(self, *, type_: Type[T], **kwargs: Any) -> T:
...
@overload # noqa F811
def request_sync(self, *, type_: None, **kwargs: Any) -> None:
...
def request_sync(self, *, type_: Any, **kwargs: Any) -> Any: # noqa F811
"""
This method is not used by the generated apis, but is included for convenience
"""
return get_event_loop().run_until_complete(self.request(type_=type_, **kwargs))
def send(self, request: Request, type_: Type[T]) -> T:
response = self.middleware(request, self.send_inner)
if response.status_code == 429:
retry_after_s = response.headers.get("Retry-After", None)
try:
resp = response.json()
message = resp["status"]["error"] if resp["status"] and resp["status"]["error"] else ""
except Exception:
message = ""
if retry_after_s:
raise ResourceExhaustedResponse(message, retry_after_s)
if response.status_code in [200, 201, 202]:
try:
return parse_as_type(response.json(), type_)
except ValidationError as e:
raise ResponseHandlingException(e)
raise UnexpectedResponse.for_response(response)
def send_inner(self, request: Request) -> Response:
try:
response = self._client.send(request)
except Exception as e:
raise ResponseHandlingException(e)
return response
def close(self) -> None:
self._client.close()
def add_middleware(self, middleware: MiddlewareT) -> None:
current_middleware = self.middleware
def new_middleware(request: Request, call_next: Send) -> Response:
def inner_send(request: Request) -> Response:
return current_middleware(request, call_next)
return middleware(request, inner_send)
self.middleware = new_middleware
class AsyncApiClient:
def __init__(self, host: str = None, **kwargs: Any) -> None:
self.host = host
self.middleware: AsyncMiddlewareT = BaseAsyncMiddleware()
self._async_client = AsyncClient(**kwargs)
@overload
async def request(
self, *, type_: Type[T], method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any
) -> T:
...
@overload # noqa F811
async def request(
self, *, type_: None, method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any
) -> None:
...
async def request( # noqa F811
self, *, type_: Any, method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any
) -> Any:
if path_params is None:
path_params = {}
host = self.host if self.host.endswith("/") else self.host + "/"
url = url[1:] if url.startswith("/") else url
# in order to do a correct join, url join requires base_url to end with /, and url to not start with /,
# since url is treated as an absolute path and might truncate prefix in base_url
url = urljoin(host, url.format(**path_params))
request = self._async_client.build_request(method, url, **kwargs)
return await self.send(request, type_)
@overload
def request_sync(self, *, type_: Type[T], **kwargs: Any) -> T:
...
@overload # noqa F811
def request_sync(self, *, type_: None, **kwargs: Any) -> None:
...
def request_sync(self, *, type_: Any, **kwargs: Any) -> Any: # noqa F811
"""
This method is not used by the generated apis, but is included for convenience
"""
return get_event_loop().run_until_complete(self.request(type_=type_, **kwargs))
async def send(self, request: Request, type_: Type[T]) -> T:
response = await self.middleware(request, self.send_inner)
if response.status_code == 429:
retry_after_s = response.headers.get("Retry-After", None)
try:
resp = response.json()
message = resp["status"]["error"] if resp["status"] and resp["status"]["error"] else ""
except Exception:
message = ""
if retry_after_s:
raise ResourceExhaustedResponse(message, retry_after_s)
if response.status_code in [200, 201, 202]:
try:
return parse_as_type(response.json(), type_)
except ValidationError as e:
raise ResponseHandlingException(e)
raise UnexpectedResponse.for_response(response)
async def send_inner(self, request: Request) -> Response:
try:
response = await self._async_client.send(request)
except Exception as e:
raise ResponseHandlingException(e)
return response
async def aclose(self) -> None:
await self._async_client.aclose()
def add_middleware(self, middleware: AsyncMiddlewareT) -> None:
current_middleware = self.middleware
async def new_middleware(request: Request, call_next: SendAsync) -> Response:
async def inner_send(request: Request) -> Response:
return await current_middleware(request, call_next)
return await middleware(request, inner_send)
self.middleware = new_middleware
class BaseAsyncMiddleware:
async def __call__(self, request: Request, call_next: SendAsync) -> Response:
return await call_next(request)
class BaseMiddleware:
def __call__(self, request: Request, call_next: Send) -> Response:
return call_next(request)
@lru_cache(maxsize=None)
def _get_parsing_type(type_: Any, source: str) -> Any:
from pydantic.main import create_model
type_name = getattr(type_, "__name__", str(type_))
return create_model(f"ParsingModel[{type_name}] (for {source})", obj=(type_, ...))
def parse_as_type(obj: Any, type_: Type[T]) -> T:
model_type = _get_parsing_type(type_, source=parse_as_type.__name__)
return model_type(obj=obj).obj
@@ -0,0 +1,5 @@
from typing import Union
# This is a dirty hack - proper way it to upgrade OpenAPI generator for at least 5.0
# But this upgrade will also require updating of all templates. Maybe some other day
AnyOfstringinteger = Union[str, int]
@@ -0,0 +1,46 @@
import json
from typing import Any, Dict, Optional
from httpx import Headers, Response
MAX_CONTENT = 200
class ApiException(Exception):
"""Base class"""
class UnexpectedResponse(ApiException):
def __init__(self, status_code: Optional[int], reason_phrase: str, content: bytes, headers: Headers) -> None:
self.status_code = status_code
self.reason_phrase = reason_phrase
self.content = content
self.headers = headers
@staticmethod
def for_response(response: Response) -> "ApiException":
return UnexpectedResponse(
status_code=response.status_code,
reason_phrase=response.reason_phrase,
content=response.content,
headers=response.headers,
)
def __str__(self) -> str:
status_code_str = f"{self.status_code}" if self.status_code is not None else ""
if self.reason_phrase == "" and self.status_code is not None:
reason_phrase_str = "(Unrecognized Status Code)"
else:
reason_phrase_str = f"({self.reason_phrase})"
status_str = f"{status_code_str} {reason_phrase_str}".strip()
short_content = self.content if len(self.content) <= MAX_CONTENT else self.content[: MAX_CONTENT - 3] + b" ..."
raw_content_str = f"Raw response content:\n{short_content!r}"
return f"Unexpected Response: {status_str}\n{raw_content_str}"
def structured(self) -> Dict[str, Any]:
return json.loads(self.content)
class ResponseHandlingException(ApiException):
def __init__(self, source: Exception):
self.source = source
@@ -0,0 +1 @@
from .models import *
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,370 @@
import math
from qdrant_client.conversions.common_types import get_args_subscribed
from qdrant_client.http import models
from typing import Union, Any, Tuple
from qdrant_client.local import datetime_utils
from qdrant_client.local.geo import geo_distance
from qdrant_client.local.payload_filters import check_condition
from qdrant_client.local.payload_value_extractor import value_by_key
DEFAULT_SCORE = 0.0
DEFAULT_DECAY_TARGET = 0.0
DEFAULT_DECAY_MIDPOINT = 0.5
DEFAULT_DECAY_SCALE = 1.0
def evaluate_expression(
expression: models.Expression,
point_id: models.ExtendedPointId,
scores: list[dict[models.ExtendedPointId, float]],
payload: models.Payload,
has_vector: dict[str, bool],
defaults: dict[str, Any],
) -> float:
if isinstance(expression, (float, int)): # Constant
return float(expression)
elif isinstance(expression, str): # Variable
return evaluate_variable(expression, point_id, scores, payload, defaults)
elif isinstance(expression, get_args_subscribed(models.Condition)):
if check_condition(expression, payload, point_id, has_vector): # type: ignore
return 1.0
return 0.0
elif isinstance(expression, models.MultExpression):
factors: list[float] = []
for expr in expression.mult:
factor = evaluate_expression(expr, point_id, scores, payload, has_vector, defaults)
# Return early if any factor is zero
if factor == 0.0:
return factor
factors.append(factor)
return math.prod(factors)
elif isinstance(expression, models.SumExpression):
return sum(
evaluate_expression(expr, point_id, scores, payload, has_vector, defaults)
for expr in expression.sum
)
elif isinstance(expression, models.NegExpression):
value = evaluate_expression(
expression.neg, point_id, scores, payload, has_vector, defaults
)
return -value
elif isinstance(expression, models.AbsExpression):
return abs(
evaluate_expression(expression.abs, point_id, scores, payload, has_vector, defaults)
)
elif isinstance(expression, models.DivExpression):
left = evaluate_expression(
expression.div.left, point_id, scores, payload, has_vector, defaults
)
if left == 0.0:
return left
right = evaluate_expression(
expression.div.right, point_id, scores, payload, has_vector, defaults
)
if right == 0.0:
if expression.div.by_zero_default is not None:
return expression.div.by_zero_default
raise_non_finite_error(f"{left}/{right}")
result = left / right
if math.isfinite(result):
return result
raise_non_finite_error(f"{left}/{right}")
elif isinstance(expression, models.SqrtExpression):
value = evaluate_expression(
expression.sqrt, point_id, scores, payload, has_vector, defaults
)
if value >= 0:
return math.sqrt(value)
raise_non_finite_error(f"{value}")
elif isinstance(expression, models.PowExpression):
base = evaluate_expression(
expression.pow.base, point_id, scores, payload, has_vector, defaults
)
exponent = evaluate_expression(
expression.pow.exponent, point_id, scores, payload, has_vector, defaults
)
# Check for valid input
if base >= 0 or (base != 0 and exponent.is_integer()):
try:
return math.pow(base, exponent)
except OverflowError:
pass
raise_non_finite_error(f"{base}^{exponent}")
elif isinstance(expression, models.ExpExpression):
value = evaluate_expression(
expression.exp, point_id, scores, payload, has_vector, defaults
)
try:
return math.exp(value)
except OverflowError:
raise_non_finite_error(f"exp({value})")
elif isinstance(expression, models.Log10Expression):
value = evaluate_expression(
expression.log10, point_id, scores, payload, has_vector, defaults
)
if value > 0:
try:
return math.log10(value)
except OverflowError:
pass
raise_non_finite_error(f"log10({value})")
elif isinstance(expression, models.LnExpression):
value = evaluate_expression(expression.ln, point_id, scores, payload, has_vector, defaults)
if value > 0:
try:
return math.log(value)
except OverflowError:
pass
raise_non_finite_error(f"ln({value})")
elif isinstance(expression, models.GeoDistance):
origin = expression.geo_distance.origin
to = expression.geo_distance.to
# Get value from payload
geo_value = try_extract_payload_value(to, payload, defaults)
if isinstance(geo_value, dict):
# let this fail if it is not a valid geo point
destination = models.GeoPoint(**geo_value)
return geo_distance(origin.lon, origin.lat, destination.lon, destination.lat)
raise ValueError(
f"Expected geo point for {to} in the payload and/or in the formula defaults."
)
elif isinstance(expression, models.DatetimeExpression):
# try to parse as datetime
dt = datetime_utils.parse(expression.datetime)
if dt is None:
raise ValueError(f"Expected datetime in supported format for {expression.datetime}")
return dt.timestamp()
elif isinstance(expression, models.DatetimeKeyExpression):
dt_str = try_extract_payload_value(expression.datetime_key, payload, defaults)
dt = datetime_utils.parse(dt_str)
if dt is None:
raise ValueError(
f"Expected datetime for {expression.datetime_key} in the payload and/or in the formula defaults."
)
return dt.timestamp()
elif isinstance(expression, models.LinDecayExpression):
x, target, midpoint, scale = evaluate_decay_params(
expression.lin_decay, point_id, scores, payload, has_vector, defaults
)
lambda_factor = (1.0 - midpoint) / scale
diff = abs(x - target)
return max(0.0, -lambda_factor * diff + 1.0)
elif isinstance(expression, models.ExpDecayExpression):
x, target, midpoint, scale = evaluate_decay_params(
expression.exp_decay, point_id, scores, payload, has_vector, defaults
)
lambda_factor = math.log(midpoint) / scale
diff = abs(x - target)
return math.exp(lambda_factor * diff)
elif isinstance(expression, models.GaussDecayExpression):
x, target, midpoint, scale = evaluate_decay_params(
expression.gauss_decay, point_id, scores, payload, has_vector, defaults
)
lambda_factor = math.log(midpoint) / (scale * scale)
diff = x - target
return math.exp(lambda_factor * diff * diff)
raise ValueError(f"Unsupported expression type: {type(expression)}")
def evaluate_decay_params(
params: models.DecayParamsExpression,
point_id: models.ExtendedPointId,
scores: list[dict[models.ExtendedPointId, float]],
payload: models.Payload,
has_vector: dict[str, bool],
defaults: dict[str, Any],
) -> Tuple[float, float, float, float]:
x = evaluate_expression(params.x, point_id, scores, payload, has_vector, defaults)
if params.target is None:
target = DEFAULT_DECAY_TARGET
else:
target = evaluate_expression(
params.target, point_id, scores, payload, has_vector, defaults
)
midpoint = params.midpoint if params.midpoint is not None else DEFAULT_DECAY_MIDPOINT
if midpoint <= 0.0 or midpoint >= 1.0:
raise ValueError(f"Midpoint must be between 0 and 1, got {midpoint}")
scale = params.scale if params.scale is not None else DEFAULT_DECAY_SCALE
if scale <= 0.0:
raise ValueError(f"Scale must be non-zero positive, got {scale}")
return x, target, midpoint, scale
def try_extract_payload_value(key: str, payload: models.Payload, defaults: dict[str, Any]) -> Any:
# Get value from payload
value = value_by_key(payload, key)
if value is None or len(value) == 0:
# Or from defaults
value = defaults.get(key, None)
# Consider it None if it is an empty list
if isinstance(value, list) and len(value) == 0:
value = None
# Consider it a single value if it's a list with one element
if isinstance(value, list) and len(value) == 1:
return value[0]
if value is None:
raise ValueError(f"No value found for {key} in the payload nor the formula defaults")
return value
def evaluate_variable(
variable: str,
point_id: models.ExtendedPointId,
scores: list[dict[models.ExtendedPointId, float]],
payload: models.Payload,
defaults: dict[str, Any],
) -> float:
var = parse_variable(variable)
if isinstance(var, str):
value = try_extract_payload_value(var, payload, defaults)
if is_number(value):
return value
raise ValueError(
f"Expected number value for {var} in the payload and/or in the formula defaults. Error: Value is not a number"
)
elif isinstance(var, int):
# Get score from scores
score = None
if var < len(scores):
score = scores[var].get(point_id, None)
if score is not None:
return score
defined_default = defaults.get(variable, None)
if defined_default is not None:
return defined_default
return DEFAULT_SCORE
raise ValueError(f"Invalid variable type: {type(var)}")
def parse_variable(var: str) -> Union[str, int]:
# Try to parse score pattern
if not var.startswith("$score"):
# Treat as payload path
return var
remaining = var.replace("$score", "", 1)
if remaining == "":
# end of string, default idx is 0
return 0
# it must proceed with brackets
if not remaining.startswith("["):
raise ValueError(f"Invalid score pattern: {var}")
remaining = remaining.replace("[", "", 1)
bracket_end = remaining.find("]")
if bracket_end == -1:
raise ValueError(f"Invalid score pattern: {var}")
# try parsing the content in between brackets as integer
try:
idx = int(remaining[:bracket_end])
except ValueError:
raise ValueError(f"Invalid score pattern: {var}")
# make sure the string ends after the closing bracket
if len(remaining) > bracket_end + 1:
raise ValueError(f"Invalid score pattern: {var}")
return idx
def raise_non_finite_error(expression: str) -> None:
raise ValueError(f"The expression {expression} produced a non-finite number")
def is_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)
def test_parsing_variable() -> None:
assert parse_variable("$score") == 0
assert parse_variable("$score[0]") == 0
assert parse_variable("$score[1]") == 1
assert parse_variable("$score[2]") == 2
try:
parse_variable("$score[invalid]")
assert False
except ValueError as e:
assert str(e) == "Invalid score pattern: $score[invalid]"
try:
parse_variable("$score[10].other")
assert False
except ValueError as e:
assert str(e) == "Invalid score pattern: $score[10].other"
def test_try_extract_payload_value() -> None:
for payload_value, expected in [(1.2, 1.2), ([1.2], 1.2), ([1.2, 2.3], [1.2, 2.3])]:
empty_defaults: dict[str, Any] = {}
payload = {"key": payload_value}
assert try_extract_payload_value("key", payload, empty_defaults) == expected
defaults = {"key": payload_value}
empty_payload: dict[str, Any] = {}
assert try_extract_payload_value("key", empty_payload, defaults) == expected
@@ -0,0 +1,79 @@
from typing import Optional
from qdrant_client.http import models
DEFAULT_RANKING_CONSTANT_K = 2
def reciprocal_rank_fusion(
responses: list[list[models.ScoredPoint]],
limit: int = 10,
ranking_constant_k: Optional[int] = None,
) -> list[models.ScoredPoint]:
def compute_score(pos: int) -> float:
ranking_constant = (
ranking_constant_k if ranking_constant_k is not None else DEFAULT_RANKING_CONSTANT_K
) # mitigates the impact of high rankings by outlier systems
return 1 / (ranking_constant + pos)
scores: dict[models.ExtendedPointId, float] = {}
point_pile = {}
for response in responses:
for i, scored_point in enumerate(response):
if scored_point.id in scores:
scores[scored_point.id] += compute_score(i)
else:
point_pile[scored_point.id] = scored_point
scores[scored_point.id] = compute_score(i)
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
sorted_points = []
for point_id, score in sorted_scores[:limit]:
point = point_pile[point_id]
point.score = score
sorted_points.append(point)
return sorted_points
def distribution_based_score_fusion(
responses: list[list[models.ScoredPoint]], limit: int
) -> list[models.ScoredPoint]:
def normalize(response: list[models.ScoredPoint]) -> list[models.ScoredPoint]:
if len(response) == 1:
response[0].score = 0.5
return response
total = sum([point.score for point in response])
mean = total / len(response)
variance = sum([(point.score - mean) ** 2 for point in response]) / (len(response) - 1)
if variance == 0:
for point in response:
point.score = 0.5
return response
std_dev = variance**0.5
low = mean - 3 * std_dev
high = mean + 3 * std_dev
for point in response:
point.score = (point.score - low) / (high - low)
return response
points_map: dict[models.ExtendedPointId, models.ScoredPoint] = {}
for response in responses:
if not response:
continue
normalized = normalize(response)
for point in normalized:
entry = points_map.get(point.id)
if entry is None:
points_map[point.id] = point
else:
entry.score += point.score
sorted_points = sorted(points_map.values(), key=lambda item: item.score, reverse=True)
return sorted_points[:limit]
@@ -0,0 +1,117 @@
import numpy as np
from qdrant_client.http import models
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion, distribution_based_score_fusion
def test_reciprocal_rank_fusion() -> None:
responses = [
[
models.ScoredPoint(id="1", score=0.1, version=1),
models.ScoredPoint(id="2", score=0.2, version=1),
models.ScoredPoint(id="3", score=0.3, version=1),
],
[
models.ScoredPoint(id="5", score=12.0, version=1),
models.ScoredPoint(id="6", score=8.0, version=1),
models.ScoredPoint(id="7", score=5.0, version=1),
models.ScoredPoint(id="2", score=3.0, version=1),
],
]
fused = reciprocal_rank_fusion(responses)
assert fused[0].id == "2"
assert fused[1].id in ["1", "5"]
assert np.isclose(fused[1].score, 1 / 2)
assert fused[2].id in ["1", "5"]
assert np.isclose(fused[2].score, 1 / 2)
def test_distribution_based_score_fusion() -> None:
responses = [
[
models.ScoredPoint(id=1, version=0, score=85.0),
models.ScoredPoint(id=0, version=0, score=76.0),
models.ScoredPoint(id=5, version=0, score=68.0),
],
[
models.ScoredPoint(id=1, version=0, score=62.0),
models.ScoredPoint(id=0, version=0, score=61.0),
models.ScoredPoint(id=4, version=0, score=57.0),
models.ScoredPoint(id=3, version=0, score=51.0),
models.ScoredPoint(id=2, version=0, score=44.0),
],
]
fused = distribution_based_score_fusion(responses, limit=3)
assert fused[0].id == 1
assert fused[1].id == 0
assert fused[2].id == 4
def test_reciprocal_rank_fusion_empty_responses() -> None:
responses: list[list[models.ScoredPoint]] = [[]]
fused = reciprocal_rank_fusion(responses)
assert fused == []
responses = [
[
models.ScoredPoint(id="1", score=0.1, version=1),
models.ScoredPoint(id="2", score=0.2, version=1),
models.ScoredPoint(id="3", score=0.3, version=1),
],
[],
]
fused = reciprocal_rank_fusion(responses)
assert fused[0].id == "1"
assert np.isclose(fused[0].score, 1 / 2)
assert fused[1].id == "2"
assert np.isclose(fused[1].score, 1 / 3)
assert fused[2].id == "3"
assert np.isclose(fused[2].score, 1 / 4)
def test_distribution_based_score_fusion_empty_response() -> None:
responses: list[list[models.ScoredPoint]] = [[]]
fused = distribution_based_score_fusion(responses, limit=3)
assert fused == []
responses = [
[
models.ScoredPoint(id=1, version=0, score=85.0),
models.ScoredPoint(id=0, version=0, score=76.0),
models.ScoredPoint(id=5, version=0, score=68.0),
],
[],
]
fused = distribution_based_score_fusion(responses, limit=3)
assert fused[0].id == 1
assert fused[1].id == 0
assert fused[2].id == 5
def test_distribution_based_score_fusion_zero_variance() -> None:
score = 85.0
responses = [
[
models.ScoredPoint(id=1, version=0, score=score),
models.ScoredPoint(id=0, version=0, score=score),
models.ScoredPoint(id=5, version=0, score=score),
],
[],
]
fused = distribution_based_score_fusion(
[[models.ScoredPoint(id=1, version=0, score=score)]], limit=3
)
assert fused[0].id == 1
assert fused[0].score == 0.5
fused = distribution_based_score_fusion(responses, limit=3)
assert len(fused) == 3
assert all([p.score == 0.5 for p in fused])
@@ -0,0 +1,973 @@
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
#
# This file is autogenerated. Do not edit it manually.
# To regenerate this file, use
#
# ```
# bash -x tools/generate_async_client.sh
# ```
#
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
import importlib.metadata
import itertools
import json
import os
import shutil
import uuid
from copy import deepcopy
from io import TextIOWrapper
from typing import Any, Generator, Iterable, Mapping, Optional, Sequence, Union, get_args
from uuid import uuid4
import numpy as np
import portalocker
from qdrant_client.common.client_warnings import show_warning, show_warning_once
from qdrant_client._pydantic_compat import to_dict
from qdrant_client.async_client_base import AsyncQdrantBase
from qdrant_client.conversions import common_types as types
from qdrant_client.http import models as rest_models
from qdrant_client.local.local_collection import (
LocalCollection,
DEFAULT_VECTOR_NAME,
ignore_mentioned_ids_filter,
)
META_INFO_FILENAME = "meta.json"
class AsyncQdrantLocal(AsyncQdrantBase):
"""
Everything Qdrant server can do, but locally.
Use this implementation to run vector search without running a Qdrant server.
Everything that works with local Qdrant will work with server Qdrant as well.
Use for small-scale data, demos, and tests.
If you need more speed or size, use Qdrant server.
"""
LARGE_DATA_THRESHOLD = 20000
def __init__(self, location: str, force_disable_check_same_thread: bool = False) -> None:
"""
Initialize local Qdrant.
Args:
location: Where to store data. Can be a path to a directory or `:memory:` for in-memory storage.
force_disable_check_same_thread: Disable SQLite check_same_thread check. Use only if you know what you are doing.
"""
super().__init__()
self.force_disable_check_same_thread = force_disable_check_same_thread
self.location = location
self.persistent = location != ":memory:"
self.collections: dict[str, LocalCollection] = {}
self.aliases: dict[str, str] = {}
self._flock_file: Optional[TextIOWrapper] = None
self._load()
self._closed: bool = False
@property
def closed(self) -> bool:
return self._closed
async def close(self, **kwargs: Any) -> None:
self._closed = True
for collection in self.collections.values():
if collection is not None:
collection.close()
else:
show_warning(
message=f"Collection appears to be None before closing. The existing collections are: {list(self.collections.keys())}",
category=UserWarning,
stacklevel=4,
)
try:
if self._flock_file is not None and (not self._flock_file.closed):
portalocker.unlock(self._flock_file)
self._flock_file.close()
except TypeError:
pass
def _load(self) -> None:
deprecated_config_fields = ("init_from",)
if not self.persistent:
return
meta_path = os.path.join(self.location, META_INFO_FILENAME)
if not os.path.exists(meta_path):
os.makedirs(self.location, exist_ok=True)
with open(meta_path, "w") as f:
f.write(json.dumps({"collections": {}, "aliases": {}}))
else:
with open(meta_path, "r") as f:
meta = json.load(f)
for collection_name, config_json in meta["collections"].items():
for key in deprecated_config_fields:
config_json.pop(key, None)
config = rest_models.CreateCollection(**config_json)
collection_path = self._collection_path(collection_name)
collection = LocalCollection(
config,
collection_path,
force_disable_check_same_thread=self.force_disable_check_same_thread,
)
self.collections[collection_name] = collection
if len(collection.ids) > self.LARGE_DATA_THRESHOLD:
show_warning(
f"Local mode is not recommended for collections with more than {self.LARGE_DATA_THRESHOLD:,} points. Collection <{collection_name}> contains {len(collection.ids)} points. Consider using Qdrant in Docker or Qdrant Cloud for better performance with large datasets.",
category=UserWarning,
stacklevel=5,
)
self.aliases = meta["aliases"]
lock_file_path = os.path.join(self.location, ".lock")
if not os.path.exists(lock_file_path):
os.makedirs(self.location, exist_ok=True)
with open(lock_file_path, "w") as f:
f.write("tmp lock file")
self._flock_file = open(lock_file_path, "r+")
try:
portalocker.lock(
self._flock_file,
portalocker.LockFlags.EXCLUSIVE | portalocker.LockFlags.NON_BLOCKING,
)
except portalocker.exceptions.LockException:
raise RuntimeError(
f"Storage folder {self.location} is already accessed by another instance of Qdrant client. If you require concurrent access, use Qdrant server instead."
)
def _save(self) -> None:
if not self.persistent:
return
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
meta_path = os.path.join(self.location, META_INFO_FILENAME)
with open(meta_path, "w") as f:
f.write(
json.dumps(
{
"collections": {
collection_name: to_dict(collection.config)
for (collection_name, collection) in self.collections.items()
},
"aliases": self.aliases,
}
)
)
def _get_collection(self, collection_name: str) -> LocalCollection:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
if collection_name in self.collections:
return self.collections[collection_name]
if collection_name in self.aliases:
return self.collections[self.aliases[collection_name]]
raise ValueError(f"Collection {collection_name} not found")
def search(
self,
collection_name: str,
query_vector: Union[
types.NumpyArray,
Sequence[float],
tuple[str, list[float]],
types.NamedVector,
types.NamedSparseVector,
],
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
offset: Optional[int] = None,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
**kwargs: Any,
) -> list[types.ScoredPoint]:
collection = self._get_collection(collection_name)
return collection.search(
query_vector=query_vector,
query_filter=query_filter,
limit=limit,
offset=offset,
with_payload=with_payload,
with_vectors=with_vectors,
score_threshold=score_threshold,
)
async def search_matrix_offsets(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixOffsetsResponse:
collection = self._get_collection(collection_name)
return collection.search_matrix_offsets(
query_filter=query_filter, limit=limit, sample=sample, using=using
)
async def search_matrix_pairs(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixPairsResponse:
collection = self._get_collection(collection_name)
return collection.search_matrix_pairs(
query_filter=query_filter, limit=limit, sample=sample, using=using
)
def _resolve_query_input(
self,
collection_name: str,
query: Optional[types.Query],
using: Optional[str],
lookup_from: Optional[types.LookupLocation],
) -> tuple[types.Query, set[types.PointId]]:
"""
Resolves any possible ids into vectors and returns a new query object, along with a set of the mentioned
point ids that should be filtered when searching.
"""
lookup_collection_name = lookup_from.collection if lookup_from else collection_name
collection = self._get_collection(lookup_collection_name)
search_in_vector_name = using if using is not None else DEFAULT_VECTOR_NAME
vector_name = (
lookup_from.vector
if lookup_from is not None and lookup_from.vector is not None
else search_in_vector_name
)
sparse = vector_name in collection.sparse_vectors
multi = vector_name in collection.multivectors
if sparse:
collection_vectors = collection.sparse_vectors
elif multi:
collection_vectors = collection.multivectors
else:
collection_vectors = collection.vectors
mentioned_ids: set[types.PointId] = set()
def input_into_vector(vector_input: types.VectorInput) -> types.VectorInput:
if isinstance(vector_input, get_args(types.PointId)):
if isinstance(vector_input, uuid.UUID):
vector_input = str(vector_input)
point_id = vector_input
if point_id not in collection.ids:
raise ValueError(f"Point {point_id} is not found in the collection")
idx = collection.ids[point_id]
if vector_name in collection_vectors:
vec = collection_vectors[vector_name][idx]
else:
raise ValueError(f"Vector {vector_name} not found")
if isinstance(vec, np.ndarray):
vec = vec.tolist()
if collection_name == lookup_collection_name:
mentioned_ids.add(point_id)
return vec
else:
return vector_input
query = deepcopy(query)
if isinstance(query, rest_models.NearestQuery):
query.nearest = input_into_vector(query.nearest)
elif isinstance(query, rest_models.RecommendQuery):
if query.recommend.negative is not None:
query.recommend.negative = [
input_into_vector(vector_input) for vector_input in query.recommend.negative
]
if query.recommend.positive is not None:
query.recommend.positive = [
input_into_vector(vector_input) for vector_input in query.recommend.positive
]
elif isinstance(query, rest_models.DiscoverQuery):
query.discover.target = input_into_vector(query.discover.target)
pairs = (
query.discover.context
if isinstance(query.discover.context, list)
else [query.discover.context]
)
query.discover.context = [
rest_models.ContextPair(
positive=input_into_vector(pair.positive),
negative=input_into_vector(pair.negative),
)
for pair in pairs
]
elif isinstance(query, rest_models.ContextQuery):
pairs = query.context if isinstance(query.context, list) else [query.context]
query.context = [
rest_models.ContextPair(
positive=input_into_vector(pair.positive),
negative=input_into_vector(pair.negative),
)
for pair in pairs
]
elif isinstance(query, rest_models.OrderByQuery):
pass
elif isinstance(query, rest_models.FusionQuery):
pass
elif isinstance(query, rest_models.RrfQuery):
pass
return (query, mentioned_ids)
def _resolve_prefetches_input(
self,
prefetch: Optional[Union[Sequence[types.Prefetch], types.Prefetch]],
collection_name: str,
) -> list[types.Prefetch]:
if prefetch is None:
return []
if isinstance(prefetch, list) and len(prefetch) == 0:
return []
prefetches = []
if isinstance(prefetch, types.Prefetch):
prefetches = [prefetch]
prefetches.extend(
prefetch.prefetch if isinstance(prefetch.prefetch, list) else [prefetch.prefetch]
)
elif isinstance(prefetch, Sequence):
prefetches = list(prefetch)
return [
self._resolve_prefetch_input(prefetch, collection_name)
for prefetch in prefetches
if prefetch is not None
]
def _resolve_prefetch_input(
self, prefetch: types.Prefetch, collection_name: str
) -> types.Prefetch:
if prefetch.query is None:
return prefetch
prefetch = deepcopy(prefetch)
(query, mentioned_ids) = self._resolve_query_input(
collection_name, prefetch.query, prefetch.using, prefetch.lookup_from
)
prefetch.query = query
prefetch.filter = ignore_mentioned_ids_filter(prefetch.filter, list(mentioned_ids))
prefetch.prefetch = self._resolve_prefetches_input(prefetch.prefetch, collection_name)
return prefetch
async def query_points(
self,
collection_name: str,
query: Optional[types.Query] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
offset: Optional[int] = None,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
lookup_from: Optional[types.LookupLocation] = None,
**kwargs: Any,
) -> types.QueryResponse:
collection = self._get_collection(collection_name)
if query is not None:
(query, mentioned_ids) = self._resolve_query_input(
collection_name, query, using, lookup_from
)
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
prefetch = self._resolve_prefetches_input(prefetch, collection_name)
return collection.query_points(
query=query,
prefetch=prefetch,
query_filter=query_filter,
using=using,
score_threshold=score_threshold,
limit=limit,
offset=offset or 0,
with_payload=with_payload,
with_vectors=with_vectors,
)
async def query_batch_points(
self, collection_name: str, requests: Sequence[types.QueryRequest], **kwargs: Any
) -> list[types.QueryResponse]:
return [
await self.query_points(
collection_name=collection_name,
query=request.query,
prefetch=request.prefetch,
query_filter=request.filter,
limit=request.limit or 10,
offset=request.offset,
with_payload=request.with_payload,
with_vectors=request.with_vector,
score_threshold=request.score_threshold,
using=request.using,
lookup_from=request.lookup_from,
)
for request in requests
]
async def query_points_groups(
self,
collection_name: str,
group_by: str,
query: Union[
types.PointId,
list[float],
list[list[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
types.Image,
types.InferenceObject,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
group_size: int = 3,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
with_lookup: Optional[types.WithLookupInterface] = None,
lookup_from: Optional[types.LookupLocation] = None,
**kwargs: Any,
) -> types.GroupsResult:
collection = self._get_collection(collection_name)
if query is not None:
(query, mentioned_ids) = self._resolve_query_input(
collection_name, query, using, lookup_from
)
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
with_lookup_collection = None
if with_lookup is not None:
if isinstance(with_lookup, str):
with_lookup_collection = self._get_collection(with_lookup)
else:
with_lookup_collection = self._get_collection(with_lookup.collection)
return collection.query_groups(
query=query,
query_filter=query_filter,
using=using,
prefetch=prefetch,
limit=limit,
group_by=group_by,
group_size=group_size,
with_payload=with_payload,
with_vectors=with_vectors,
score_threshold=score_threshold,
with_lookup=with_lookup,
with_lookup_collection=with_lookup_collection,
)
async def scroll(
self,
collection_name: str,
scroll_filter: Optional[types.Filter] = None,
limit: int = 10,
order_by: Optional[types.OrderBy] = None,
offset: Optional[types.PointId] = None,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
**kwargs: Any,
) -> tuple[list[types.Record], Optional[types.PointId]]:
collection = self._get_collection(collection_name)
return collection.scroll(
scroll_filter=scroll_filter,
limit=limit,
order_by=order_by,
offset=offset,
with_payload=with_payload,
with_vectors=with_vectors,
)
async def count(
self,
collection_name: str,
count_filter: Optional[types.Filter] = None,
exact: bool = True,
**kwargs: Any,
) -> types.CountResult:
collection = self._get_collection(collection_name)
return collection.count(count_filter=count_filter)
async def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
**kwargs: Any,
) -> types.FacetResponse:
collection = self._get_collection(collection_name)
return collection.facet(key=key, facet_filter=facet_filter, limit=limit)
async def upsert(
self,
collection_name: str,
points: types.Points,
update_filter: Optional[types.Filter] = None,
**kwargs: Any,
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.upsert(points, update_filter=update_filter)
return self._default_update_result()
async def update_vectors(
self,
collection_name: str,
points: Sequence[types.PointVectors],
update_filter: Optional[types.Filter] = None,
**kwargs: Any,
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.update_vectors(points, update_filter=update_filter)
return self._default_update_result()
async def delete_vectors(
self,
collection_name: str,
vectors: Sequence[str],
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.delete_vectors(vectors, points)
return self._default_update_result()
async def retrieve(
self,
collection_name: str,
ids: Sequence[types.PointId],
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
**kwargs: Any,
) -> list[types.Record]:
collection = self._get_collection(collection_name)
return collection.retrieve(ids, with_payload, with_vectors)
@classmethod
def _default_update_result(cls, operation_id: int = 0) -> types.UpdateResult:
return types.UpdateResult(
operation_id=operation_id, status=rest_models.UpdateStatus.COMPLETED
)
async def delete(
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.delete(points_selector)
return self._default_update_result()
async def set_payload(
self,
collection_name: str,
payload: types.Payload,
points: types.PointsSelector,
key: Optional[str] = None,
**kwargs: Any,
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.set_payload(payload=payload, selector=points, key=key)
return self._default_update_result()
async def overwrite_payload(
self,
collection_name: str,
payload: types.Payload,
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.overwrite_payload(payload=payload, selector=points)
return self._default_update_result()
async def delete_payload(
self,
collection_name: str,
keys: Sequence[str],
points: types.PointsSelector,
**kwargs: Any,
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.delete_payload(keys=keys, selector=points)
return self._default_update_result()
async def clear_payload(
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.clear_payload(selector=points_selector)
return self._default_update_result()
async def batch_update_points(
self,
collection_name: str,
update_operations: Sequence[types.UpdateOperation],
**kwargs: Any,
) -> list[types.UpdateResult]:
collection = self._get_collection(collection_name)
collection.batch_update_points(update_operations)
return [self._default_update_result()] * len(update_operations)
async def update_collection_aliases(
self, change_aliases_operations: Sequence[types.AliasOperations], **kwargs: Any
) -> bool:
for operation in change_aliases_operations:
if isinstance(operation, rest_models.CreateAliasOperation):
self._get_collection(operation.create_alias.collection_name)
self.aliases[operation.create_alias.alias_name] = (
operation.create_alias.collection_name
)
elif isinstance(operation, rest_models.DeleteAliasOperation):
self.aliases.pop(operation.delete_alias.alias_name, None)
elif isinstance(operation, rest_models.RenameAliasOperation):
new_name = operation.rename_alias.new_alias_name
old_name = operation.rename_alias.old_alias_name
self.aliases[new_name] = self.aliases.pop(old_name)
else:
raise ValueError(f"Unknown operation: {operation}")
self._save()
return True
async def get_collection_aliases(
self, collection_name: str, **kwargs: Any
) -> types.CollectionsAliasesResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
for (alias_name, name) in self.aliases.items()
if name == collection_name
]
)
async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
for (alias_name, name) in self.aliases.items()
]
)
async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
return types.CollectionsResponse(
collections=[
rest_models.CollectionDescription(name=name)
for (name, _) in self.collections.items()
]
)
async def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
collection = self._get_collection(collection_name)
return collection.info()
async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
try:
self._get_collection(collection_name)
return True
except ValueError:
return False
async def update_collection(
self,
collection_name: str,
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
metadata: Optional[types.Payload] = None,
**kwargs: Any,
) -> bool:
_collection = self._get_collection(collection_name)
updated = False
if sparse_vectors_config is not None:
for vector_name, vector_params in sparse_vectors_config.items():
_collection.update_sparse_vectors_config(vector_name, vector_params)
updated = True
if metadata is not None:
if _collection.config.metadata is not None:
_collection.config.metadata.update(metadata)
else:
_collection.config.metadata = deepcopy(metadata)
updated = True
self._save()
return updated
def _collection_path(self, collection_name: str) -> Optional[str]:
if self.persistent:
return os.path.join(self.location, "collection", collection_name)
else:
return None
async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
_collection = self.collections.pop(collection_name, None)
del _collection
self.aliases = {
alias_name: name
for (alias_name, name) in self.aliases.items()
if name != collection_name
}
collection_path = self._collection_path(collection_name)
if collection_path is not None:
shutil.rmtree(collection_path, ignore_errors=True)
self._save()
return True
async def create_collection(
self,
collection_name: str,
vectors_config: Optional[
Union[types.VectorParams, Mapping[str, types.VectorParams]]
] = None,
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
metadata: Optional[types.Payload] = None,
**kwargs: Any,
) -> bool:
if self.closed:
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
if collection_name in self.collections:
raise ValueError(f"Collection {collection_name} already exists")
collection_path = self._collection_path(collection_name)
if collection_path is not None:
os.makedirs(collection_path, exist_ok=True)
collection = LocalCollection(
rest_models.CreateCollection(
vectors=vectors_config or {},
sparse_vectors=sparse_vectors_config,
metadata=deepcopy(metadata),
),
location=collection_path,
force_disable_check_same_thread=self.force_disable_check_same_thread,
)
self.collections[collection_name] = collection
self._save()
return True
async def recreate_collection(
self,
collection_name: str,
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
metadata: Optional[types.Payload] = None,
**kwargs: Any,
) -> bool:
await self.delete_collection(collection_name)
return await self.create_collection(
collection_name, vectors_config, sparse_vectors_config, metadata=metadata
)
def upload_points(
self,
collection_name: str,
points: Iterable[types.PointStruct],
update_filter: Optional[types.Filter] = None,
**kwargs: Any,
) -> None:
self._upload_points(collection_name, points, update_filter=update_filter)
def _upload_points(
self,
collection_name: str,
points: Iterable[Union[types.PointStruct, types.Record]],
update_filter: Optional[types.Filter] = None,
) -> None:
collection = self._get_collection(collection_name)
collection.upsert(
[
rest_models.PointStruct(
id=point.id, vector=point.vector or {}, payload=point.payload or {}
)
for point in points
],
update_filter=update_filter,
)
def upload_collection(
self,
collection_name: str,
vectors: Union[
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
],
payload: Optional[Iterable[dict[Any, Any]]] = None,
ids: Optional[Iterable[types.PointId]] = None,
update_filter: Optional[types.Filter] = None,
**kwargs: Any,
) -> None:
def uuid_generator() -> Generator[str, None, None]:
while True:
yield str(uuid4())
collection = self._get_collection(collection_name)
if isinstance(vectors, dict) and any(
(isinstance(v, np.ndarray) for v in vectors.values())
):
assert (
len(set([arr.shape[0] for arr in vectors.values()])) == 1
), "Each named vector should have the same number of vectors"
num_vectors = next(iter(vectors.values())).shape[0]
vectors = [
{name: vectors[name][i].tolist() for name in vectors.keys()}
for i in range(num_vectors)
]
collection.upsert(
[
rest_models.PointStruct(
id=str(point_id) if isinstance(point_id, uuid.UUID) else point_id,
vector=(vector.tolist() if isinstance(vector, np.ndarray) else vector) or {},
payload=payload or {},
)
for (point_id, vector, payload) in zip(
ids or uuid_generator(), iter(vectors), payload or itertools.cycle([{}])
)
],
update_filter=update_filter,
)
async def create_payload_index(
self,
collection_name: str,
field_name: str,
field_schema: Optional[types.PayloadSchemaType] = None,
field_type: Optional[types.PayloadSchemaType] = None,
**kwargs: Any,
) -> types.UpdateResult:
show_warning_once(
message="Payload indexes have no effect in the local Qdrant. Please use server Qdrant if you need payload indexes.",
category=UserWarning,
idx="create-local-payload-indexes",
stacklevel=5,
)
return self._default_update_result()
async def delete_payload_index(
self, collection_name: str, field_name: str, **kwargs: Any
) -> types.UpdateResult:
show_warning_once(
message="Payload indexes have no effect in the local Qdrant. Please use server Qdrant if you need payload indexes.",
category=UserWarning,
idx="delete-local-payload-indexes",
stacklevel=5,
)
return self._default_update_result()
async def list_snapshots(
self, collection_name: str, **kwargs: Any
) -> list[types.SnapshotDescription]:
return []
async def create_snapshot(
self, collection_name: str, **kwargs: Any
) -> Optional[types.SnapshotDescription]:
raise NotImplementedError(
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
)
async def delete_snapshot(
self, collection_name: str, snapshot_name: str, **kwargs: Any
) -> bool:
raise NotImplementedError(
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
)
async def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
return []
async def create_full_snapshot(self, **kwargs: Any) -> types.SnapshotDescription:
raise NotImplementedError(
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
)
async def delete_full_snapshot(self, snapshot_name: str, **kwargs: Any) -> bool:
raise NotImplementedError(
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
)
async def recover_snapshot(self, collection_name: str, location: str, **kwargs: Any) -> bool:
raise NotImplementedError(
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
)
async def list_shard_snapshots(
self, collection_name: str, shard_id: int, **kwargs: Any
) -> list[types.SnapshotDescription]:
return []
async def create_shard_snapshot(
self, collection_name: str, shard_id: int, **kwargs: Any
) -> Optional[types.SnapshotDescription]:
raise NotImplementedError(
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
)
async def delete_shard_snapshot(
self, collection_name: str, shard_id: int, snapshot_name: str, **kwargs: Any
) -> bool:
raise NotImplementedError(
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
)
async def recover_shard_snapshot(
self, collection_name: str, shard_id: int, location: str, **kwargs: Any
) -> bool:
raise NotImplementedError(
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
)
async def create_shard_key(
self,
collection_name: str,
shard_key: types.ShardKey,
shards_number: Optional[int] = None,
replication_factor: Optional[int] = None,
placement: Optional[list[int]] = None,
**kwargs: Any,
) -> bool:
raise NotImplementedError(
"Sharding is not supported in the local Qdrant. Please use server Qdrant if you need sharding."
)
async def delete_shard_key(
self, collection_name: str, shard_key: types.ShardKey, **kwargs: Any
) -> bool:
raise NotImplementedError(
"Sharding is not supported in the local Qdrant. Please use server Qdrant if you need sharding."
)
async def info(self) -> types.VersionInfo:
version = importlib.metadata.version("qdrant-client")
return rest_models.VersionInfo(
title="qdrant - vector search engine", version=version, commit=None
)
async def cluster_collection_update(
self, collection_name: str, cluster_operation: types.ClusterOperations, **kwargs: Any
) -> bool:
raise NotImplementedError(
"Cluster collection update is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
)
async def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo:
raise NotImplementedError(
"Collection cluster info is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
)
async def cluster_status(self) -> types.ClusterStatus:
raise NotImplementedError(
"Cluster status is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
)
async def recover_current_peer(self) -> bool:
raise NotImplementedError(
"Recover current peer is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
)
async def remove_peer(self, peer_id: int, **kwargs: Any) -> bool:
raise NotImplementedError(
"Remove peer info is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
)
@@ -0,0 +1,50 @@
from datetime import datetime, timezone
from typing import Optional
# These are the formats accepted by qdrant core
available_formats = [
"%Y-%m-%dT%H:%M:%S.%f%z",
"%Y-%m-%d %H:%M:%S.%f%z",
"%Y-%m-%dT%H:%M:%S%z",
"%Y-%m-%d %H:%M:%S%z",
"%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%d %H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%d",
]
def parse(date_str: str) -> Optional[datetime]:
"""Parses one section of the date string at a time.
Args:
date_str (str): Accepts any of the formats in qdrant core (see https://github.com/qdrant/qdrant/blob/0ed86ce0575d35930268db19e1f7680287072c58/lib/segment/src/types.rs#L1388-L1410)
Returns:
Optional[datetime]: the datetime if the string is valid, otherwise None
"""
def parse_available_formats(datetime_str: str) -> Optional[datetime]:
for fmt in available_formats:
try:
dt = datetime.strptime(datetime_str, fmt)
if dt.tzinfo is None:
# Assume UTC if no timezone is provided
dt = dt.replace(tzinfo=timezone.utc)
return dt
except ValueError:
pass
return None
parsed_dt = parse_available_formats(date_str)
if parsed_dt is not None:
return parsed_dt
# Python can't parse timezones containing only hours (+HH), but it can parse timezones with hours and minutes
# So we add :00 to the assumed timezone and try parsing it again
# dt examples to handle:
# "2021-01-01 00:00:00.000+01"
# "2021-01-01 00:00:00.000-10"
return parse_available_formats(date_str + ":00")
@@ -0,0 +1,302 @@
from enum import Enum
from typing import Optional, Union
import numpy as np
from qdrant_client.conversions import common_types as types
from qdrant_client.http import models
EPSILON = 1.1920929e-7 # https://doc.rust-lang.org/std/f32/constant.EPSILON.html
# https://github.com/qdrant/qdrant/blob/7164ac4a5987d28f1c93f5712aef8e09e7d93555/lib/segment/src/spaces/simple_avx.rs#L99C10-L99C10
class DistanceOrder(str, Enum):
BIGGER_IS_BETTER = "bigger_is_better"
SMALLER_IS_BETTER = "smaller_is_better"
class RecoQuery:
def __init__(
self,
positive: Optional[list[list[float]]] = None,
negative: Optional[list[list[float]]] = None,
strategy: Optional[models.RecommendStrategy] = None,
):
assert strategy is not None, "Recommend strategy must be provided"
self.strategy = strategy
positive = positive if positive is not None else []
negative = negative if negative is not None else []
self.positive: list[types.NumpyArray] = [np.array(vector) for vector in positive]
self.negative: list[types.NumpyArray] = [np.array(vector) for vector in negative]
assert not np.isnan(self.positive).any(), "Positive vectors must not contain NaN"
assert not np.isnan(self.negative).any(), "Negative vectors must not contain NaN"
class ContextPair:
def __init__(self, positive: list[float], negative: list[float]):
self.positive: types.NumpyArray = np.array(positive)
self.negative: types.NumpyArray = np.array(negative)
assert not np.isnan(self.positive).any(), "Positive vector must not contain NaN"
assert not np.isnan(self.negative).any(), "Negative vector must not contain NaN"
class DiscoveryQuery:
def __init__(self, target: list[float], context: list[ContextPair]):
self.target: types.NumpyArray = np.array(target)
self.context = context
assert not np.isnan(self.target).any(), "Target vector must not contain NaN"
class ContextQuery:
def __init__(self, context_pairs: list[ContextPair]):
self.context_pairs = context_pairs
DenseQueryVector = Union[
DiscoveryQuery,
ContextQuery,
RecoQuery,
]
def distance_to_order(distance: models.Distance) -> DistanceOrder:
"""
Convert distance to order
Args:
distance: distance to convert
Returns:
order
"""
if distance == models.Distance.EUCLID:
return DistanceOrder.SMALLER_IS_BETTER
elif distance == models.Distance.MANHATTAN:
return DistanceOrder.SMALLER_IS_BETTER
return DistanceOrder.BIGGER_IS_BETTER
def cosine_similarity(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
"""
Calculate cosine distance between query and vectors
Args:
query: query vector
vectors: vectors to calculate distance with
Returns:
distances
"""
vectors_norm = np.linalg.norm(vectors, axis=-1)[:, np.newaxis]
vectors /= np.where(vectors_norm != 0.0, vectors_norm, EPSILON)
if len(query.shape) == 1:
query_norm = np.linalg.norm(query)
query /= np.where(query_norm != 0.0, query_norm, EPSILON)
return np.dot(vectors, query)
query_norm = np.linalg.norm(query, axis=-1)[:, np.newaxis]
query /= np.where(query_norm != 0.0, query_norm, EPSILON)
return np.dot(query, vectors.T)
def dot_product(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
"""
Calculate dot product between query and vectors
Args:
query: query vector.
vectors: vectors to calculate distance with
Returns:
distances
"""
if len(query.shape) == 1:
return np.dot(vectors, query)
else:
return np.dot(query, vectors.T)
def euclidean_distance(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
"""
Calculate euclidean distance between query and vectors
Args:
query: query vector.
vectors: vectors to calculate distance with
Returns:
distances
"""
if len(query.shape) == 1:
return np.linalg.norm(vectors - query, axis=-1)
else:
return np.linalg.norm(vectors - query[:, np.newaxis], axis=-1)
def manhattan_distance(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
"""
Calculate manhattan distance between query and vectors
Args:
query: query vector.
vectors: vectors to calculate distance with
Returns:
distances
"""
if len(query.shape) == 1:
return np.sum(np.abs(vectors - query), axis=-1)
else:
return np.sum(np.abs(vectors - query[:, np.newaxis]), axis=-1)
def calculate_distance(
query: types.NumpyArray, vectors: types.NumpyArray, distance_type: models.Distance
) -> types.NumpyArray:
assert not np.isnan(query).any(), "Query vector must not contain NaN"
if distance_type == models.Distance.COSINE:
return cosine_similarity(query, vectors)
elif distance_type == models.Distance.DOT:
return dot_product(query, vectors)
elif distance_type == models.Distance.EUCLID:
return euclidean_distance(query, vectors)
elif distance_type == models.Distance.MANHATTAN:
return manhattan_distance(query, vectors)
else:
raise ValueError(f"Unknown distance type {distance_type}")
def calculate_distance_core(
query: types.NumpyArray, vectors: types.NumpyArray, distance_type: models.Distance
) -> types.NumpyArray:
"""
Calculate same internal distances as in core, rather than the final displayed distance
"""
assert not np.isnan(query).any(), "Query vector must not contain NaN"
if distance_type == models.Distance.EUCLID:
return -np.square(vectors - query, dtype=np.float32).sum(axis=1, dtype=np.float32)
if distance_type == models.Distance.MANHATTAN:
return -np.abs(vectors - query, dtype=np.float32).sum(axis=1, dtype=np.float32)
else:
return calculate_distance(query, vectors, distance_type)
def fast_sigmoid(x: np.float32) -> np.float32:
if np.isnan(x) or np.isinf(x):
# To avoid divisions on NaNs or inf, which gets: RuntimeWarning: invalid value encountered in scalar divide
return x
return x / np.add(1.0, abs(x))
def scaled_fast_sigmoid(x: np.float32) -> np.float32:
return 0.5 * (np.add(fast_sigmoid(x), 1.0))
def calculate_recommend_best_scores(
query: RecoQuery, vectors: types.NumpyArray, distance_type: models.Distance
) -> types.NumpyArray:
def get_best_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
vector_count = vectors.shape[0]
# Get scores to all examples
scores: list[types.NumpyArray] = []
for example in examples:
score = calculate_distance_core(example, vectors, distance_type)
scores.append(score)
# Keep only max for each vector
if len(scores) == 0:
scores.append(np.full(vector_count, -np.inf))
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
return best_scores
pos = get_best_scores(query.positive)
neg = get_best_scores(query.negative)
# Choose from best positive or best negative,
# in in both cases we apply sigmoid and then negate depending on the order
return np.where(
pos > neg,
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
)
def calculate_recommend_sum_scores(
query: RecoQuery, vectors: types.NumpyArray, distance_type: models.Distance
) -> types.NumpyArray:
def get_sum_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
vector_count = vectors.shape[0]
scores: list[types.NumpyArray] = []
for example in examples:
score = calculate_distance_core(example, vectors, distance_type)
scores.append(score)
if len(scores) == 0:
scores.append(np.zeros(vector_count))
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
return sum_scores
pos = get_sum_scores(query.positive)
neg = get_sum_scores(query.negative)
return pos - neg
def calculate_discovery_ranks(
context: list[ContextPair],
vectors: types.NumpyArray,
distance_type: models.Distance,
) -> types.NumpyArray:
overall_ranks = np.zeros(vectors.shape[0], dtype=np.int32)
for pair in context:
# Get distances to positive and negative vectors
pos = calculate_distance_core(pair.positive, vectors, distance_type)
neg = calculate_distance_core(pair.negative, vectors, distance_type)
pair_ranks = np.array(
[
1 if is_bigger else 0 if is_equal else -1
for is_bigger, is_equal in zip(pos > neg, pos == neg)
]
)
overall_ranks += pair_ranks
return overall_ranks
def calculate_discovery_scores(
query: DiscoveryQuery, vectors: types.NumpyArray, distance_type: models.Distance
) -> types.NumpyArray:
ranks = calculate_discovery_ranks(query.context, vectors, distance_type)
# Get distances to target
distances_to_target = calculate_distance_core(query.target, vectors, distance_type)
sigmoided_distances = np.fromiter(
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
)
return ranks + sigmoided_distances
def calculate_context_scores(
query: ContextQuery, vectors: types.NumpyArray, distance_type: models.Distance
) -> types.NumpyArray:
overall_scores = np.zeros(vectors.shape[0], dtype=np.float32)
for pair in query.context_pairs:
# Get distances to positive and negative vectors
pos = calculate_distance_core(pair.positive, vectors, distance_type)
neg = calculate_distance_core(pair.negative, vectors, distance_type)
difference = pos - neg - EPSILON
pair_scores = np.fromiter(
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
)
overall_scores += pair_scores
return overall_scores
@@ -0,0 +1,90 @@
from math import asin, cos, radians, sin, sqrt
# Radius of earth in meters, [as recommended by the IUGG](ftp://athena.fsv.cvut.cz/ZFG/grs80-Moritz.pdf)
MEAN_EARTH_RADIUS = 6371008.8
def geo_distance(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
"""
Calculate distance between two points on Earth using Haversine formula.
Args:
lon1: longitude of first point
lat1: latitude of first point
lon2: longitude of second point
lat2: latitude of second point
Returns:
distance in meters
"""
# convert decimal degrees to radians
lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
# haversine formula
dlon = lon2 - lon1
dlat = lat2 - lat1
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
c = 2 * asin(sqrt(a))
return MEAN_EARTH_RADIUS * c
def test_geo_distance() -> None:
moscow = {"lon": 37.6173, "lat": 55.7558}
london = {"lon": -0.1278, "lat": 51.5074}
berlin = {"lon": 13.4050, "lat": 52.5200}
assert geo_distance(moscow["lon"], moscow["lat"], moscow["lon"], moscow["lat"]) < 1.0
assert geo_distance(moscow["lon"], moscow["lat"], london["lon"], london["lat"]) > 2400 * 1000
assert geo_distance(moscow["lon"], moscow["lat"], london["lon"], london["lat"]) < 2600 * 1000
assert geo_distance(moscow["lon"], moscow["lat"], berlin["lon"], berlin["lat"]) > 1600 * 1000
assert geo_distance(moscow["lon"], moscow["lat"], berlin["lon"], berlin["lat"]) < 1650 * 1000
def boolean_point_in_polygon(
point: tuple[float, float],
exterior: list[tuple[float, float]],
interiors: list[list[tuple[float, float]]],
) -> bool:
inside_poly = False
if in_ring(point, exterior, True):
in_hole = False
k = 0
while k < len(interiors) and not in_hole:
if in_ring(point, interiors[k], False):
in_hole = True
k += 1
if not in_hole:
inside_poly = True
return inside_poly
def in_ring(
pt: tuple[float, float], ring: list[tuple[float, float]], ignore_boundary: bool
) -> bool:
is_inside = False
if ring[0][0] == ring[len(ring) - 1][0] and ring[0][1] == ring[len(ring) - 1][1]:
ring = ring[0 : len(ring) - 1]
j = len(ring) - 1
for i in range(0, len(ring)):
xi = ring[i][0]
yi = ring[i][1]
xj = ring[j][0]
yj = ring[j][1]
on_boundary = (
(pt[1] * (xi - xj) + yi * (xj - pt[0]) + yj * (pt[0] - xi) == 0)
and ((xi - pt[0]) * (xj - pt[0]) <= 0)
and ((yi - pt[1]) * (yj - pt[1]) <= 0)
)
if on_boundary:
return not ignore_boundary
intersect = ((yi > pt[1]) != (yj > pt[1])) and (
pt[0] < (xj - xi) * (pt[1] - yi) / (yj - yi) + xi
)
if intersect:
is_inside = not is_inside
j = i
return is_inside
@@ -0,0 +1,151 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
class JsonPathItemType(str, Enum):
KEY = "key"
INDEX = "index"
WILDCARD_INDEX = "wildcard_index"
class JsonPathItem(BaseModel):
item_type: JsonPathItemType
index: Optional[int] = (
None # split into index and key instead of using Union, because pydantic coerces
)
# int to str even in case of Union[int, str]. Tested with pydantic==1.10.14
key: Optional[str] = None
def parse_json_path(key: str) -> list[JsonPathItem]:
"""Parse and validate json path
Args:
key: json path
Returns:
list[JsonPathItem]: json path split into separate keys
Raises:
ValueError: if json path is invalid or empty
Examples:
# >>> parse_json_path("a[0][1].b")
# [
# JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='a'),
# JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=0),
# JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=1),
# JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='b')
# ]
"""
keys = []
json_path = key
while json_path:
json_path_item, rest = match_quote(json_path)
if json_path_item is None:
json_path_item, rest = match_key(json_path)
if json_path_item is None:
raise ValueError("Invalid path")
keys.append(json_path_item)
brackets_chunks, rest = match_brackets(rest)
keys.extend(brackets_chunks)
json_path = trunk_sep(rest)
if not json_path:
return keys
continue
raise ValueError("Invalid path")
def trunk_sep(path: str) -> str:
if not path:
return path
if len(path) == 1:
raise ValueError("Invalid path")
if path.startswith("."):
return path[1:]
elif path.startswith("["):
return path
else:
raise ValueError("Invalid path")
def match_quote(path: str) -> tuple[Optional[JsonPathItem], str]:
if not path.startswith('"'):
return None, path
left_quote_pos = 0
right_quote_pos = path.find('"', 1)
if path.count('"') < 2:
raise ValueError("Invalid path")
return (
JsonPathItem(
item_type=JsonPathItemType.KEY, key=path[left_quote_pos + 1 : right_quote_pos]
),
path[right_quote_pos + 1 :],
)
def match_key(path: str) -> tuple[Optional[JsonPathItem], str]:
char_counter = 0
for char in path:
if not char.isalnum() and char not in ["_", "-"]:
break
char_counter += 1
if char_counter == 0:
return None, path
return (
JsonPathItem(item_type=JsonPathItemType.KEY, key=path[:char_counter]),
path[char_counter:],
)
def match_brackets(rest: str) -> tuple[list[JsonPathItem], str]:
keys = []
while rest:
json_path_item, rest = _match_brackets(rest)
if json_path_item is None:
break
keys.append(json_path_item)
return keys, rest
def _match_brackets(path: str) -> tuple[Optional[JsonPathItem], str]:
if "[" not in path or not path.startswith("["):
return None, path
left_bracket_pos = 0
right_bracket_pos = path.find("]", left_bracket_pos + 1)
if right_bracket_pos == -1:
raise ValueError("Invalid path")
if right_bracket_pos == (left_bracket_pos + 1):
return (
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX),
path[right_bracket_pos + 1 :],
)
try:
index = int(path[left_bracket_pos + 1 : right_bracket_pos])
return (
JsonPathItem(item_type=JsonPathItemType.INDEX, index=index),
path[right_bracket_pos + 1 :],
)
except ValueError as e:
raise ValueError("Invalid path") from e
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,220 @@
from typing import Optional, Union, Any
import numpy as np
from qdrant_client.http import models
from qdrant_client.conversions import common_types as types
from qdrant_client.local.distances import (
calculate_distance,
scaled_fast_sigmoid,
EPSILON,
fast_sigmoid,
)
class MultiRecoQuery:
def __init__(
self,
positive: Optional[list[list[list[float]]]] = None, # list of matrices
negative: Optional[list[list[list[float]]]] = None, # list of matrices
strategy: Optional[models.RecommendStrategy] = None,
):
assert strategy is not None, "Recommend strategy must be provided"
self.strategy = strategy
positive = positive if positive is not None else []
negative = negative if negative is not None else []
for vector in positive:
assert not np.isnan(vector).any(), "Positive vectors must not contain NaN"
for vector in negative:
assert not np.isnan(vector).any(), "Negative vectors must not contain NaN"
self.positive: list[types.NumpyArray] = [np.array(vector) for vector in positive]
self.negative: list[types.NumpyArray] = [np.array(vector) for vector in negative]
class MultiContextPair:
def __init__(self, positive: list[list[float]], negative: list[list[float]]):
self.positive: types.NumpyArray = np.array(positive)
self.negative: types.NumpyArray = np.array(negative)
assert not np.isnan(self.positive).any(), "Positive vector must not contain NaN"
assert not np.isnan(self.negative).any(), "Negative vector must not contain NaN"
class MultiDiscoveryQuery:
def __init__(self, target: list[list[float]], context: list[MultiContextPair]):
self.target: types.NumpyArray = np.array(target)
self.context = context
assert not np.isnan(self.target).any(), "Target vector must not contain NaN"
class MultiContextQuery:
def __init__(self, context_pairs: list[MultiContextPair]):
self.context_pairs = context_pairs
MultiQueryVector = Union[
MultiDiscoveryQuery,
MultiContextQuery,
MultiRecoQuery,
]
def calculate_multi_distance(
query_matrix: types.NumpyArray,
matrices: list[types.NumpyArray],
distance_type: models.Distance,
) -> types.NumpyArray:
assert not np.isnan(query_matrix).any(), "Query matrix must not contain NaN"
assert len(query_matrix.shape) == 2, "Query must be a matrix"
distances = calculate_multi_distance_core(query_matrix, matrices, distance_type)
if distance_type == models.Distance.EUCLID:
distances = np.sqrt(np.abs(distances))
elif distance_type == models.Distance.MANHATTAN:
distances = np.abs(distances)
return distances
def calculate_multi_distance_core(
query_matrix: types.NumpyArray,
matrices: list[types.NumpyArray],
distance_type: models.Distance,
) -> types.NumpyArray:
def euclidean(q: types.NumpyArray, m: types.NumpyArray, *_: Any) -> types.NumpyArray:
return -np.square(m - q, dtype=np.float32).sum(axis=-1, dtype=np.float32)
def manhattan(q: types.NumpyArray, m: types.NumpyArray, *_: Any) -> types.NumpyArray:
return -np.abs(m - q, dtype=np.float32).sum(axis=-1, dtype=np.float32)
assert not np.isnan(query_matrix).any(), "Query vector must not contain NaN"
similarities: list[float] = []
# Euclid and Manhattan are the only ones which are calculated differently during candidate selection
# in core, here we make sure to use the same internal similarity function as in core.
if distance_type in [models.Distance.EUCLID, models.Distance.MANHATTAN]:
query_matrix = query_matrix[:, np.newaxis]
dist_func = euclidean if distance_type == models.Distance.EUCLID else manhattan
else:
dist_func = calculate_distance # type: ignore
for matrix in matrices:
sim_matrix = dist_func(query_matrix, matrix, distance_type)
similarity = float(np.sum(np.max(sim_matrix, axis=-1)))
similarities.append(similarity)
return np.array(similarities)
def calculate_multi_recommend_best_scores(
query: MultiRecoQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
) -> types.NumpyArray:
def get_best_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
matrix_count = len(matrices)
# Get scores to all examples
scores: list[types.NumpyArray] = []
for example in examples:
score = calculate_multi_distance_core(example, matrices, distance_type)
scores.append(score)
# Keep only max for each vector
if len(scores) == 0:
scores.append(np.full(matrix_count, -np.inf))
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
return best_scores
pos = get_best_scores(query.positive)
neg = get_best_scores(query.negative)
# Choose from the best positive or the best negative,
# in both cases we apply sigmoid and then negate depending on the order
return np.where(
pos > neg,
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
)
def calculate_multi_recommend_sum_scores(
query: MultiRecoQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
) -> types.NumpyArray:
def get_sum_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
matrix_count = len(matrices)
scores: list[types.NumpyArray] = []
for example in examples:
score = calculate_multi_distance_core(example, matrices, distance_type)
scores.append(score)
if len(scores) == 0:
scores.append(np.zeros(matrix_count))
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
return sum_scores
pos = get_sum_scores(query.positive)
neg = get_sum_scores(query.negative)
return pos - neg
def calculate_multi_discovery_ranks(
context: list[MultiContextPair],
matrices: list[types.NumpyArray],
distance_type: models.Distance,
) -> types.NumpyArray:
overall_ranks: types.NumpyArray = np.zeros(len(matrices), dtype=np.int32)
for pair in context:
# Get distances to positive and negative vectors
pos = calculate_multi_distance_core(pair.positive, matrices, distance_type)
neg = calculate_multi_distance_core(pair.negative, matrices, distance_type)
pair_ranks = np.array(
[
1 if is_bigger else 0 if is_equal else -1
for is_bigger, is_equal in zip(pos > neg, pos == neg)
]
)
overall_ranks += pair_ranks
return overall_ranks
def calculate_multi_discovery_scores(
query: MultiDiscoveryQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
) -> types.NumpyArray:
ranks = calculate_multi_discovery_ranks(query.context, matrices, distance_type)
# Get distances to target
distances_to_target = calculate_multi_distance_core(query.target, matrices, distance_type)
sigmoided_distances = np.fromiter(
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
)
return ranks + sigmoided_distances
def calculate_multi_context_scores(
query: MultiContextQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
) -> types.NumpyArray:
overall_scores: types.NumpyArray = np.zeros(len(matrices), dtype=np.float32)
for pair in query.context_pairs:
# Get distances to positive and negative vectors
pos = calculate_multi_distance_core(pair.positive, matrices, distance_type)
neg = calculate_multi_distance_core(pair.negative, matrices, distance_type)
difference = pos - neg - EPSILON
pair_scores = np.fromiter(
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
)
overall_scores += pair_scores
return overall_scores
@@ -0,0 +1,30 @@
from datetime import datetime
from typing import Optional, Union
from qdrant_client.http.models import OrderValue
from qdrant_client.local.datetime_utils import parse
MICROS_PER_SECOND = 1_000_000
def datetime_to_microseconds(dt: datetime) -> int:
return int(dt.timestamp() * MICROS_PER_SECOND)
def to_order_value(value: Union[None, OrderValue, datetime, str]) -> Optional[OrderValue]:
if value is None:
return None
# check if OrderValue
if isinstance(value, (int, float)):
return value
if isinstance(value, datetime):
return datetime_to_microseconds(value)
if isinstance(value, str):
dt = parse(value)
if dt is not None:
return datetime_to_microseconds(dt)
return None
@@ -0,0 +1,337 @@
from datetime import date, datetime, timezone
from typing import Any, Optional, Union, Dict
from uuid import UUID
import numpy as np
from qdrant_client.http import models
from qdrant_client.local import datetime_utils
from qdrant_client.local.geo import boolean_point_in_polygon, geo_distance
from qdrant_client.local.payload_value_extractor import value_by_key
from qdrant_client.conversions import common_types as types
def get_value_counts(values: list[Any]) -> list[int]:
counts = []
if all(value is None for value in values):
counts.append(0)
else:
for value in values:
if value is None:
counts.append(0)
elif hasattr(value, "__len__") and not isinstance(value, str):
counts.append(len(value))
else:
counts.append(1)
return counts
def check_values_count(condition: models.ValuesCount, values: Optional[list[Any]]) -> bool:
if values is None:
return False
counts = get_value_counts(values)
if condition.lt is not None and all(count >= condition.lt for count in counts):
return False
if condition.lte is not None and all(count > condition.lte for count in counts):
return False
if condition.gt is not None and all(count <= condition.gt for count in counts):
return False
if condition.gte is not None and all(count < condition.gte for count in counts):
return False
return True
def check_geo_radius(condition: models.GeoRadius, values: Any) -> bool:
if isinstance(values, dict) and "lat" in values and "lon" in values:
lat = values["lat"]
lon = values["lon"]
distance = geo_distance(
lon1=lon,
lat1=lat,
lon2=condition.center.lon,
lat2=condition.center.lat,
)
return distance < condition.radius
return False
def check_geo_bounding_box(condition: models.GeoBoundingBox, values: Any) -> bool:
if isinstance(values, dict) and "lat" in values and "lon" in values:
lat = values["lat"]
lon = values["lon"]
# handle anti-meridian crossing case
if condition.top_left.lon > condition.bottom_right.lon:
longitude_condition = (
condition.top_left.lon <= lon <= 180 or -180 <= lon <= condition.bottom_right.lon
)
else:
longitude_condition = condition.top_left.lon <= lon <= condition.bottom_right.lon
latitude_condition = condition.top_left.lat >= lat >= condition.bottom_right.lat
return longitude_condition and latitude_condition
return False
def check_geo_polygon(condition: models.GeoPolygon, values: Any) -> bool:
if isinstance(values, dict) and "lat" in values and "lon" in values:
lat = values["lat"]
lon = values["lon"]
exterior = [(point.lat, point.lon) for point in condition.exterior.points]
interiors = []
if condition.interiors is not None:
interiors = [
[(point.lat, point.lon) for point in interior.points]
for interior in condition.interiors
]
return boolean_point_in_polygon(point=(lat, lon), exterior=exterior, interiors=interiors)
return False
def check_range_interface(condition: models.RangeInterface, value: Any) -> bool:
if isinstance(condition, models.Range):
return check_range(condition, value)
if isinstance(condition, models.DatetimeRange):
return check_datetime_range(condition, value)
return False
def check_range(condition: models.Range, value: Any) -> bool:
if not isinstance(value, (int, float)):
return False
return (
(condition.lt is None or value < condition.lt)
and (condition.lte is None or value <= condition.lte)
and (condition.gt is None or value > condition.gt)
and (condition.gte is None or value >= condition.gte)
)
def check_datetime_range(condition: models.DatetimeRange, value: Any) -> bool:
def make_condition_tz_aware(dt: Optional[Union[datetime, date]]) -> Optional[datetime]:
if isinstance(dt, date) and not isinstance(dt, datetime):
dt = datetime.combine(dt, datetime.min.time())
if dt is None or dt.tzinfo is not None:
return dt
# Assume UTC if no timezone is provided
return dt.replace(tzinfo=timezone.utc)
if not isinstance(value, str):
return False
dt = datetime_utils.parse(value)
if dt is None:
return False
condition.lt = make_condition_tz_aware(condition.lt)
condition.lte = make_condition_tz_aware(condition.lte)
condition.gt = make_condition_tz_aware(condition.gt)
condition.gte = make_condition_tz_aware(condition.gte)
return (
(condition.lt is None or dt < condition.lt)
and (condition.lte is None or dt <= condition.lte)
and (condition.gt is None or dt > condition.gt)
and (condition.gte is None or dt >= condition.gte)
)
def check_match(condition: models.Match, value: Any) -> bool:
if isinstance(condition, models.MatchValue):
return value == condition.value
if isinstance(condition, models.MatchText):
return value is not None and condition.text in value
if isinstance(condition, models.MatchTextAny):
return value is not None and any(word in value for word in condition.text_any.split())
if isinstance(condition, models.MatchAny):
return value in condition.any
if isinstance(condition, models.MatchExcept):
return value not in condition.except_
raise ValueError(f"Unknown match condition: {condition}")
def check_nested_filter(nested_filter: models.Filter, values: list[Any]) -> bool:
return any(check_filter(nested_filter, v, point_id=-1, has_vector={}) for v in values)
def check_condition(
condition: models.Condition,
payload: dict[str, Any],
point_id: models.ExtendedPointId,
has_vector: Dict[str, bool],
) -> bool:
if isinstance(condition, models.IsNullCondition):
values = value_by_key(payload, condition.is_null.key, flat=False)
if values is None:
return False
if any(v is None for v in values):
return True
elif isinstance(condition, models.IsEmptyCondition):
values = value_by_key(payload, condition.is_empty.key, flat=False)
if (
values is None
or len(values) == 0
or all((v is None or (isinstance(v, list) and len(v) == 0)) for v in values)
):
return True
elif isinstance(condition, models.HasIdCondition):
ids = [str(id_) if isinstance(id_, UUID) else id_ for id_ in condition.has_id]
if point_id in ids:
return True
elif isinstance(condition, models.HasVectorCondition):
if condition.has_vector in has_vector and has_vector[condition.has_vector]:
return True
elif isinstance(condition, models.FieldCondition):
values = value_by_key(payload, condition.key)
if condition.match is not None:
if values is None:
return False
return any(check_match(condition.match, v) for v in values)
if condition.range is not None:
if values is None:
return False
return any(check_range_interface(condition.range, v) for v in values)
if condition.geo_bounding_box is not None:
if values is None:
return False
return any(check_geo_bounding_box(condition.geo_bounding_box, v) for v in values)
if condition.geo_radius is not None:
if values is None:
return False
return any(check_geo_radius(condition.geo_radius, v) for v in values)
if condition.values_count is not None:
values = value_by_key(payload, condition.key, flat=False)
return check_values_count(condition.values_count, values)
if condition.geo_polygon is not None:
if values is None:
return False
return any(check_geo_polygon(condition.geo_polygon, v) for v in values)
elif isinstance(condition, models.NestedCondition):
values = value_by_key(payload, condition.nested.key)
if values is None:
return False
return check_nested_filter(condition.nested.filter, values)
elif isinstance(condition, models.Filter):
return check_filter(condition, payload, point_id, has_vector)
else:
raise ValueError(f"Unknown condition: {condition}")
return False
def check_must(
conditions: list[models.Condition],
payload: dict,
point_id: models.ExtendedPointId,
has_vector: Dict[str, bool],
) -> bool:
return all(
check_condition(condition, payload, point_id, has_vector) for condition in conditions
)
def check_must_not(
conditions: list[models.Condition],
payload: dict,
point_id: models.ExtendedPointId,
has_vector: Dict[str, bool],
) -> bool:
return all(
not check_condition(condition, payload, point_id, has_vector) for condition in conditions
)
def check_should(
conditions: list[models.Condition],
payload: dict,
point_id: models.ExtendedPointId,
has_vector: Dict[str, bool],
) -> bool:
return any(
check_condition(condition, payload, point_id, has_vector) for condition in conditions
)
def check_min_should(
conditions: list[models.Condition],
payload: dict,
point_id: models.ExtendedPointId,
vectors: Dict[str, Any],
min_count: int,
) -> bool:
return (
sum(check_condition(condition, payload, point_id, vectors) for condition in conditions)
>= min_count
)
def check_filter(
payload_filter: models.Filter,
payload: dict,
point_id: models.ExtendedPointId,
has_vector: Dict[str, bool],
) -> bool:
def ensure_condition_list(
condition: Union[models.Condition, list[models.Condition]],
) -> list[models.Condition]:
if isinstance(condition, list):
return condition
return [condition]
if payload_filter.must is not None:
if not check_must(
ensure_condition_list(payload_filter.must), payload, point_id, has_vector
):
return False
if payload_filter.must_not is not None:
if not check_must_not(
ensure_condition_list(payload_filter.must_not), payload, point_id, has_vector
):
return False
if payload_filter.should is not None:
if not check_should(
ensure_condition_list(payload_filter.should), payload, point_id, has_vector
):
return False
if payload_filter.min_should is not None:
if not check_min_should(
payload_filter.min_should.conditions,
payload,
point_id,
has_vector,
payload_filter.min_should.min_count,
):
return False
return True
def calculate_payload_mask(
payloads: list[dict],
payload_filter: Optional[models.Filter],
ids_inv: list[models.ExtendedPointId],
deleted_per_vector: Dict[str, np.ndarray],
) -> types.NumpyArray:
if payload_filter is None:
return np.ones(len(payloads), dtype=bool)
mask: types.NumpyArray = np.zeros(len(payloads), dtype=bool)
for i, payload in enumerate(payloads):
has_vector = {}
for vector_name, deleted in deleted_per_vector.items():
if not deleted[i]:
has_vector[vector_name] = True
if check_filter(payload_filter, payload, ids_inv[i], has_vector):
mask[i] = True
return mask
@@ -0,0 +1,92 @@
import uuid
from typing import Any, Optional
from qdrant_client.local.json_path_parser import (
JsonPathItem,
JsonPathItemType,
parse_json_path,
)
def value_by_key(payload: dict[str, Any], key: str, flat: bool = True) -> Optional[list[Any]]:
"""
Get value from payload by key.
Args:
payload: arbitrary json-like object
flat: If True, extend list of values. If False, append. By default, we use True and flatten the arrays,
we need it for filters, however for `count` method we need to keep the arrays as is.
key:
Key or path to value in payload.
Examples:
- "name"
- "address.city"
- "location[].name"
- "location[0].name"
Returns:
List of values or None if key not found.
"""
keys = parse_json_path(key)
result = []
def _get_value(data: Any, k_list: list[JsonPathItem]) -> None:
if not k_list:
return
current_key = k_list.pop(0)
if len(k_list) == 0:
if isinstance(data, dict) and current_key.item_type == JsonPathItemType.KEY:
if current_key.key in data:
value = data[current_key.key]
if isinstance(value, list) and flat:
result.extend(value)
else:
result.append(value)
elif isinstance(data, list):
if current_key.item_type == JsonPathItemType.WILDCARD_INDEX:
result.extend(data)
elif current_key.item_type == JsonPathItemType.INDEX:
assert current_key.index is not None
if current_key.index < len(data):
result.append(data[current_key.index])
elif current_key.item_type == JsonPathItemType.KEY:
if not isinstance(data, dict):
return
if current_key.key in data:
_get_value(data[current_key.key], k_list.copy())
elif current_key.item_type == JsonPathItemType.INDEX:
assert current_key.index is not None
if not isinstance(data, list):
return
if current_key.index < len(data):
_get_value(data[current_key.index], k_list.copy())
elif current_key.item_type == JsonPathItemType.WILDCARD_INDEX:
if not isinstance(data, list):
return
for item in data:
_get_value(item, k_list.copy())
_get_value(payload, keys)
return result if result else None
def parse_uuid(value: Any) -> Optional[uuid.UUID]:
"""
Parse UUID from value.
Args:
value: arbitrary value
"""
try:
return uuid.UUID(str(value))
except ValueError:
return None
@@ -0,0 +1,247 @@
from typing import Any, Optional, Type
from qdrant_client.local.json_path_parser import JsonPathItem, JsonPathItemType
def set_value_by_key(payload: dict, keys: list[JsonPathItem], value: Any) -> None:
"""
Set value in payload by key.
Args:
payload: arbitrary json-like object
keys:
list of json path items, e.g.:
[
JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='a'),
JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=0),
JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=1),
JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='b')
]
The original keys could look like this:
- "name"
- "address.city"
- "location[].name"
- "location[0].name"
value: value to set
"""
Setter.set(payload, keys.copy(), value, None, None)
class Setter:
TYPE: Any
SETTERS: dict[JsonPathItemType, Type["Setter"]] = {}
@classmethod
def add_setter(cls, item_type: JsonPathItemType, setter: Type["Setter"]) -> None:
cls.SETTERS[item_type] = setter
@classmethod
def set(
cls,
data: Any,
k_list: list[JsonPathItem],
value: dict[str, Any],
prev_data: Any,
prev_key: Optional[JsonPathItem],
) -> None:
if not k_list:
return
current_key = k_list.pop(0)
cls.SETTERS[current_key.item_type]._set(
data,
current_key,
k_list,
value,
prev_data,
prev_key,
)
@classmethod
def _set(
cls,
data: Any,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
prev_data: Any,
prev_key: Optional[JsonPathItem],
) -> None:
if isinstance(data, cls.TYPE):
cls._set_compatible_types(
data=data, current_key=current_key, k_list=k_list, value=value
)
else:
cls._set_incompatible_types(
current_key=current_key,
k_list=k_list,
value=value,
prev_data=prev_data,
prev_key=prev_key,
)
@classmethod
def _set_compatible_types(
cls,
data: Any,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
) -> None:
raise NotImplementedError()
@classmethod
def _set_incompatible_types(
cls,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
prev_data: Any,
prev_key: Optional[JsonPathItem],
) -> None:
raise NotImplementedError()
class KeySetter(Setter):
TYPE = dict
@classmethod
def _set_compatible_types(
cls,
data: Any,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
) -> None:
if current_key.key not in data:
data[current_key.key] = {}
if len(k_list) == 0:
if isinstance(data[current_key.key], dict):
data[current_key.key].update(value)
else:
data[current_key.key] = value
else:
cls.set(data[current_key.key], k_list.copy(), value, data, current_key)
@classmethod
def _set_incompatible_types(
cls,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
prev_data: Any,
prev_key: Optional[JsonPathItem],
) -> None:
assert prev_key is not None
if len(k_list) == 0:
if prev_key.item_type == JsonPathItemType.KEY:
prev_data[prev_key.key] = {current_key.key: value}
else: # if prev key was WILDCARD, we need to pass INDEX instead with an index set
prev_data[prev_key.index] = {current_key.key: value}
else:
if prev_key.item_type == JsonPathItemType.KEY:
prev_data[prev_key.key] = {current_key.key: {}}
cls.set(
prev_data[prev_key.key][current_key.key],
k_list.copy(),
value,
prev_data[prev_key.key],
current_key,
)
else:
prev_data[prev_key.index] = {current_key.key: {}}
cls.set(
prev_data[prev_key.index][current_key.key],
k_list.copy(),
value,
prev_data[prev_key.index],
current_key,
)
class _ListSetter(Setter):
TYPE = list
@classmethod
def _set_incompatible_types(
cls,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
prev_data: Any,
prev_key: Optional[JsonPathItem],
) -> None:
assert prev_key is not None
if prev_key.item_type == JsonPathItemType.KEY:
prev_data[prev_key.key] = []
return
else:
prev_data[prev_key.index] = []
return
@classmethod
def _set_compatible_types(
cls,
data: Any,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
) -> None:
raise NotImplementedError()
class IndexSetter(_ListSetter):
@classmethod
def _set_compatible_types(
cls,
data: Any,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
) -> None:
assert current_key.index is not None
if current_key.index < len(data):
if len(k_list) == 0:
if isinstance(data[current_key.index], dict):
data[current_key.index].update(value)
else:
data[current_key.index] = value
return
cls.set(data[current_key.index], k_list.copy(), value, data, current_key)
class WildcardIndexSetter(_ListSetter):
@classmethod
def _set_compatible_types(
cls,
data: Any,
current_key: JsonPathItem,
k_list: list[JsonPathItem],
value: dict[str, Any],
) -> None:
if len(k_list) == 0:
for i, item in enumerate(data):
if isinstance(item, dict):
data[i].update(value)
else:
data[i] = value
else:
for i, item in enumerate(data):
cls.set(
item,
k_list.copy(),
value,
data,
JsonPathItem(item_type=JsonPathItemType.INDEX, index=i),
)
Setter.add_setter(JsonPathItemType.KEY, KeySetter)
Setter.add_setter(JsonPathItemType.INDEX, IndexSetter)
Setter.add_setter(JsonPathItemType.WILDCARD_INDEX, WildcardIndexSetter)
@@ -0,0 +1,175 @@
import base64
import dbm
import logging
import pickle
import sqlite3
from pathlib import Path
from typing import Iterable, Optional
from qdrant_client.http import models
STORAGE_FILE_NAME_OLD = "storage.dbm"
STORAGE_FILE_NAME = "storage.sqlite"
def try_migrate_to_sqlite(location: str) -> None:
dbm_path = Path(location) / STORAGE_FILE_NAME_OLD
sql_path = Path(location) / STORAGE_FILE_NAME
if sql_path.exists():
return
if not dbm_path.exists():
return
try:
dbm_storage = dbm.open(str(dbm_path), "c")
con = sqlite3.connect(str(sql_path))
cur = con.cursor()
# Create table
cur.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
for key in dbm_storage.keys():
value = dbm_storage[key]
if isinstance(key, str):
key = key.encode("utf-8")
key = pickle.loads(key)
sqlite_key = CollectionPersistence.encode_key(key)
# Insert a row of data
cur.execute(
"INSERT INTO points VALUES (?, ?)",
(
sqlite_key,
sqlite3.Binary(value),
),
)
con.commit()
con.close()
dbm_storage.close()
dbm_path.unlink()
except Exception as e:
logging.error("Failed to migrate dbm to sqlite:", e)
logging.error(
"Please try to use previous version of qdrant-client or re-create collection"
)
raise e
class CollectionPersistence:
CHECK_SAME_THREAD: Optional[bool] = None
@classmethod
def encode_key(cls, key: models.ExtendedPointId) -> str:
return base64.b64encode(pickle.dumps(key)).decode("utf-8")
def __init__(self, location: str, force_disable_check_same_thread: bool = False):
"""
Create or load a collection from the local storage.
Args:
location: path to the collection directory.
"""
try_migrate_to_sqlite(location)
self.location = Path(location) / STORAGE_FILE_NAME
self.location.parent.mkdir(exist_ok=True, parents=True)
if self.CHECK_SAME_THREAD is None and force_disable_check_same_thread is False:
with sqlite3.connect(":memory:") as tmp_conn:
# it is unsafe to use `sqlite3.threadsafety` until python3.11 since it was hardcoded to 1, thus we
# need to fetch threadsafe with a query
# THREADSAFE = 0: Threads may not share the module
# THREADSAFE = 1: Threads may share the module, connections and cursors. Default for Linux.
# THREADSAFE = 2: Threads may share the module, but not connections. Default for macOS.
threadsafe = tmp_conn.execute(
"select * from pragma_compile_options where compile_options like 'THREADSAFE=%'"
).fetchone()[0]
self.__class__.CHECK_SAME_THREAD = threadsafe != "THREADSAFE=1"
if force_disable_check_same_thread:
self.__class__.CHECK_SAME_THREAD = False
self.storage = sqlite3.connect(
str(self.location), check_same_thread=self.CHECK_SAME_THREAD # type: ignore
)
self._ensure_table()
def close(self) -> None:
self.storage.close()
def _ensure_table(self) -> None:
cursor = self.storage.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
self.storage.commit()
def persist(self, point: models.PointStruct) -> None:
"""
Persist a point in the local storage.
Args:
point: point to persist
"""
key = self.encode_key(point.id)
value = pickle.dumps(point)
cursor = self.storage.cursor()
# Insert or update by key
cursor.execute(
"INSERT OR REPLACE INTO points VALUES (?, ?)",
(
key,
sqlite3.Binary(value),
),
)
self.storage.commit()
def delete(self, point_id: models.ExtendedPointId) -> None:
"""
Delete a point from the local storage.
Args:
point_id: id of the point to delete
"""
key = self.encode_key(point_id)
cursor = self.storage.cursor()
cursor.execute(
"DELETE FROM points WHERE id = ?",
(key,),
)
self.storage.commit()
def load(self) -> Iterable[models.PointStruct]:
"""
Load a point from the local storage.
Returns:
point: loaded point
"""
cursor = self.storage.cursor()
cursor.execute("SELECT point FROM points")
for row in cursor.fetchall():
yield pickle.loads(row[0])
def test_persistence() -> None:
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
persistence = CollectionPersistence(tmpdir)
point = models.PointStruct(id=1, vector=[1.0, 2.0, 3.0], payload={"a": 1})
persistence.persist(point)
for loaded_point in persistence.load():
assert loaded_point == point
break
del persistence
persistence = CollectionPersistence(tmpdir)
for loaded_point in persistence.load():
assert loaded_point == point
break
persistence.delete(point.id)
persistence.delete(point.id)
for _ in persistence.load():
assert False, "Should not load anything"
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,36 @@
import numpy as np
from qdrant_client.http.models import SparseVector
def empty_sparse_vector() -> SparseVector:
return SparseVector(
indices=[],
values=[],
)
def validate_sparse_vector(vector: SparseVector) -> None:
assert len(vector.indices) == len(
vector.values
), "Indices and values must have the same length"
assert not np.isnan(vector.values).any(), "Values must not contain NaN"
assert len(vector.indices) == len(set(vector.indices)), "Indices must be unique"
def is_sorted(vector: SparseVector) -> bool:
for i in range(1, len(vector.indices)):
if vector.indices[i] < vector.indices[i - 1]:
return False
return True
def sort_sparse_vector(vector: SparseVector) -> SparseVector:
if is_sorted(vector):
return vector
sorted_indices = np.argsort(vector.indices)
return SparseVector(
indices=[vector.indices[i] for i in sorted_indices],
values=[vector.values[i] for i in sorted_indices],
)
@@ -0,0 +1,314 @@
from typing import Callable, Optional, Sequence, Union
import numpy as np
from qdrant_client.conversions import common_types as types
from qdrant_client.http.models import SparseVector
from qdrant_client.local.distances import EPSILON, fast_sigmoid, scaled_fast_sigmoid
from qdrant_client.local.sparse import (
empty_sparse_vector,
is_sorted,
sort_sparse_vector,
validate_sparse_vector,
)
class SparseRecoQuery:
def __init__(
self,
positive: Optional[list[SparseVector]] = None,
negative: Optional[list[SparseVector]] = None,
strategy: Optional[types.RecommendStrategy] = None,
):
assert strategy is not None, "Recommend strategy must be provided"
self.strategy = strategy
positive = positive if positive is not None else []
negative = negative if negative is not None else []
for i, vector in enumerate(positive):
validate_sparse_vector(vector)
positive[i] = sort_sparse_vector(vector)
for i, vector in enumerate(negative):
validate_sparse_vector(vector)
negative[i] = sort_sparse_vector(vector)
self.positive = positive
self.negative = negative
def transform_sparse(
self, foo: Callable[["SparseVector"], "SparseVector"]
) -> "SparseRecoQuery":
return SparseRecoQuery(
positive=[foo(vector) for vector in self.positive],
negative=[foo(vector) for vector in self.negative],
strategy=self.strategy,
)
class SparseContextPair:
def __init__(self, positive: SparseVector, negative: SparseVector):
validate_sparse_vector(positive)
validate_sparse_vector(negative)
self.positive: SparseVector = sort_sparse_vector(positive)
self.negative: SparseVector = sort_sparse_vector(negative)
class SparseDiscoveryQuery:
def __init__(self, target: SparseVector, context: list[SparseContextPair]):
validate_sparse_vector(target)
self.target: SparseVector = sort_sparse_vector(target)
self.context = context
def transform_sparse(
self, foo: Callable[["SparseVector"], "SparseVector"]
) -> "SparseDiscoveryQuery":
return SparseDiscoveryQuery(
target=foo(self.target),
context=[
SparseContextPair(foo(pair.positive), foo(pair.negative)) for pair in self.context
],
)
class SparseContextQuery:
def __init__(self, context_pairs: list[SparseContextPair]):
self.context_pairs = context_pairs
def transform_sparse(
self, foo: Callable[["SparseVector"], "SparseVector"]
) -> "SparseContextQuery":
return SparseContextQuery(
context_pairs=[
SparseContextPair(foo(pair.positive), foo(pair.negative))
for pair in self.context_pairs
]
)
SparseQueryVector = Union[
SparseVector,
SparseDiscoveryQuery,
SparseContextQuery,
SparseRecoQuery,
]
def calculate_distance_sparse(
query: SparseVector, vectors: list[SparseVector], empty_is_zero: bool = False
) -> types.NumpyArray:
"""Calculate distances between a query sparse vector and a list of sparse vectors.
Args:
query (SparseVector): The query sparse vector.
vectors (list[SparseVector]): A list of sparse vectors to compare against.
empty_is_zero (bool): If True, distance between vectors with no overlap is treated as zero.
Otherwise, it is treated as negative infinity.
Simple nearest search requires `empty_is_zero` to be False, while methods like
recommend, discovery, and context search require True.
"""
scores = []
for vector in vectors:
score = sparse_dot_product(query, vector)
if score is not None:
scores.append(score)
elif not empty_is_zero:
# means no overlap
scores.append(np.float32("-inf"))
else:
scores.append(np.float32(0.0))
return np.array(scores, dtype=np.float32)
# Expects sorted indices
# Returns None if no overlap
def sparse_dot_product(vector1: SparseVector, vector2: SparseVector) -> Optional[np.float32]:
result = 0.0
i, j = 0, 0
overlap = False
assert is_sorted(vector1), "Query sparse vector must be sorted"
assert is_sorted(vector2), "Sparse vector to compare with must be sorted"
while i < len(vector1.indices) and j < len(vector2.indices):
if vector1.indices[i] == vector2.indices[j]:
overlap = True
result += vector1.values[i] * vector2.values[j]
i += 1
j += 1
elif vector1.indices[i] < vector2.indices[j]:
i += 1
else:
j += 1
if overlap:
return np.float32(result)
else:
return None
def calculate_sparse_discovery_ranks(
context: list[SparseContextPair],
vectors: list[SparseVector],
) -> types.NumpyArray:
overall_ranks: types.NumpyArray = np.zeros(len(vectors), dtype=np.int32)
for pair in context:
# Get distances to positive and negative vectors
pos = calculate_distance_sparse(pair.positive, vectors, empty_is_zero=True)
neg = calculate_distance_sparse(pair.negative, vectors, empty_is_zero=True)
pair_ranks = np.array(
[
1 if is_bigger else 0 if is_equal else -1
for is_bigger, is_equal in zip(pos > neg, pos == neg)
]
)
overall_ranks += pair_ranks
return overall_ranks
def calculate_sparse_discovery_scores(
query: SparseDiscoveryQuery, vectors: list[SparseVector]
) -> types.NumpyArray:
ranks = calculate_sparse_discovery_ranks(query.context, vectors)
# Get distances to target
distances_to_target = calculate_distance_sparse(query.target, vectors, empty_is_zero=True)
sigmoided_distances = np.fromiter(
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
)
return ranks + sigmoided_distances
def calculate_sparse_context_scores(
query: SparseContextQuery, vectors: list[SparseVector]
) -> types.NumpyArray:
overall_scores: types.NumpyArray = np.zeros(len(vectors), dtype=np.float32)
for pair in query.context_pairs:
# Get distances to positive and negative vectors
pos = calculate_distance_sparse(pair.positive, vectors, empty_is_zero=True)
neg = calculate_distance_sparse(pair.negative, vectors, empty_is_zero=True)
difference = pos - neg - EPSILON
pair_scores = np.fromiter(
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
)
overall_scores += pair_scores
return overall_scores
def calculate_sparse_recommend_best_scores(
query: SparseRecoQuery, vectors: list[SparseVector]
) -> types.NumpyArray:
def get_best_scores(examples: list[SparseVector]) -> types.NumpyArray:
vector_count = len(vectors)
# Get scores to all examples
scores: list[types.NumpyArray] = []
for example in examples:
score = calculate_distance_sparse(example, vectors, empty_is_zero=True)
scores.append(score)
# Keep only max for each vector
if len(scores) == 0:
scores.append(np.full(vector_count, -np.inf))
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
return best_scores
pos = get_best_scores(query.positive)
neg = get_best_scores(query.negative)
# Choose from best positive or best negative,
# in both cases we apply sigmoid and then negate depending on the order
return np.where(
pos > neg,
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
)
def calculate_sparse_recommend_sum_scores(
query: SparseRecoQuery, vectors: list[SparseVector]
) -> types.NumpyArray:
def get_sum_scores(examples: list[SparseVector]) -> types.NumpyArray:
vector_count = len(vectors)
scores: list[types.NumpyArray] = []
for example in examples:
score = calculate_distance_sparse(example, vectors, empty_is_zero=True)
scores.append(score)
if len(scores) == 0:
scores.append(np.zeros(vector_count))
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
return sum_scores
pos = get_sum_scores(query.positive)
neg = get_sum_scores(query.negative)
return pos - neg
# Expects sorted indices
def combine_aggregate(vector1: SparseVector, vector2: SparseVector, op: Callable) -> SparseVector:
result = empty_sparse_vector()
i, j = 0, 0
while i < len(vector1.indices) and j < len(vector2.indices):
if vector1.indices[i] == vector2.indices[j]:
result.indices.append(vector1.indices[i])
result.values.append(op(vector1.values[i], vector2.values[j]))
i += 1
j += 1
elif vector1.indices[i] < vector2.indices[j]:
result.indices.append(vector1.indices[i])
result.values.append(op(vector1.values[i], 0.0))
i += 1
else:
result.indices.append(vector2.indices[j])
result.values.append(op(0.0, vector2.values[j]))
j += 1
while i < len(vector1.indices):
result.indices.append(vector1.indices[i])
result.values.append(op(vector1.values[i], 0.0))
i += 1
while j < len(vector2.indices):
result.indices.append(vector2.indices[j])
result.values.append(op(0.0, vector2.values[j]))
j += 1
return result
# Expects sorted indices
def sparse_avg(vectors: Sequence[SparseVector]) -> SparseVector:
result = empty_sparse_vector()
if len(vectors) == 0:
return result
sparse_count = 0
for vector in vectors:
sparse_count += 1
result = combine_aggregate(result, vector, lambda v1, v2: v1 + v2)
result.values = np.divide(result.values, sparse_count).tolist()
return result
# Expects sorted indices
def merge_positive_and_negative_avg(
positive: SparseVector, negative: SparseVector
) -> SparseVector:
return combine_aggregate(positive, negative, lambda pos, neg: pos + pos - neg)
@@ -0,0 +1,57 @@
from datetime import datetime, timedelta, timezone
import pytest
from qdrant_client.local.datetime_utils import parse
@pytest.mark.parametrize( # type: ignore
"date_str, expected",
[
("2021-01-01T00:00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
("2021-01-01T00:00:00Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
("2021-01-01T00:00:00+00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
("2021-01-01T00:00:00.000000", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
("2021-01-01T00:00:00.000000Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
(
"2021-01-01T00:00:00.000000+01:00",
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=1))),
),
(
"2021-01-01T00:00:00.000000-10:00",
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-10))),
),
("2021-01-01", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
("2021-01-01 00:00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
("2021-01-01 00:00:00Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
(
"2021-01-01 00:00:00+0200",
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=2))),
),
("2021-01-01 00:00:00.000000", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
("2021-01-01 00:00:00.000000Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
(
"2021-01-01 00:00:00.000000+00:30",
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(minutes=30))),
),
(
"2021-01-01 00:00:00.000009+00:30",
datetime(2021, 1, 1, 0, 0, 0, 9, tzinfo=timezone(timedelta(minutes=30))),
),
# this is accepted in core but not here, there is no specifier for only-hour offset
(
"2021-01-01 00:00:00.000+01",
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=1))),
),
(
"2021-01-01 00:00:00.000-10",
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-10))),
),
(
"2021-01-01 00:00:00-03:00",
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-3))),
),
],
)
def test_parse_dates(date_str: str, expected: datetime):
assert parse(date_str) == expected
@@ -0,0 +1,57 @@
import numpy as np
from qdrant_client.http import models
from qdrant_client.local.distances import calculate_distance
from qdrant_client.local.multi_distances import calculate_multi_distance
from qdrant_client.local.sparse_distances import calculate_distance_sparse
def test_distances() -> None:
query = np.array([1.0, 2.0, 3.0])
vectors = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
assert np.allclose(calculate_distance(query, vectors, models.Distance.DOT), [14.0, 14.0])
assert np.allclose(calculate_distance(query, vectors, models.Distance.EUCLID), [0.0, 0.0])
assert np.allclose(calculate_distance(query, vectors, models.Distance.MANHATTAN), [0.0, 0.0])
# cosine modifies vectors inplace
assert np.allclose(calculate_distance(query, vectors, models.Distance.COSINE), [1.0, 1.0])
query = np.array([1.0, 0.0, 1.0])
vectors = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 0.0]])
assert np.allclose(
calculate_distance(query, vectors, models.Distance.DOT), [4.0, 0.0], atol=0.0001
)
assert np.allclose(
calculate_distance(query, vectors, models.Distance.EUCLID),
[2.82842712, 1.7320508],
atol=0.0001,
)
assert np.allclose(
calculate_distance(query, vectors, models.Distance.MANHATTAN),
[4.0, 3.0],
atol=0.0001,
)
# cosine modifies vectors inplace
assert np.allclose(
calculate_distance(query, vectors, models.Distance.COSINE),
[0.75592895, 0.0],
atol=0.0001,
)
sparse_query = models.SparseVector(indices=[1, 2], values=[1, 2])
sparse_vectors = [models.SparseVector(indices=[10, 20], values=[1, 2])]
assert calculate_distance_sparse(sparse_query, sparse_vectors) == [np.float32("-inf")]
sparse_vectors = [
models.SparseVector(indices=[1, 2], values=[3, 4]),
models.SparseVector(indices=[1, 2, 3], values=[1, 2, 3]),
]
assert np.allclose(
calculate_distance_sparse(sparse_query, sparse_vectors), [11.0, 5], atol=0.0001
)
multivector_query = np.array([[1, 2, 3], [3, 4, 5]])
docs = [np.array([[1, 2, 3], [0, 1, 2]])]
assert calculate_multi_distance(multivector_query, docs, models.Distance.DOT)[0] == 40.0
@@ -0,0 +1,189 @@
from qdrant_client.http.models import models
from qdrant_client.local.payload_filters import check_filter
def test_nested_payload_filters():
payload = {
"country": {
"name": "Germany",
"capital": "Berlin",
"cities": [
{
"name": "Berlin",
"population": 3.7,
"location": {
"lon": 13.76116,
"lat": 52.33826,
},
"sightseeing": ["Brandenburg Gate", "Reichstag"],
},
{
"name": "Munich",
"population": 1.5,
"location": {
"lon": 11.57549,
"lat": 48.13743,
},
"sightseeing": ["Marienplatz", "Olympiapark"],
},
{
"name": "Hamburg",
"population": 1.8,
"location": {
"lon": 9.99368,
"lat": 53.55108,
},
"sightseeing": ["Reeperbahn", "Elbphilharmonie"],
},
],
}
}
query = models.Filter(
**{
"must": [
{
"nested": {
"key": "country.cities",
"filter": {
"must": [
{
"key": "population",
"range": {
"gte": 1.0,
},
}
],
"must_not": [{"key": "sightseeing", "values_count": {"gt": 1}}],
},
}
}
]
}
)
res = check_filter(query, payload, 0, has_vector={})
assert res is False
query = models.Filter(
**{
"must": [
{
"nested": {
"key": "country.cities",
"filter": {
"must": [
{
"key": "population",
"range": {
"gte": 1.0,
},
}
]
},
}
}
]
}
)
res = check_filter(query, payload, 0, has_vector={})
assert res is True
query = models.Filter(
**{
"must": [
{
"nested": {
"key": "country.cities",
"filter": {
"must": [
{
"key": "population",
"range": {
"gte": 1.0,
},
},
{"key": "sightseeing", "values_count": {"gt": 2}},
]
},
}
}
]
}
)
res = check_filter(query, payload, 0, has_vector={})
assert res is False
query = models.Filter(
**{
"must": [
{
"nested": {
"key": "country.cities",
"filter": {
"must": [
{
"key": "population",
"range": {
"gte": 9.0,
},
}
]
},
}
}
]
}
)
res = check_filter(query, payload, 0, has_vector={})
assert res is False
def test_geo_polygon_filter_query():
payload = {
"location": [
{
"lon": 70.0,
"lat": 70.0,
},
]
}
query = models.Filter(
**{
"must": [
{
"key": "location",
"geo_polygon": {
"exterior": {
"points": [
{"lon": 55.455868, "lat": 55.495862},
{"lon": 86.455868, "lat": 55.495862},
{"lon": 86.455868, "lat": 86.495862},
{"lon": 55.455868, "lat": 86.495862},
{"lon": 55.455868, "lat": 55.495862},
]
},
},
}
]
}
)
res = check_filter(query, payload, 0, has_vector={})
assert res is True
payload = {
"location": [
{
"lon": 30.693738,
"lat": 30.502165,
},
]
}
res = check_filter(query, payload, 0, has_vector={})
assert res is False
@@ -0,0 +1,549 @@
from typing import Any
import pytest
from qdrant_client.local.json_path_parser import (
JsonPathItem,
JsonPathItemType,
parse_json_path,
)
from qdrant_client.local.payload_value_extractor import value_by_key
from qdrant_client.local.payload_value_setter import set_value_by_key
def test_parse_json_path() -> None:
jp_key = "a"
keys = parse_json_path(jp_key)
assert keys == [JsonPathItem(item_type=JsonPathItemType.KEY, key="a")]
jp_key = "a.b"
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
]
jp_key = 'a."a[b]".c'
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.KEY, key="a[b]"),
JsonPathItem(item_type=JsonPathItemType.KEY, key="c"),
]
jp_key = "a[0]"
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
]
jp_key = "a[0].b"
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
]
jp_key = "a[0].b[1]"
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
]
jp_key = "a[][]"
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX, index=None),
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX, index=None),
]
jp_key = "a[0][1]"
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
]
jp_key = "a[0][1].b"
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
]
jp_key = 'a."k.c"'
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.KEY, key="k.c"),
]
jp_key = 'a."c[][]".b'
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.KEY, key="c[][]"),
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
]
jp_key = 'a."c..q".b'
keys = parse_json_path(jp_key)
assert keys == [
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
JsonPathItem(item_type=JsonPathItemType.KEY, key="c..q"),
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
]
with pytest.raises(ValueError):
jp_key = 'a."k.c'
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = 'a."k.c".'
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = 'a."k.c".[]'
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a.'k.c'"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a["
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a]"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a[]]"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a[][]."
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a[][]b"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = ".a"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a[x]"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = 'a[]""'
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = '""b'
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "[]"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a[.]"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = 'a["1"]'
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = ""
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a..c"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a.c[]b[]"
parse_json_path(jp_key)
with pytest.raises(ValueError):
jp_key = "a.c[].[]"
parse_json_path(jp_key)
def test_value_by_key() -> None:
payload = {
"name": "John",
"age": 25,
"counts": [1, 2, 3],
"address": {
"city": "New York",
},
"location": [
{"name": "home", "counts": [1, 2, 3]},
{"name": "work", "counts": [4, 5, 6]},
],
"nested": [{"empty": []}, {"empty": []}, {"empty": None}],
"the_null": None,
"the": {"nested.key": "cuckoo"},
"double-nest-array": [[1, 2], [3, 4], [5, 6]],
}
# region flat=True
assert value_by_key(payload, "name") == ["John"]
assert value_by_key(payload, "address.city") == ["New York"]
assert value_by_key(payload, "location[].name") == ["home", "work"]
assert value_by_key(payload, "location[0].name") == ["home"]
assert value_by_key(payload, "location[1].name") == ["work"]
assert value_by_key(payload, "location[2].name") is None
assert value_by_key(payload, "location[].name[0]") is None
assert value_by_key(payload, "location[0]") == [{"name": "home", "counts": [1, 2, 3]}]
assert value_by_key(payload, "not_exits") is None
assert value_by_key(payload, "address") == [{"city": "New York"}]
assert value_by_key(payload, "address.city[0]") is None
assert value_by_key(payload, "counts") == [1, 2, 3]
assert value_by_key(payload, "location[].counts") == [1, 2, 3, 4, 5, 6]
assert value_by_key(payload, "nested[].empty") == [None]
assert value_by_key(payload, "the_null") == [None]
assert value_by_key(payload, 'the."nested.key"') == ["cuckoo"]
assert value_by_key(payload, "double-nest-array[][]") == [1, 2, 3, 4, 5, 6]
assert value_by_key(payload, "double-nest-array[0][]") == [1, 2]
assert value_by_key(payload, "double-nest-array[0][0]") == [1]
assert value_by_key(payload, "double-nest-array[0][0]") == [1]
assert value_by_key(payload, "double-nest-array[][1]") == [2, 4, 6]
# endregion
# region flat=False
assert value_by_key(payload, "name", flat=False) == ["John"]
assert value_by_key(payload, "address.city", flat=False) == ["New York"]
assert value_by_key(payload, "location[].name", flat=False) == ["home", "work"]
assert value_by_key(payload, "location[0].name", flat=False) == ["home"]
assert value_by_key(payload, "location[1].name", flat=False) == ["work"]
assert value_by_key(payload, "location[2].name", flat=False) is None
assert value_by_key(payload, "location[].name[0]", flat=False) is None
assert value_by_key(payload, "location[0]", flat=False) == [
{"name": "home", "counts": [1, 2, 3]}
]
assert value_by_key(payload, "not_exist", flat=False) is None
assert value_by_key(payload, "address", flat=False) == [{"city": "New York"}]
assert value_by_key(payload, "address.city[0]", flat=False) is None
assert value_by_key(payload, "counts", flat=False) == [[1, 2, 3]]
assert value_by_key(payload, "location[].counts", flat=False) == [
[1, 2, 3],
[4, 5, 6],
]
assert value_by_key(payload, "nested[].empty", flat=False) == [[], [], None]
assert value_by_key(payload, "the_null", flat=False) == [None]
assert value_by_key(payload, "age.nested.not_exist") is None
# endregion
def test_set_value_by_key() -> None:
# region valid keys
payload: dict[str, Any] = {}
new_value: dict[str, Any] = {}
key = "a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {}}, payload
payload = {"a": {"a": 2}}
new_value = {}
key = "a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"a": 2}}, payload
payload = {"a": {"a": 2}}
new_value = {"b": 3}
key = "a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"a": 2, "b": 3}}, payload
payload = {"a": {"a": 2}}
new_value = {"a": 3}
key = "a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"a": 3}}, payload
payload = {"a": {"a": 2}}
new_value = {"a": 3}
key = "a.a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"a": {"a": 3}}}, payload
payload = {"a": {"a": {"a": 1}}}
new_value = {"b": 2}
key = "a.a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"a": {"a": 1, "b": 2}}}, payload
payload = {"a": {"a": {"a": 1}}}
new_value = {"a": 2}
key = "a.a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"a": {"a": 2}}}, payload
payload = {"a": []}
new_value = {"b": 2}
key = "a[0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": []}, payload
payload = {"a": [{}]}
new_value = {"b": 2}
key = "a[0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [{"b": 2}]}, payload
payload = {"a": [{"a": 1}]}
new_value = {"b": 2}
key = "a[0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [{"a": 1, "b": 2}]}, payload
payload = {"a": [[]]}
new_value = {"b": 2}
key = "a[0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [{"b": 2}]}, payload
payload = {"a": [[]]}
new_value = {"b": 2}
key = "a[1]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [[]]}, payload
payload = {"a": [{"a": []}]}
new_value = {"b": 2}
key = "a[0].a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [{"a": {"b": 2}}]}, payload
payload = {"a": [{"a": []}]}
new_value = {"b": 2}
key = "a[].a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [{"a": {"b": 2}}]}, payload
payload = {"a": [{"a": []}, {"a": []}]}
new_value = {"b": 2}
key = "a[].a"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [{"a": {"b": 2}}, {"a": {"b": 2}}]}, payload
payload = {"a": 1, "b": 2}
new_value = {"c": 3}
key = "c"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": 1, "b": 2, "c": {"c": 3}}, payload
payload = {"a": {"b": {"c": 1}}}
new_value = {"d": 2}
key = "a.b.d"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": {"c": 1, "d": {"d": 2}}}}, payload
payload = {"a": {"b": {"c": 1}}}
new_value = {"c": 2}
key = "a.b"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": {"c": 2}}}, payload
payload = {"a": [{"b": 1}, {"b": 2}]}
new_value = {"c": 3}
key = "a[1]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [{"b": 1}, {"b": 2, "c": 3}]}, payload
payload = {"a": []}
new_value = {"b": {"c": 1}}
key = "a[0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": []}, payload
payload = {"a": {"b": {"c": {"d": {"e": 1}}}}}
new_value = {"f": 2}
key = "a.b.c.d"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": {"c": {"d": {"e": 1, "f": 2}}}}}, payload
payload = {"a": {"b": {"c": 1}}}
new_value = {"d": {"e": 2}}
key = "a.b.c"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": {"c": {"d": {"e": 2}}}}}, payload
payload = {"a": [{"b": 1}]}
new_value = {"c": 2}
key = "a[1]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [{"b": 1}]}, payload
payload = {"a": {"b": [{"c": 1}, {"c": 2}]}}
new_value = {"d": 3}
key = "a.b[0].c"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": [{"c": {"d": 3}}, {"c": 2}]}}, payload
payload = {"a": {"b": {"c": [{"d": 1}]}}}
new_value = {"e": {"f": 2}}
key = "a.b.c[0].d"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": {"c": [{"d": {"e": {"f": 2}}}]}}}, payload
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
new_value = {"c": 3}
key = "a[0][0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2}]]}, payload
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
new_value = {"c": 3}
key = "a[1][0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [[{"b": 1}], [{"b": 2, "c": 3}]]}, payload
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
new_value = {"c": 3}
key = "a[1][1]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [[{"b": 1}], [{"b": 2}]]}, payload
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
new_value = {"c": 3}
key = "a[][0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2, "c": 3}]]}, payload
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
new_value = {"c": 3}
key = "a[][]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2, "c": 3}]]}, payload
payload = {"a": []}
new_value = {"c": 3}
key = 'a."b.c"'
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b.c": {"c": 3}}}, payload
payload = {"a": {"c": [1]}}
new_value = {"a": 1}
key = "a.c[0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"c": [{"a": 1}]}}, payload
payload = {"a": {"c": [1]}}
new_value = {"a": 1}
key = "a.c[0].d"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"c": [{"d": {"a": 1}}]}}, payload
payload = {"": 2}
new_value = {"a": 1}
key = '""'
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"": {"a": 1}}, payload
# endregion
# region exceptions
try:
payload = {"a": []}
new_value = {"c": 3}
key = "a.'b.c'"
set_value_by_key(payload, parse_json_path(key), new_value)
assert False, f"Should've raised an exception due to the key with incorrect quotes: {key}"
except Exception:
assert True
try:
payload = {"a": [{"b": 1}, {"b": 2}]}
new_value = {"c": 3}
key = "a[-1]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert False, "Negative indexation is not supported"
except Exception:
assert True
try:
payload = {"a": [{"b": 1}, {"b": 2}]}
new_value = {"c": 3}
key = "a["
set_value_by_key(payload, parse_json_path(key), new_value)
assert False, f"Should've raised an exception due to the incorrect key: {key}"
except Exception:
assert True
try:
payload = {"a": [{"b": 1}, {"b": 2}]}
new_value = {"c": 3}
key = "a]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert False, f"Should've raise an exception due to the incorrect key: {key}"
except Exception:
assert True
# endregion
# region wrong keys
payload = {"a": []}
new_value = {}
key = "a.b[0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": []}}, payload
payload = {"a": []}
new_value = {}
key = "a.b"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": {}}}, payload
payload = {"a": []}
new_value = {"c": 2}
key = "a.b"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": {"c": 2}}}, payload
payload = {"a": [[{"a": 1}]]}
new_value = {"a": 2}
key = "a.b[0][0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"b": []}}, payload
payload = {"a": {"c": 2}}
new_value = {"a": 1}
key = "a[]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": []}, payload
payload = {"a": {"c": 2}}
new_value = {"a": 1}
key = "a[].b"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": []}, payload
payload = {"a": {"c": [1]}}
new_value = {"a": 1}
key = "a.c[][][0]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"c": [[]]}}, payload
payload = {"a": {"c": [{"d": 1}]}}
new_value = {"a": 1}
key = "a.c[][]"
set_value_by_key(payload, parse_json_path(key), new_value)
assert payload == {"a": {"c": [[]]}}, payload
# endregion
@@ -0,0 +1,135 @@
import copy
import pytest
from qdrant_client.local.qdrant_local import QdrantLocal
from qdrant_client import models
@pytest.fixture(scope="module", autouse=True)
def client():
"""
Sets up multiple collections with a bunch of points
"""
client = QdrantLocal(":memory:")
client.create_collection(
"collection_default",
vectors_config=models.VectorParams(
size=4,
distance=models.Distance.DOT,
),
)
client.create_collection(
"collection_multiple_vectors",
vectors_config={
"": models.VectorParams(
size=4,
distance=models.Distance.DOT,
),
"byte": models.VectorParams(
size=4, distance=models.Distance.DOT, datatype=models.Datatype.UINT8
),
"colbert": models.VectorParams(
size=4,
distance=models.Distance.DOT,
multivector_config=models.MultiVectorConfig(
comparator=models.MultiVectorComparator.MAX_SIM
),
),
},
sparse_vectors_config={"sparse": models.SparseVectorParams()},
)
client.upsert(
"collection_default",
[
models.PointStruct(id=1, vector=[0.25, 0.0, 0.0, 0.0]),
],
)
client.upsert(
"collection_multiple_vectors",
[
models.PointStruct(
id=1,
vector={
"": [0.0, 0.25, 0.0, 0.0],
"byte": [0, 25, 0, 0],
"colbert": [[0.0, 0.25, 0.0, 0.0], [0.0, 0.25, 0.0, 0.0]],
"sparse": models.SparseVector(indices=[1], values=[0.25]),
},
),
],
)
return client
@pytest.mark.parametrize(
"query",
[
models.NearestQuery(nearest=1),
models.RecommendQuery(recommend=models.RecommendInput(positive=[1], negative=[1])),
models.DiscoverQuery(
discover=models.DiscoverInput(
target=1, context=[models.ContextPair(**{"positive": 1, "negative": 1})]
)
),
models.ContextQuery(context=[models.ContextPair(**{"positive": 1, "negative": 1})]),
models.OrderByQuery(order_by=models.OrderBy(key="price", direction=models.Direction.ASC)),
models.FusionQuery(fusion=models.Fusion.RRF),
],
)
@pytest.mark.parametrize(
"using, lookup_from, expected, mentioned",
[
(None, None, [0.25, 0.0, 0.0, 0.0], True),
("", None, [0.25, 0.0, 0.0, 0.0], True),
(
"byte",
models.LookupLocation(collection="collection_multiple_vectors"),
[0, 25, 0, 0],
False,
),
(
"",
models.LookupLocation(collection="collection_multiple_vectors", vector="colbert"),
[[0.0, 0.25, 0.0, 0.0], [0.0, 0.25, 0.0, 0.0]],
False,
),
(
None,
models.LookupLocation(collection="collection_multiple_vectors", vector="sparse"),
models.SparseVector(indices=[1], values=[0.25]),
False,
),
],
)
def test_vector_dereferencing(client, query, using, lookup_from, expected, mentioned):
resolved, mentioned_ids = client._resolve_query_input(
collection_name="collection_default",
query=copy.deepcopy(query),
using=using,
lookup_from=lookup_from,
)
if isinstance(resolved, models.NearestQuery):
assert resolved.nearest == expected
elif isinstance(resolved, models.RecommendQuery):
assert resolved.recommend.positive == [expected]
assert resolved.recommend.negative == [expected]
elif isinstance(resolved, models.DiscoverQuery):
assert resolved.discover.target == expected
assert resolved.discover.context[0].positive == expected
assert resolved.discover.context[0].negative == expected
elif isinstance(resolved, models.ContextQuery):
assert resolved.context[0].positive == expected
assert resolved.context[0].negative == expected
else:
mentioned = False
assert resolved == query
if mentioned:
assert mentioned_ids == {1}
@@ -0,0 +1,21 @@
import random
from qdrant_client import models
from qdrant_client.local.local_collection import LocalCollection, DEFAULT_VECTOR_NAME
def test_get_vectors():
collection = LocalCollection(
models.CreateCollection(
vectors=models.VectorParams(size=2, distance=models.Distance.MANHATTAN)
)
)
collection.upsert(
points=[
models.PointStruct(id=i, vector=[random.random(), random.random()]) for i in range(10)
]
)
assert collection._get_vectors(idx=1, with_vectors=DEFAULT_VECTOR_NAME)
assert collection._get_vectors(idx=2, with_vectors=True)
assert collection._get_vectors(idx=3, with_vectors=False) is None
@@ -0,0 +1 @@
from .migrate import migrate
@@ -0,0 +1,182 @@
import time
from typing import Iterable, Optional
from qdrant_client._pydantic_compat import to_dict, model_fields
from qdrant_client.client_base import QdrantBase
from qdrant_client.http import models
def upload_with_retry(
client: QdrantBase,
collection_name: str,
points: Iterable[models.PointStruct],
max_attempts: int = 3,
pause: float = 3.0,
) -> None:
attempts = 1
while attempts <= max_attempts:
try:
client.upload_points(
collection_name=collection_name,
points=points,
wait=True,
)
return
except Exception as e:
print(f"Exception: {e}, attempt {attempts}/{max_attempts}")
if attempts < max_attempts:
print(f"Next attempt in {pause} seconds")
time.sleep(pause)
attempts += 1
raise Exception(f"Failed to upload points after {max_attempts} attempts")
def migrate(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_names: Optional[list[str]] = None,
recreate_on_collision: bool = False,
batch_size: int = 100,
) -> None:
"""
Migrate collections from source client to destination client
Args:
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
collection_names (list[str], optional): List of collection names to migrate.
If None - migrate all source client collections. Defaults to None.
recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise
raise ValueError.
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
"""
collection_names = _select_source_collections(source_client, collection_names)
if any(
_has_custom_shards(source_client, collection_name) for collection_name in collection_names
):
raise ValueError("Migration of collections with custom shards is not supported yet")
collisions = _find_collisions(dest_client, collection_names)
absent_dest_collections = set(collection_names) - set(collisions)
if collisions and not recreate_on_collision:
raise ValueError(f"Collections already exist in dest_client: {collisions}")
for collection_name in absent_dest_collections:
_recreate_collection(source_client, dest_client, collection_name)
_migrate_collection(source_client, dest_client, collection_name, batch_size)
for collection_name in collisions:
_recreate_collection(source_client, dest_client, collection_name)
_migrate_collection(source_client, dest_client, collection_name, batch_size)
def _has_custom_shards(source_client: QdrantBase, collection_name: str) -> bool:
collection_info = source_client.get_collection(collection_name)
return (
getattr(collection_info.config.params, "sharding_method", None)
== models.ShardingMethod.CUSTOM
)
def _select_source_collections(
source_client: QdrantBase, collection_names: Optional[list[str]] = None
) -> list[str]:
source_collections = source_client.get_collections().collections
source_collection_names = [collection.name for collection in source_collections]
if collection_names is not None:
assert all(
collection_name in source_collection_names for collection_name in collection_names
), f"Source client does not have collections: {set(collection_names) - set(source_collection_names)}"
else:
collection_names = source_collection_names
return collection_names
def _find_collisions(dest_client: QdrantBase, collection_names: list[str]) -> list[str]:
dest_collections = dest_client.get_collections().collections
dest_collection_names = {collection.name for collection in dest_collections}
existing_dest_collections = dest_collection_names & set(collection_names)
return list(existing_dest_collections)
def _recreate_collection(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_name: str,
) -> None:
src_collection_info = source_client.get_collection(collection_name)
src_config = src_collection_info.config
src_payload_schema = src_collection_info.payload_schema
if dest_client.collection_exists(collection_name):
dest_client.delete_collection(collection_name)
strict_mode_config: Optional[models.StrictModeConfig] = None
if src_config.strict_mode_config is not None:
strict_mode_config = models.StrictModeConfig(
**{
k: v
for k, v in to_dict(src_config.strict_mode_config).items()
if k in model_fields(models.StrictModeConfig)
}
)
dest_client.create_collection(
collection_name,
vectors_config=src_config.params.vectors,
sparse_vectors_config=src_config.params.sparse_vectors,
shard_number=src_config.params.shard_number,
replication_factor=src_config.params.replication_factor,
write_consistency_factor=src_config.params.write_consistency_factor,
on_disk_payload=src_config.params.on_disk_payload,
hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)),
optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)),
wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
quantization_config=src_config.quantization_config,
strict_mode_config=strict_mode_config,
)
_recreate_payload_schema(dest_client, collection_name, src_payload_schema)
def _recreate_payload_schema(
dest_client: QdrantBase,
collection_name: str,
payload_schema: dict[str, models.PayloadIndexInfo],
) -> None:
for field_name, field_info in payload_schema.items():
dest_client.create_payload_index(
collection_name,
field_name=field_name,
field_schema=field_info.data_type if field_info.params is None else field_info.params,
)
def _migrate_collection(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_name: str,
batch_size: int = 100,
) -> None:
"""Migrate collection from source client to destination client
Args:
collection_name (str): Collection name
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
"""
records, next_offset = source_client.scroll(collection_name, limit=2, with_vectors=True)
upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
while next_offset is not None:
records, next_offset = source_client.scroll(
collection_name, offset=next_offset, limit=batch_size, with_vectors=True
)
upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
source_client_vectors_count = source_client.count(collection_name).count
dest_client_vectors_count = dest_client.count(collection_name).count
assert (
source_client_vectors_count == dest_client_vectors_count
), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}"
@@ -0,0 +1,3 @@
from qdrant_client.http.models import *
from qdrant_client.fastembed_common import *
from qdrant_client.embed.models import *
@@ -0,0 +1,239 @@
import logging
import os
from collections import defaultdict
from enum import Enum
from multiprocessing import Queue, get_context
from multiprocessing.context import BaseContext
from multiprocessing.process import BaseProcess
from multiprocessing.sharedctypes import Synchronized as BaseValue
from queue import Empty
from typing import Any, Iterable, Optional, Type
# Single item should be processed in less than:
processing_timeout = 10 * 60 # seconds
MAX_INTERNAL_BATCH_SIZE = 200
class QueueSignals(str, Enum):
stop = "stop"
confirm = "confirm"
error = "error"
class Worker:
@classmethod
def start(cls, *args: Any, **kwargs: Any) -> "Worker":
raise NotImplementedError()
def process(self, items: Iterable[Any]) -> Iterable[Any]:
raise NotImplementedError()
def _worker(
worker_class: Type[Worker],
input_queue: Queue,
output_queue: Queue,
num_active_workers: BaseValue,
worker_id: int,
kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
A worker that pulls data pints off the input queue, and places the execution result on the output queue.
When there are no data pints left on the input queue, it decrements
num_active_workers to signal completion.
"""
if kwargs is None:
kwargs = {}
logging.info(f"Reader worker: {worker_id} PID: {os.getpid()}")
try:
worker = worker_class.start(**kwargs)
# Keep going until you get an item that's None.
def input_queue_iterable() -> Iterable[Any]:
while True:
item = input_queue.get()
if item == QueueSignals.stop:
break
yield item
for processed_item in worker.process(input_queue_iterable()):
output_queue.put(processed_item)
except Exception as e: # pylint: disable=broad-except
logging.exception(e)
output_queue.put(QueueSignals.error)
finally:
# It's important that we close and join the queue here before
# decrementing num_active_workers. Otherwise our parent may join us
# before the queue's feeder thread has passed all buffered items to
# the underlying pipe resulting in a deadlock.
#
# See:
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
input_queue.close()
output_queue.close()
input_queue.join_thread()
output_queue.join_thread()
with num_active_workers.get_lock():
num_active_workers.value -= 1
logging.info(f"Reader worker {worker_id} finished")
class ParallelWorkerPool:
def __init__(
self,
num_workers: int,
worker: Type[Worker],
start_method: Optional[str] = None,
max_internal_batch_size: int = MAX_INTERNAL_BATCH_SIZE,
):
self.worker_class = worker
self.num_workers = num_workers
self.input_queue: Optional[Queue] = None
self.output_queue: Optional[Queue] = None
self.ctx: BaseContext = get_context(start_method)
self.processes: list[BaseProcess] = []
self.queue_size = self.num_workers * max_internal_batch_size
self.emergency_shutdown = False
self.num_active_workers: Optional[BaseValue] = None
def start(self, **kwargs: Any) -> None:
self.input_queue = self.ctx.Queue(self.queue_size)
self.output_queue = self.ctx.Queue(self.queue_size)
ctx_value = self.ctx.Value("i", self.num_workers)
assert isinstance(ctx_value, BaseValue)
self.num_active_workers = ctx_value
for worker_id in range(0, self.num_workers):
assert hasattr(self.ctx, "Process")
process = self.ctx.Process(
target=_worker,
args=(
self.worker_class,
self.input_queue,
self.output_queue,
self.num_active_workers,
worker_id,
kwargs.copy(),
),
)
process.start()
self.processes.append(process)
def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
try:
self.start(**kwargs)
assert self.input_queue is not None, "Input queue was not initialized"
assert self.output_queue is not None, "Output queue was not initialized"
pushed = 0
read = 0
for item in stream:
self.check_worker_health()
if pushed - read < self.queue_size:
try:
out_item = self.output_queue.get_nowait()
except Empty:
out_item = None
else:
try:
out_item = self.output_queue.get(timeout=processing_timeout)
except Empty as e:
self.join_or_terminate()
raise e
if out_item is not None:
if out_item == QueueSignals.error:
self.join_or_terminate()
raise RuntimeError("Thread unexpectedly terminated")
yield out_item
read += 1
self.input_queue.put(item)
pushed += 1
for _ in range(self.num_workers):
self.input_queue.put(QueueSignals.stop)
while read < pushed:
out_item = self.output_queue.get(timeout=processing_timeout)
if out_item == QueueSignals.error:
self.join_or_terminate()
raise RuntimeError("Thread unexpectedly terminated")
yield out_item
read += 1
finally:
assert self.input_queue is not None, "Input queue is None"
assert self.output_queue is not None, "Output queue is None"
self.join()
self.input_queue.close()
self.output_queue.close()
if self.emergency_shutdown:
self.input_queue.cancel_join_thread()
self.output_queue.cancel_join_thread()
else:
self.input_queue.join_thread()
self.output_queue.join_thread()
def semi_ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
return self.unordered_map(enumerate(stream), *args, **kwargs)
def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
buffer = defaultdict(int)
next_expected = 0
for idx, item in self.semi_ordered_map(stream, *args, **kwargs):
buffer[idx] = item
while next_expected in buffer:
yield buffer.pop(next_expected)
next_expected += 1
def check_worker_health(self) -> None:
"""
Checks if any worker process has terminated unexpectedly
"""
for process in self.processes:
if not process.is_alive() and process.exitcode != 0:
self.emergency_shutdown = True
self.join_or_terminate()
raise RuntimeError(
f"Worker PID: {process.pid} terminated unexpectedly with code {process.exitcode}"
)
def join_or_terminate(self, timeout: Optional[int] = 1) -> None:
"""
Emergency shutdown
@param timeout:
@return:
"""
self.emergency_shutdown = True
for process in self.processes:
process.join(timeout=timeout)
if process.is_alive():
process.terminate()
self.processes.clear()
def join(self) -> None:
for process in self.processes:
process.join()
self.processes.clear()
def __del__(self) -> None:
"""
Terminate processes if the user hasn't joined. This is necessary as
leaving stray processes running can corrupt shared state. In brief,
we've observed shared memory counters being reused (when the memory was
free from the perspective of the parent process) while the stray
workers still held a reference to them.
For a discussion of using destructors in Python in this manner, see
https://eli.thegreenplace.net/2009/06/12/safely-using-destructors-in-python/.
"""
for process in self.processes:
if process.is_alive():
process.terminate()

Some files were not shown because too many files have changed in this diff Show More