refactor: excel parse
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from .async_qdrant_client import AsyncQdrantClient as AsyncQdrantClient
|
||||
from .qdrant_client import QdrantClient as QdrantClient
|
||||
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
|
||||
if PYDANTIC_V2:
|
||||
import pydantic_core
|
||||
|
||||
to_jsonable_python = pydantic_core.to_jsonable_python
|
||||
else:
|
||||
from pydantic.json import ENCODERS_BY_TYPE
|
||||
|
||||
def to_jsonable_python(x: Any) -> Any:
|
||||
return ENCODERS_BY_TYPE[type(x)](x)
|
||||
|
||||
|
||||
def update_forward_refs(model_class: Type[BaseModel], *args: Any, **kwargs: Any) -> None:
|
||||
if PYDANTIC_V2:
|
||||
model_class.model_rebuild(*args, **kwargs)
|
||||
else:
|
||||
model_class.update_forward_refs(*args, **kwargs)
|
||||
|
||||
|
||||
def construct(model_class: Type[Model], *args: Any, **kwargs: Any) -> Model:
|
||||
if PYDANTIC_V2:
|
||||
return model_class.model_construct(*args, **kwargs)
|
||||
else:
|
||||
return model_class.construct(*args, **kwargs)
|
||||
|
||||
|
||||
def to_dict(model: BaseModel, *args: Any, **kwargs: Any) -> dict[Any, Any]:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump(*args, **kwargs)
|
||||
else:
|
||||
return model.dict(*args, **kwargs)
|
||||
|
||||
|
||||
def model_fields_set(model: BaseModel) -> set:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_fields_set
|
||||
else:
|
||||
return model.__fields_set__
|
||||
|
||||
|
||||
def model_fields(model: Type[BaseModel]) -> dict:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_fields # type: ignore # pydantic type issue
|
||||
else:
|
||||
return model.__fields__
|
||||
|
||||
|
||||
def model_json_schema(model: Type[BaseModel], *args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_json_schema(*args, **kwargs)
|
||||
else:
|
||||
return json.loads(model.schema_json(*args, **kwargs))
|
||||
|
||||
|
||||
def model_config(model: Type[BaseModel]) -> dict[str, Any]:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_config
|
||||
else:
|
||||
return dict(vars(model.__config__))
|
||||
@@ -0,0 +1,386 @@
|
||||
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
|
||||
#
|
||||
# This file is autogenerated. Do not edit it manually.
|
||||
# To regenerate this file, use
|
||||
#
|
||||
# ```
|
||||
# bash -x tools/generate_async_client.sh
|
||||
# ```
|
||||
#
|
||||
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
|
||||
|
||||
from typing import Any, Iterable, Mapping, Optional, Sequence, Union
|
||||
from qdrant_client.conversions import common_types as types
|
||||
|
||||
|
||||
class AsyncQdrantBase:
|
||||
def __init__(self, **kwargs: Any):
|
||||
pass
|
||||
|
||||
async def search_matrix_offsets(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
limit: int = 3,
|
||||
sample: int = 10,
|
||||
using: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.SearchMatrixOffsetsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def search_matrix_pairs(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
limit: int = 3,
|
||||
sample: int = 10,
|
||||
using: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.SearchMatrixPairsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def query_batch_points(
|
||||
self, collection_name: str, requests: Sequence[types.QueryRequest], **kwargs: Any
|
||||
) -> list[types.QueryResponse]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def query_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: Union[
|
||||
types.PointId,
|
||||
list[float],
|
||||
list[list[float]],
|
||||
types.SparseVector,
|
||||
types.Query,
|
||||
types.NumpyArray,
|
||||
types.Document,
|
||||
types.Image,
|
||||
types.InferenceObject,
|
||||
None,
|
||||
] = None,
|
||||
using: Optional[str] = None,
|
||||
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
offset: Optional[int] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
lookup_from: Optional[types.LookupLocation] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.QueryResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def query_points_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
group_by: str,
|
||||
query: Union[
|
||||
types.PointId,
|
||||
list[float],
|
||||
list[list[float]],
|
||||
types.SparseVector,
|
||||
types.Query,
|
||||
types.NumpyArray,
|
||||
types.Document,
|
||||
types.Image,
|
||||
types.InferenceObject,
|
||||
None,
|
||||
] = None,
|
||||
using: Optional[str] = None,
|
||||
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
group_size: int = 3,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
with_lookup: Optional[types.WithLookupInterface] = None,
|
||||
lookup_from: Optional[types.LookupLocation] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.GroupsResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def scroll(
|
||||
self,
|
||||
collection_name: str,
|
||||
scroll_filter: Optional[types.Filter] = None,
|
||||
limit: int = 10,
|
||||
order_by: Optional[types.OrderBy] = None,
|
||||
offset: Optional[types.PointId] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
**kwargs: Any,
|
||||
) -> tuple[list[types.Record], Optional[types.PointId]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def count(
|
||||
self,
|
||||
collection_name: str,
|
||||
count_filter: Optional[types.Filter] = None,
|
||||
exact: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> types.CountResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def facet(
|
||||
self,
|
||||
collection_name: str,
|
||||
key: str,
|
||||
facet_filter: Optional[types.Filter] = None,
|
||||
limit: int = 10,
|
||||
exact: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> types.FacetResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def upsert(
|
||||
self, collection_name: str, points: types.Points, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def update_vectors(
|
||||
self, collection_name: str, points: Sequence[types.PointVectors], **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: Sequence[str],
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Sequence[types.PointId],
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
**kwargs: Any,
|
||||
) -> list[types.Record]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete(
|
||||
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: types.Payload,
|
||||
points: types.PointsSelector,
|
||||
key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def overwrite_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: types.Payload,
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
keys: Sequence[str],
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def clear_payload(
|
||||
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def batch_update_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
update_operations: Sequence[types.UpdateOperation],
|
||||
**kwargs: Any,
|
||||
) -> list[types.UpdateResult]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def update_collection_aliases(
|
||||
self, change_aliases_operations: Sequence[types.AliasOperations], **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_collection_aliases(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> types.CollectionsAliasesResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def update_collection(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def recreate_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def upload_points(
|
||||
self, collection_name: str, points: Iterable[types.PointStruct], **kwargs: Any
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def upload_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: Union[
|
||||
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
|
||||
],
|
||||
payload: Optional[Iterable[dict[Any, Any]]] = None,
|
||||
ids: Optional[Iterable[types.PointId]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_payload_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
field_name: str,
|
||||
field_schema: Optional[types.PayloadSchemaType] = None,
|
||||
field_type: Optional[types.PayloadSchemaType] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete_payload_index(
|
||||
self, collection_name: str, field_name: str, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def list_snapshots(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> list[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_snapshot(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete_snapshot(
|
||||
self, collection_name: str, snapshot_name: str, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_full_snapshot(self, **kwargs: Any) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete_full_snapshot(self, snapshot_name: str, **kwargs: Any) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def recover_snapshot(
|
||||
self, collection_name: str, location: str, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def list_shard_snapshots(
|
||||
self, collection_name: str, shard_id: int, **kwargs: Any
|
||||
) -> list[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, **kwargs: Any
|
||||
) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, snapshot_name: str, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def recover_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, location: str, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def close(self, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def migrate(
|
||||
self,
|
||||
dest_client: "AsyncQdrantBase",
|
||||
collection_names: Optional[list[str]] = None,
|
||||
batch_size: int = 100,
|
||||
recreate_on_collision: bool = False,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_key: types.ShardKey,
|
||||
shards_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
placement: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def delete_shard_key(
|
||||
self, collection_name: str, shard_key: types.ShardKey, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def info(self) -> types.VersionInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def cluster_collection_update(
|
||||
self, collection_name: str, cluster_operation: types.ClusterOperations, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def cluster_status(self) -> types.ClusterStatus:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def recover_current_peer(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def remove_peer(self, peer_id: int, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,842 @@
|
||||
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
|
||||
#
|
||||
# This file is autogenerated. Do not edit it manually.
|
||||
# To regenerate this file, use
|
||||
#
|
||||
# ```
|
||||
# bash -x tools/generate_async_client.sh
|
||||
# ```
|
||||
#
|
||||
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
|
||||
|
||||
import uuid
|
||||
from itertools import tee
|
||||
from typing import Any, Iterable, Optional, Sequence, Union, get_args
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client import grpc
|
||||
from qdrant_client.common.client_warnings import show_warning, show_warning_once
|
||||
from qdrant_client.async_client_base import AsyncQdrantBase
|
||||
from qdrant_client.embed.embedder import Embedder
|
||||
from qdrant_client.embed.model_embedder import ModelEmbedder
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.conversions.conversion import GrpcToRest
|
||||
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
|
||||
from qdrant_client.embed.schema_parser import ModelSchemaParser
|
||||
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion
|
||||
from qdrant_client.fastembed_common import FastEmbedMisc, OnnxProvider
|
||||
from qdrant_client.fastembed_common import (
|
||||
QueryResponse,
|
||||
TextEmbedding,
|
||||
SparseTextEmbedding,
|
||||
IDF_EMBEDDING_MODELS,
|
||||
)
|
||||
|
||||
|
||||
class AsyncQdrantFastembedMixin(AsyncQdrantBase):
|
||||
DEFAULT_EMBEDDING_MODEL = "BAAI/bge-small-en"
|
||||
DEFAULT_BATCH_SIZE = 8
|
||||
_FASTEMBED_INSTALLED: bool
|
||||
|
||||
def __init__(
|
||||
self, parser: ModelSchemaParser, is_local_mode: bool, server_version: Optional[str]
|
||||
):
|
||||
self.__class__._FASTEMBED_INSTALLED = FastEmbedMisc.is_installed()
|
||||
self._embedding_model_name: Optional[str] = None
|
||||
self._sparse_embedding_model_name: Optional[str] = None
|
||||
self._model_embedder = ModelEmbedder(
|
||||
parser=parser, is_local_mode=is_local_mode, server_version=server_version
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
async def list_text_models(cls) -> dict[str, tuple[int, models.Distance]]:
|
||||
"""Lists the supported dense text models.
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
|
||||
"""
|
||||
return FastEmbedMisc.list_text_models()
|
||||
|
||||
@classmethod
|
||||
async def list_image_models(cls) -> dict[str, tuple[int, models.Distance]]:
|
||||
"""Lists the supported image dense models.
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
|
||||
"""
|
||||
return FastEmbedMisc.list_image_models()
|
||||
|
||||
@classmethod
|
||||
async def list_late_interaction_text_models(cls) -> dict[str, tuple[int, models.Distance]]:
|
||||
"""Lists the supported late interaction text models.
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
|
||||
"""
|
||||
return FastEmbedMisc.list_late_interaction_text_models()
|
||||
|
||||
@classmethod
|
||||
async def list_late_interaction_multimodal_models(
|
||||
cls,
|
||||
) -> dict[str, tuple[int, models.Distance]]:
|
||||
"""Lists the supported late interaction multimodal models.
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
|
||||
"""
|
||||
return FastEmbedMisc.list_late_interaction_multimodal_models()
|
||||
|
||||
@classmethod
|
||||
async def list_sparse_models(cls) -> dict[str, dict[str, Any]]:
|
||||
"""Lists the supported sparse text models.
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, Any]]: A dict of model names and their descriptions.
|
||||
"""
|
||||
return FastEmbedMisc.list_sparse_models()
|
||||
|
||||
@property
|
||||
def embedding_model_name(self) -> str:
|
||||
if self._embedding_model_name is None:
|
||||
self._embedding_model_name = self.DEFAULT_EMBEDDING_MODEL
|
||||
return self._embedding_model_name
|
||||
|
||||
@property
|
||||
def sparse_embedding_model_name(self) -> Optional[str]:
|
||||
return self._sparse_embedding_model_name
|
||||
|
||||
def set_model(
|
||||
self,
|
||||
embedding_model_name: str,
|
||||
max_length: Optional[int] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
lazy_load: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Set embedding model to use for encoding documents and queries.
|
||||
|
||||
Args:
|
||||
embedding_model_name: One of the supported embedding models. See `SUPPORTED_EMBEDDING_MODELS` for details.
|
||||
max_length (int, optional): Deprecated. Defaults to None.
|
||||
cache_dir (str, optional): The path to the cache directory.
|
||||
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
||||
Defaults to `fastembed_cache` in the system's temp directory.
|
||||
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
|
||||
providers: The list of onnx providers (with or without options) to use. Defaults to None.
|
||||
Example configuration:
|
||||
https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
|
||||
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
|
||||
Defaults to False.
|
||||
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
|
||||
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
|
||||
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
|
||||
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
|
||||
Raises:
|
||||
ValueError: If embedding model is not supported.
|
||||
ImportError: If fastembed is not installed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if max_length is not None:
|
||||
show_warning(
|
||||
message="max_length parameter is deprecated and will be removed in the future. It's not used by fastembed models.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
self._get_or_init_model(
|
||||
model_name=embedding_model_name,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
providers=providers,
|
||||
cuda=cuda,
|
||||
device_ids=device_ids,
|
||||
lazy_load=lazy_load,
|
||||
deprecated=True,
|
||||
**kwargs,
|
||||
)
|
||||
self._embedding_model_name = embedding_model_name
|
||||
|
||||
def set_sparse_model(
|
||||
self,
|
||||
embedding_model_name: Optional[str],
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
lazy_load: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Set sparse embedding model to use for hybrid search over documents in combination with dense embeddings.
|
||||
|
||||
Args:
|
||||
embedding_model_name: One of the supported sparse embedding models. See `SUPPORTED_SPARSE_EMBEDDING_MODELS` for details.
|
||||
If None, sparse embeddings will not be used.
|
||||
cache_dir (str, optional): The path to the cache directory.
|
||||
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
||||
Defaults to `fastembed_cache` in the system's temp directory.
|
||||
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
|
||||
providers: The list of onnx providers (with or without options) to use. Defaults to None.
|
||||
Example configuration:
|
||||
https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
|
||||
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
|
||||
Defaults to False.
|
||||
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
|
||||
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
|
||||
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
|
||||
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
|
||||
Raises:
|
||||
ValueError: If embedding model is not supported.
|
||||
ImportError: If fastembed is not installed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if embedding_model_name is not None:
|
||||
self._get_or_init_sparse_model(
|
||||
model_name=embedding_model_name,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
providers=providers,
|
||||
cuda=cuda,
|
||||
device_ids=device_ids,
|
||||
lazy_load=lazy_load,
|
||||
deprecated=True,
|
||||
**kwargs,
|
||||
)
|
||||
self._sparse_embedding_model_name = embedding_model_name
|
||||
|
||||
@classmethod
|
||||
def _get_model_params(cls, model_name: str) -> tuple[int, models.Distance]:
|
||||
FastEmbedMisc.import_fastembed()
|
||||
for descriptions in (
|
||||
FastEmbedMisc.list_text_models(),
|
||||
FastEmbedMisc.list_image_models(),
|
||||
FastEmbedMisc.list_late_interaction_text_models(),
|
||||
FastEmbedMisc.list_late_interaction_multimodal_models(),
|
||||
):
|
||||
if params := descriptions.get(model_name):
|
||||
return params
|
||||
if model_name in FastEmbedMisc.list_sparse_models():
|
||||
raise ValueError(
|
||||
"Sparse embeddings do not return fixed embedding size and distance type"
|
||||
)
|
||||
raise ValueError(f"Unsupported embedding model: {model_name}")
|
||||
|
||||
def _get_or_init_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
deprecated: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "TextEmbedding":
|
||||
FastEmbedMisc.import_fastembed()
|
||||
assert isinstance(self._model_embedder.embedder, Embedder)
|
||||
return self._model_embedder.embedder.get_or_init_model(
|
||||
model_name=model_name,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
providers=providers,
|
||||
deprecated=deprecated,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_or_init_sparse_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
deprecated: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "SparseTextEmbedding":
|
||||
FastEmbedMisc.import_fastembed()
|
||||
assert isinstance(self._model_embedder.embedder, Embedder)
|
||||
return self._model_embedder.embedder.get_or_init_sparse_model(
|
||||
model_name=model_name,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
providers=providers,
|
||||
deprecated=deprecated,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _embed_documents(
|
||||
self,
|
||||
documents: Iterable[str],
|
||||
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
|
||||
batch_size: int = 32,
|
||||
embed_type: str = "default",
|
||||
parallel: Optional[int] = None,
|
||||
) -> Iterable[tuple[str, list[float]]]:
|
||||
embedding_model = self._get_or_init_model(model_name=embedding_model_name, deprecated=True)
|
||||
(documents_a, documents_b) = tee(documents, 2)
|
||||
if embed_type == "passage":
|
||||
vectors_iter = embedding_model.passage_embed(
|
||||
documents_a, batch_size=batch_size, parallel=parallel
|
||||
)
|
||||
elif embed_type == "query":
|
||||
vectors_iter = (
|
||||
list(embedding_model.query_embed(query=query))[0] for query in documents_a
|
||||
)
|
||||
elif embed_type == "default":
|
||||
vectors_iter = embedding_model.embed(
|
||||
documents_a, batch_size=batch_size, parallel=parallel
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown embed type: {embed_type}")
|
||||
for vector, doc in zip(vectors_iter, documents_b):
|
||||
yield (doc, vector.tolist())
|
||||
|
||||
def _sparse_embed_documents(
|
||||
self,
|
||||
documents: Iterable[str],
|
||||
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
|
||||
batch_size: int = 32,
|
||||
parallel: Optional[int] = None,
|
||||
) -> Iterable[types.SparseVector]:
|
||||
sparse_embedding_model = self._get_or_init_sparse_model(
|
||||
model_name=embedding_model_name, deprecated=True
|
||||
)
|
||||
vectors_iter = sparse_embedding_model.embed(
|
||||
documents, batch_size=batch_size, parallel=parallel
|
||||
)
|
||||
for sparse_vector in vectors_iter:
|
||||
yield types.SparseVector(
|
||||
indices=sparse_vector.indices.tolist(), values=sparse_vector.values.tolist()
|
||||
)
|
||||
|
||||
def get_vector_field_name(self) -> str:
|
||||
"""
|
||||
Returns name of the vector field in qdrant collection, used by current fastembed model.
|
||||
Returns:
|
||||
Name of the vector field.
|
||||
"""
|
||||
model_name = self.embedding_model_name.split("/")[-1].lower()
|
||||
return f"fast-{model_name}"
|
||||
|
||||
def get_sparse_vector_field_name(self) -> Optional[str]:
|
||||
"""
|
||||
Returns name of the vector field in qdrant collection, used by current fastembed model.
|
||||
Returns:
|
||||
Name of the vector field.
|
||||
"""
|
||||
if self.sparse_embedding_model_name is not None:
|
||||
model_name = self.sparse_embedding_model_name.split("/")[-1].lower()
|
||||
return f"fast-sparse-{model_name}"
|
||||
return None
|
||||
|
||||
def _scored_points_to_query_responses(
|
||||
self, scored_points: list[types.ScoredPoint]
|
||||
) -> list[QueryResponse]:
|
||||
response = []
|
||||
vector_field_name = self.get_vector_field_name()
|
||||
sparse_vector_field_name = self.get_sparse_vector_field_name()
|
||||
for scored_point in scored_points:
|
||||
embedding = (
|
||||
scored_point.vector.get(vector_field_name, None)
|
||||
if isinstance(scored_point.vector, dict)
|
||||
else None
|
||||
)
|
||||
sparse_embedding = None
|
||||
if sparse_vector_field_name is not None:
|
||||
sparse_embedding = (
|
||||
scored_point.vector.get(sparse_vector_field_name, None)
|
||||
if isinstance(scored_point.vector, dict)
|
||||
else None
|
||||
)
|
||||
response.append(
|
||||
QueryResponse(
|
||||
id=scored_point.id,
|
||||
embedding=embedding,
|
||||
sparse_embedding=sparse_embedding,
|
||||
metadata=scored_point.payload,
|
||||
document=scored_point.payload.get("document", ""),
|
||||
score=scored_point.score,
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
def _points_iterator(
|
||||
self,
|
||||
ids: Optional[Iterable[models.ExtendedPointId]],
|
||||
metadata: Optional[Iterable[dict[str, Any]]],
|
||||
encoded_docs: Iterable[tuple[str, list[float]]],
|
||||
ids_accumulator: list,
|
||||
sparse_vectors: Optional[Iterable[types.SparseVector]] = None,
|
||||
) -> Iterable[models.PointStruct]:
|
||||
if ids is None:
|
||||
ids = iter(lambda: uuid.uuid4().hex, None)
|
||||
if metadata is None:
|
||||
metadata = iter(lambda: {}, None)
|
||||
if sparse_vectors is None:
|
||||
sparse_vectors = iter(lambda: None, True)
|
||||
vector_name = self.get_vector_field_name()
|
||||
sparse_vector_name = self.get_sparse_vector_field_name()
|
||||
for idx, meta, (doc, vector), sparse_vector in zip(
|
||||
ids, metadata, encoded_docs, sparse_vectors
|
||||
):
|
||||
ids_accumulator.append(idx)
|
||||
payload = {"document": doc, **meta}
|
||||
point_vector: dict[str, models.Vector] = {vector_name: vector}
|
||||
if sparse_vector_name is not None and sparse_vector is not None:
|
||||
point_vector[sparse_vector_name] = sparse_vector
|
||||
yield models.PointStruct(id=idx, payload=payload, vector=point_vector)
|
||||
|
||||
def _validate_collection_info(self, collection_info: models.CollectionInfo) -> None:
|
||||
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
|
||||
vector_field_name = self.get_vector_field_name()
|
||||
assert isinstance(
|
||||
collection_info.config.params.vectors, dict
|
||||
), f"Collection have incompatible vector params: {collection_info.config.params.vectors}"
|
||||
assert (
|
||||
vector_field_name in collection_info.config.params.vectors
|
||||
), f"Collection have incompatible vector params: {collection_info.config.params.vectors}, expected {vector_field_name}"
|
||||
vector_params = collection_info.config.params.vectors[vector_field_name]
|
||||
assert (
|
||||
embeddings_size == vector_params.size
|
||||
), f"Embedding size mismatch: {embeddings_size} != {vector_params.size}"
|
||||
assert (
|
||||
distance == vector_params.distance
|
||||
), f"Distance mismatch: {distance} != {vector_params.distance}"
|
||||
sparse_vector_field_name = self.get_sparse_vector_field_name()
|
||||
if sparse_vector_field_name is not None:
|
||||
assert (
|
||||
sparse_vector_field_name in collection_info.config.params.sparse_vectors
|
||||
), f"Collection have incompatible vector params: {collection_info.config.params.vectors}"
|
||||
if self.sparse_embedding_model_name in IDF_EMBEDDING_MODELS:
|
||||
modifier = collection_info.config.params.sparse_vectors[
|
||||
sparse_vector_field_name
|
||||
].modifier
|
||||
assert (
|
||||
modifier == models.Modifier.IDF
|
||||
), f"{self.sparse_embedding_model_name} requires modifier IDF, current modifier is {modifier}"
|
||||
|
||||
def get_embedding_size(self, model_name: Optional[str] = None) -> int:
|
||||
"""Get the size of the embeddings produced by the specified model.
|
||||
|
||||
Args:
|
||||
model_name: optional, the name of the model to get the embedding size for. If None, the default model will
|
||||
be used.
|
||||
|
||||
Returns:
|
||||
int: the size of the embeddings produced by the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If sparse model name is passed or model is not found in the supported models.
|
||||
"""
|
||||
model_name = model_name or self.embedding_model_name
|
||||
(embeddings_size, _) = self._get_model_params(model_name=model_name)
|
||||
return embeddings_size
|
||||
|
||||
def get_fastembed_vector_params(
|
||||
self,
|
||||
on_disk: Optional[bool] = None,
|
||||
quantization_config: Optional[models.QuantizationConfig] = None,
|
||||
hnsw_config: Optional[models.HnswConfigDiff] = None,
|
||||
) -> dict[str, models.VectorParams]:
|
||||
"""
|
||||
Generates vector configuration, compatible with fastembed models.
|
||||
|
||||
Args:
|
||||
on_disk: if True, vectors will be stored on disk. If None, default value will be used.
|
||||
quantization_config: Quantization configuration. If None, quantization will be disabled.
|
||||
hnsw_config: HNSW configuration. If None, default configuration will be used.
|
||||
|
||||
Returns:
|
||||
Configuration for `vectors_config` argument in `create_collection` method.
|
||||
"""
|
||||
vector_field_name = self.get_vector_field_name()
|
||||
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
|
||||
return {
|
||||
vector_field_name: models.VectorParams(
|
||||
size=embeddings_size,
|
||||
distance=distance,
|
||||
on_disk=on_disk,
|
||||
quantization_config=quantization_config,
|
||||
hnsw_config=hnsw_config,
|
||||
)
|
||||
}
|
||||
|
||||
def get_fastembed_sparse_vector_params(
|
||||
self, on_disk: Optional[bool] = None, modifier: Optional[models.Modifier] = None
|
||||
) -> Optional[dict[str, models.SparseVectorParams]]:
|
||||
"""
|
||||
Generates vector configuration, compatible with fastembed sparse models.
|
||||
|
||||
Args:
|
||||
on_disk: if True, vectors will be stored on disk. If None, default value will be used.
|
||||
modifier: Sparse vector queries modifier. E.g. Modifier.IDF for idf-based rescoring. Default: None.
|
||||
Returns:
|
||||
Configuration for `vectors_config` argument in `create_collection` method.
|
||||
"""
|
||||
vector_field_name = self.get_sparse_vector_field_name()
|
||||
if self.sparse_embedding_model_name in IDF_EMBEDDING_MODELS:
|
||||
modifier = models.Modifier.IDF if modifier is None else modifier
|
||||
if vector_field_name is None:
|
||||
return None
|
||||
return {
|
||||
vector_field_name: models.SparseVectorParams(
|
||||
index=models.SparseIndexParams(on_disk=on_disk), modifier=modifier
|
||||
)
|
||||
}
|
||||
|
||||
async def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
documents: Iterable[str],
|
||||
metadata: Optional[Iterable[dict[str, Any]]] = None,
|
||||
ids: Optional[Iterable[models.ExtendedPointId]] = None,
|
||||
batch_size: int = 32,
|
||||
parallel: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Union[str, int]]:
|
||||
"""
|
||||
Adds text documents into qdrant collection.
|
||||
If collection does not exist, it will be created with default parameters.
|
||||
Metadata in combination with documents will be added as payload.
|
||||
Documents will be embedded using the specified embedding model.
|
||||
|
||||
If you want to use your own vectors, use `upsert` method instead.
|
||||
|
||||
Args:
|
||||
collection_name (str):
|
||||
Name of the collection to add documents to.
|
||||
documents (Iterable[str]):
|
||||
List of documents to embed and add to the collection.
|
||||
metadata (Iterable[dict[str, Any]], optional):
|
||||
List of metadata dicts. Defaults to None.
|
||||
ids (Iterable[models.ExtendedPointId], optional):
|
||||
List of ids to assign to documents.
|
||||
If not specified, UUIDs will be generated. Defaults to None.
|
||||
batch_size (int, optional):
|
||||
How many documents to embed and upload in single request. Defaults to 32.
|
||||
parallel (Optional[int], optional):
|
||||
How many parallel workers to use for embedding. Defaults to None.
|
||||
If number is specified, data-parallel process will be used.
|
||||
|
||||
Raises:
|
||||
ImportError: If fastembed is not installed.
|
||||
|
||||
Returns:
|
||||
List of IDs of added documents. If no ids provided, UUIDs will be randomly generated on client side.
|
||||
|
||||
"""
|
||||
show_warning_once(
|
||||
"`add` method has been deprecated and will be removed in 1.17. Instead, inference can be done internally within regular methods like `upsert` by wrapping data into `models.Document` or `models.Image`."
|
||||
)
|
||||
encoded_docs = self._embed_documents(
|
||||
documents=documents,
|
||||
embedding_model_name=self.embedding_model_name,
|
||||
batch_size=batch_size,
|
||||
embed_type="passage",
|
||||
parallel=parallel,
|
||||
)
|
||||
encoded_sparse_docs = None
|
||||
if self.sparse_embedding_model_name is not None:
|
||||
encoded_sparse_docs = self._sparse_embed_documents(
|
||||
documents=documents,
|
||||
embedding_model_name=self.sparse_embedding_model_name,
|
||||
batch_size=batch_size,
|
||||
parallel=parallel,
|
||||
)
|
||||
try:
|
||||
collection_info = await self.get_collection(collection_name=collection_name)
|
||||
except Exception:
|
||||
await self.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=self.get_fastembed_vector_params(),
|
||||
sparse_vectors_config=self.get_fastembed_sparse_vector_params(),
|
||||
)
|
||||
collection_info = await self.get_collection(collection_name=collection_name)
|
||||
self._validate_collection_info(collection_info)
|
||||
inserted_ids: list = []
|
||||
points = self._points_iterator(
|
||||
ids=ids,
|
||||
metadata=metadata,
|
||||
encoded_docs=encoded_docs,
|
||||
ids_accumulator=inserted_ids,
|
||||
sparse_vectors=encoded_sparse_docs,
|
||||
)
|
||||
self.upload_points(
|
||||
collection_name=collection_name,
|
||||
points=points,
|
||||
wait=True,
|
||||
parallel=parallel or 1,
|
||||
batch_size=batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
return inserted_ids
|
||||
|
||||
async def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str,
|
||||
query_filter: Optional[models.Filter] = None,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> list[QueryResponse]:
|
||||
"""
|
||||
Search for documents in a collection.
|
||||
This method automatically embeds the query text using the specified embedding model.
|
||||
If you want to use your own query vector, use `search` method instead.
|
||||
|
||||
Args:
|
||||
collection_name: Collection to search in
|
||||
query_text:
|
||||
Text to search for. This text will be embedded using the specified embedding model.
|
||||
And then used as a query vector.
|
||||
query_filter:
|
||||
- Exclude vectors which doesn't fit given conditions.
|
||||
- If `None` - search among all vectors
|
||||
limit: How many results return
|
||||
**kwargs: Additional search parameters. See `qdrant_client.models.QueryRequest` for details.
|
||||
|
||||
Returns:
|
||||
list[types.ScoredPoint]: List of scored points.
|
||||
|
||||
"""
|
||||
show_warning_once(
|
||||
"`query` method has been deprecated and will be removed in 1.17. Instead, inference can be done internally within regular methods like `query_points` by wrapping data into `models.Document` or `models.Image`."
|
||||
)
|
||||
embedding_model_inst = self._get_or_init_model(
|
||||
model_name=self.embedding_model_name, deprecated=True
|
||||
)
|
||||
embeddings = list(embedding_model_inst.query_embed(query=query_text))
|
||||
query_vector = embeddings[0].tolist()
|
||||
if self.sparse_embedding_model_name is None:
|
||||
return self._scored_points_to_query_responses(
|
||||
(
|
||||
await self.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
using=self.get_vector_field_name(),
|
||||
query_filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
**kwargs,
|
||||
)
|
||||
).points
|
||||
)
|
||||
sparse_embedding_model_inst = self._get_or_init_sparse_model(
|
||||
model_name=self.sparse_embedding_model_name, deprecated=True
|
||||
)
|
||||
sparse_vector = list(sparse_embedding_model_inst.query_embed(query=query_text))[0]
|
||||
sparse_query_vector = models.SparseVector(
|
||||
indices=sparse_vector.indices.tolist(), values=sparse_vector.values.tolist()
|
||||
)
|
||||
dense_request = models.QueryRequest(
|
||||
query=query_vector,
|
||||
using=self.get_vector_field_name(),
|
||||
filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
**kwargs,
|
||||
)
|
||||
sparse_request = models.QueryRequest(
|
||||
query=sparse_query_vector,
|
||||
using=self.get_sparse_vector_field_name(),
|
||||
filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
**kwargs,
|
||||
)
|
||||
(dense_request_response, sparse_request_response) = await self.query_batch_points(
|
||||
collection_name=collection_name, requests=[dense_request, sparse_request]
|
||||
)
|
||||
return self._scored_points_to_query_responses(
|
||||
reciprocal_rank_fusion(
|
||||
[dense_request_response.points, sparse_request_response.points], limit=limit
|
||||
)
|
||||
)
|
||||
|
||||
async def query_batch(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_texts: list[str],
|
||||
query_filter: Optional[models.Filter] = None,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> list[list[QueryResponse]]:
|
||||
"""
|
||||
Search for documents in a collection with batched query.
|
||||
This method automatically embeds the query text using the specified embedding model.
|
||||
|
||||
Args:
|
||||
collection_name: Collection to search in
|
||||
query_texts:
|
||||
A list of texts to search for. Each text will be embedded using the specified embedding model.
|
||||
And then used as a query vector for a separate search requests.
|
||||
query_filter:
|
||||
- Exclude vectors which doesn't fit given conditions.
|
||||
- If `None` - search among all vectors
|
||||
This filter will be applied to all search requests.
|
||||
limit: How many results return
|
||||
**kwargs: Additional search parameters. See `qdrant_client.models.QueryRequest` for details.
|
||||
|
||||
Returns:
|
||||
list[list[QueryResponse]]: List of lists of responses for each query text.
|
||||
|
||||
"""
|
||||
show_warning_once(
|
||||
"`query_batch` method has been deprecated and will be removed in 1.17. Instead, inference can be done internally within regular methods like `query_batch_points` by wrapping data into `models.Document` or `models.Image`."
|
||||
)
|
||||
embedding_model_inst = self._get_or_init_model(
|
||||
model_name=self.embedding_model_name, deprecated=True
|
||||
)
|
||||
query_vectors = list(embedding_model_inst.query_embed(query=query_texts))
|
||||
requests = []
|
||||
for vector in query_vectors:
|
||||
request = models.QueryRequest(
|
||||
query=vector.tolist(),
|
||||
using=self.get_vector_field_name(),
|
||||
filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
**kwargs,
|
||||
)
|
||||
requests.append(request)
|
||||
if self.sparse_embedding_model_name is None:
|
||||
responses = await self.query_batch_points(
|
||||
collection_name=collection_name, requests=requests
|
||||
)
|
||||
return [
|
||||
self._scored_points_to_query_responses(response.points) for response in responses
|
||||
]
|
||||
sparse_embedding_model_inst = self._get_or_init_sparse_model(
|
||||
model_name=self.sparse_embedding_model_name, deprecated=True
|
||||
)
|
||||
sparse_query_vectors = [
|
||||
models.SparseVector(
|
||||
indices=sparse_vector.indices.tolist(), values=sparse_vector.values.tolist()
|
||||
)
|
||||
for sparse_vector in sparse_embedding_model_inst.embed(documents=query_texts)
|
||||
]
|
||||
for sparse_vector in sparse_query_vectors:
|
||||
request = models.QueryRequest(
|
||||
using=self.get_sparse_vector_field_name(),
|
||||
query=sparse_vector,
|
||||
filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
**kwargs,
|
||||
)
|
||||
requests.append(request)
|
||||
responses = await self.query_batch_points(
|
||||
collection_name=collection_name, requests=requests
|
||||
)
|
||||
dense_responses = responses[: len(query_texts)]
|
||||
sparse_responses = responses[len(query_texts) :]
|
||||
responses = [
|
||||
reciprocal_rank_fusion([dense_response.points, sparse_response.points], limit=limit)
|
||||
for (dense_response, sparse_response) in zip(dense_responses, sparse_responses)
|
||||
]
|
||||
return [self._scored_points_to_query_responses(response) for response in responses]
|
||||
|
||||
@classmethod
|
||||
def _resolve_query(
|
||||
cls,
|
||||
query: Union[
|
||||
types.PointId,
|
||||
list[float],
|
||||
list[list[float]],
|
||||
types.SparseVector,
|
||||
types.Query,
|
||||
types.NumpyArray,
|
||||
models.Document,
|
||||
models.Image,
|
||||
models.InferenceObject,
|
||||
None,
|
||||
],
|
||||
) -> Optional[models.Query]:
|
||||
"""Resolves query interface into a models.Query object
|
||||
|
||||
Args:
|
||||
query: models.QueryInterface - query as a model or a plain structure like list[float]
|
||||
|
||||
Returns:
|
||||
Optional[models.Query]: query as it was, models.Query(nearest=query) or None
|
||||
|
||||
Raises:
|
||||
ValueError: if query is not of supported type
|
||||
"""
|
||||
if isinstance(query, get_args(types.Query)):
|
||||
return query
|
||||
if isinstance(query, types.SparseVector):
|
||||
return models.NearestQuery(nearest=query)
|
||||
if isinstance(query, np.ndarray):
|
||||
return models.NearestQuery(nearest=query.tolist())
|
||||
if isinstance(query, list):
|
||||
return models.NearestQuery(nearest=query)
|
||||
if isinstance(query, get_args(types.PointId)):
|
||||
query = (
|
||||
GrpcToRest.convert_point_id(query) if isinstance(query, grpc.PointId) else query
|
||||
)
|
||||
return models.NearestQuery(nearest=query)
|
||||
if isinstance(query, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return models.NearestQuery(nearest=query)
|
||||
if query is None:
|
||||
return None
|
||||
raise ValueError(f"Unsupported query type: {type(query)}")
|
||||
|
||||
def _resolve_query_request(self, query: models.QueryRequest) -> models.QueryRequest:
|
||||
"""Resolve QueryRequest query field
|
||||
|
||||
Args:
|
||||
query: models.QueryRequest - query request to resolve
|
||||
|
||||
Returns:
|
||||
models.QueryRequest: A deepcopy of the query request with resolved query field
|
||||
"""
|
||||
query = deepcopy(query)
|
||||
query.query = self._resolve_query(query.query)
|
||||
return query
|
||||
|
||||
def _resolve_query_batch_request(
|
||||
self, requests: Sequence[models.QueryRequest]
|
||||
) -> Sequence[models.QueryRequest]:
|
||||
"""Resolve query field for each query request in a batch
|
||||
|
||||
Args:
|
||||
requests: Sequence[models.QueryRequest] - query requests to resolve
|
||||
|
||||
Returns:
|
||||
Sequence[models.QueryRequest]: A list of deep copied query requests with resolved query fields
|
||||
"""
|
||||
return [self._resolve_query_request(query) for query in requests]
|
||||
|
||||
def _embed_models(
|
||||
self,
|
||||
raw_models: Union[BaseModel, Iterable[BaseModel]],
|
||||
is_query: bool = False,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Iterable[BaseModel]:
|
||||
yield from self._model_embedder.embed_models(
|
||||
raw_models=raw_models,
|
||||
is_query=is_query,
|
||||
batch_size=batch_size or self.DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
|
||||
def _embed_models_strict(
|
||||
self,
|
||||
raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]],
|
||||
batch_size: Optional[int] = None,
|
||||
parallel: Optional[int] = None,
|
||||
) -> Iterable[BaseModel]:
|
||||
yield from self._model_embedder.embed_models_strict(
|
||||
raw_models=raw_models,
|
||||
batch_size=batch_size or self.DEFAULT_BATCH_SIZE,
|
||||
parallel=parallel,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
from qdrant_client.auth.bearer_auth import BearerAuth
|
||||
@@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class BearerAuth(httpx.Auth):
|
||||
def __init__(
|
||||
self,
|
||||
auth_token_provider: Union[Callable[[], str], Callable[[], Awaitable[str]]],
|
||||
):
|
||||
self.async_token: Optional[Callable[[], Awaitable[str]]] = None
|
||||
self.sync_token: Optional[Callable[[], str]] = None
|
||||
|
||||
if asyncio.iscoroutinefunction(auth_token_provider):
|
||||
self.async_token = auth_token_provider
|
||||
else:
|
||||
if callable(auth_token_provider):
|
||||
self.sync_token = auth_token_provider # type: ignore
|
||||
else:
|
||||
raise ValueError("auth_token_provider must be a callable or awaitable")
|
||||
|
||||
def _sync_get_token(self) -> str:
|
||||
if self.sync_token is None:
|
||||
raise ValueError("Synchronous token provider is not set.")
|
||||
return self.sync_token()
|
||||
|
||||
def sync_auth_flow(self, request: httpx.Request) -> httpx.Request:
|
||||
token = self._sync_get_token()
|
||||
request.headers["Authorization"] = f"Bearer {token}"
|
||||
yield request
|
||||
|
||||
async def _async_get_token(self) -> str:
|
||||
if self.async_token is not None:
|
||||
return await self.async_token() # type: ignore
|
||||
# Fallback to synchronous token if asynchronous token is not available
|
||||
return self._sync_get_token()
|
||||
|
||||
async def async_auth_flow(self, request: httpx.Request) -> httpx.Request:
|
||||
token = await self._async_get_token()
|
||||
request.headers["Authorization"] = f"Bearer {token}"
|
||||
yield request
|
||||
@@ -0,0 +1,416 @@
|
||||
from typing import Any, Iterable, Mapping, Optional, Sequence, Union
|
||||
|
||||
from qdrant_client.conversions import common_types as types
|
||||
|
||||
|
||||
class QdrantBase:
|
||||
def __init__(self, **kwargs: Any):
|
||||
pass
|
||||
|
||||
def search_matrix_offsets(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
limit: int = 3,
|
||||
sample: int = 10,
|
||||
using: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.SearchMatrixOffsetsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def search_matrix_pairs(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
limit: int = 3,
|
||||
sample: int = 10,
|
||||
using: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.SearchMatrixPairsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def query_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
requests: Sequence[types.QueryRequest],
|
||||
**kwargs: Any,
|
||||
) -> list[types.QueryResponse]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def query_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: Union[
|
||||
types.PointId,
|
||||
list[float],
|
||||
list[list[float]],
|
||||
types.SparseVector,
|
||||
types.Query,
|
||||
types.NumpyArray,
|
||||
types.Document,
|
||||
types.Image,
|
||||
types.InferenceObject,
|
||||
None,
|
||||
] = None,
|
||||
using: Optional[str] = None,
|
||||
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
offset: Optional[int] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
lookup_from: Optional[types.LookupLocation] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.QueryResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def query_points_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
group_by: str,
|
||||
query: Union[
|
||||
types.PointId,
|
||||
list[float],
|
||||
list[list[float]],
|
||||
types.SparseVector,
|
||||
types.Query,
|
||||
types.NumpyArray,
|
||||
types.Document,
|
||||
types.Image,
|
||||
types.InferenceObject,
|
||||
None,
|
||||
] = None,
|
||||
using: Optional[str] = None,
|
||||
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
group_size: int = 3,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
with_lookup: Optional[types.WithLookupInterface] = None,
|
||||
lookup_from: Optional[types.LookupLocation] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.GroupsResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def scroll(
|
||||
self,
|
||||
collection_name: str,
|
||||
scroll_filter: Optional[types.Filter] = None,
|
||||
limit: int = 10,
|
||||
order_by: Optional[types.OrderBy] = None,
|
||||
offset: Optional[types.PointId] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
**kwargs: Any,
|
||||
) -> tuple[list[types.Record], Optional[types.PointId]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def count(
|
||||
self,
|
||||
collection_name: str,
|
||||
count_filter: Optional[types.Filter] = None,
|
||||
exact: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> types.CountResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def facet(
|
||||
self,
|
||||
collection_name: str,
|
||||
key: str,
|
||||
facet_filter: Optional[types.Filter] = None,
|
||||
limit: int = 10,
|
||||
exact: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> types.FacetResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def upsert(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: types.Points,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: Sequence[types.PointVectors],
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: Sequence[str],
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Sequence[types.PointId],
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
**kwargs: Any,
|
||||
) -> list[types.Record]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
points_selector: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: types.Payload,
|
||||
points: types.PointsSelector,
|
||||
key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def overwrite_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: types.Payload,
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
keys: Sequence[str],
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def clear_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
points_selector: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def batch_update_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
update_operations: Sequence[types.UpdateOperation],
|
||||
**kwargs: Any,
|
||||
) -> list[types.UpdateResult]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_collection_aliases(
|
||||
self,
|
||||
change_aliases_operations: Sequence[types.AliasOperations],
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_collection_aliases(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> types.CollectionsAliasesResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def recreate_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def upload_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: Iterable[types.PointStruct],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def upload_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: Union[
|
||||
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
|
||||
],
|
||||
payload: Optional[Iterable[dict[Any, Any]]] = None,
|
||||
ids: Optional[Iterable[types.PointId]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_payload_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
field_name: str,
|
||||
field_schema: Optional[types.PayloadSchemaType] = None,
|
||||
field_type: Optional[types.PayloadSchemaType] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_payload_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
field_name: str,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
def list_snapshots(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> list[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_snapshot(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_snapshot(
|
||||
self, collection_name: str, snapshot_name: str, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_full_snapshot(self, **kwargs: Any) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_full_snapshot(self, snapshot_name: str, **kwargs: Any) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def recover_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
location: str,
|
||||
**kwargs: Any,
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def list_shard_snapshots(
|
||||
self, collection_name: str, shard_id: int, **kwargs: Any
|
||||
) -> list[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, **kwargs: Any
|
||||
) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, snapshot_name: str, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def recover_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
location: str,
|
||||
**kwargs: Any,
|
||||
) -> Optional[bool]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def migrate(
|
||||
self,
|
||||
dest_client: "QdrantBase",
|
||||
collection_names: Optional[list[str]] = None,
|
||||
batch_size: int = 100,
|
||||
recreate_on_collision: bool = False,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_key: types.ShardKey,
|
||||
shards_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
placement: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def delete_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_key: types.ShardKey,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def info(self) -> types.VersionInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
def cluster_collection_update(
|
||||
self,
|
||||
collection_name: str,
|
||||
cluster_operation: types.ClusterOperations,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
def cluster_status(self) -> types.ClusterStatus:
|
||||
raise NotImplementedError()
|
||||
|
||||
def recover_current_peer(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
def remove_peer(self, peer_id: int, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,16 @@
|
||||
class QdrantException(Exception):
|
||||
"""Base class"""
|
||||
|
||||
|
||||
class ResourceExhaustedResponse(QdrantException):
|
||||
def __init__(self, message: str, retry_after_s: int) -> None:
|
||||
self.message = message if message else "Resource Exhausted Response"
|
||||
try:
|
||||
self.retry_after_s = int(retry_after_s)
|
||||
except Exception as ex:
|
||||
raise QdrantException(
|
||||
f"Retry-After header value is not a valid integer: {retry_after_s}"
|
||||
) from ex
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.message.strip()
|
||||
@@ -0,0 +1,24 @@
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
SEEN_MESSAGES = set()
|
||||
|
||||
|
||||
def show_warning(message: str, category: type[Warning] = UserWarning, stacklevel: int = 2) -> None:
|
||||
warnings.warn(message, category, stacklevel=stacklevel)
|
||||
|
||||
|
||||
def show_warning_once(
|
||||
message: str,
|
||||
category: type[Warning] = UserWarning,
|
||||
idx: Optional[str] = None,
|
||||
stacklevel: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Show a warning of the specified category only once per program run.
|
||||
"""
|
||||
key = idx if idx is not None else message
|
||||
|
||||
if key not in SEEN_MESSAGES:
|
||||
SEEN_MESSAGES.add(key)
|
||||
show_warning(message, category, stacklevel)
|
||||
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from collections import namedtuple
|
||||
|
||||
import httpx
|
||||
|
||||
from qdrant_client.auth import BearerAuth
|
||||
|
||||
Version = namedtuple("Version", ["major", "minor", "rest"])
|
||||
|
||||
|
||||
def get_server_version(
|
||||
rest_uri: str, rest_headers: dict[str, Any], auth_provider: Optional[BearerAuth]
|
||||
) -> Optional[str]:
|
||||
response = httpx.get(rest_uri, headers=rest_headers, auth=auth_provider)
|
||||
|
||||
if response.status_code == 200:
|
||||
version_info = response.json().get("version", None)
|
||||
if not version_info:
|
||||
logging.debug(
|
||||
f"Unable to parse response from server: {response}, server version defaults to None"
|
||||
)
|
||||
return version_info
|
||||
else:
|
||||
logging.debug(
|
||||
f"Unexpected response from server: {response}, server version defaults to None"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def parse_version(version: str) -> Version:
|
||||
if not version:
|
||||
raise ValueError("Version is None")
|
||||
try:
|
||||
major, minor, *rest = version.split(".")
|
||||
return Version(int(major), int(minor), rest)
|
||||
except ValueError as er:
|
||||
raise ValueError(
|
||||
f"Unable to parse version, expected format: x.y.z, found: {version}"
|
||||
) from er
|
||||
|
||||
|
||||
def is_compatible(client_version: Optional[str], server_version: Optional[str]) -> bool:
|
||||
if not client_version:
|
||||
logging.debug(f"Unable to compare with client version {client_version}")
|
||||
return False
|
||||
|
||||
if not server_version:
|
||||
logging.debug(f"Unable to compare with server version {server_version}")
|
||||
return False
|
||||
|
||||
if client_version == server_version:
|
||||
return True
|
||||
|
||||
try:
|
||||
parsed_server_version = parse_version(server_version)
|
||||
parsed_client_version = parse_version(client_version)
|
||||
except ValueError as er:
|
||||
logging.debug(f"Unable to compare versions: {er}")
|
||||
return False
|
||||
|
||||
major_dif = abs(parsed_server_version.major - parsed_client_version.major)
|
||||
if major_dif >= 1:
|
||||
return False
|
||||
return abs(parsed_server_version.minor - parsed_client_version.minor) <= 1
|
||||
@@ -0,0 +1,353 @@
|
||||
import asyncio
|
||||
import collections
|
||||
from typing import Any, Awaitable, Callable, Optional, Union
|
||||
|
||||
import grpc
|
||||
|
||||
from qdrant_client.common.client_exceptions import ResourceExhaustedResponse
|
||||
from qdrant_client.common.client_warnings import show_warning_once
|
||||
|
||||
|
||||
# type: ignore # noqa: F401
|
||||
# Source <https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/generic_client_interceptor.py>
|
||||
class _GenericClientInterceptor(
|
||||
grpc.UnaryUnaryClientInterceptor,
|
||||
grpc.UnaryStreamClientInterceptor,
|
||||
grpc.StreamUnaryClientInterceptor,
|
||||
grpc.StreamStreamClientInterceptor,
|
||||
):
|
||||
def __init__(self, interceptor_function: Callable):
|
||||
self._fn = interceptor_function
|
||||
|
||||
def intercept_unary_unary(
|
||||
self, continuation: Any, client_call_details: Any, request: Any
|
||||
) -> Any:
|
||||
new_details, new_request_iterator, postprocess = self._fn(
|
||||
client_call_details, iter((request,)), False, False
|
||||
)
|
||||
response = continuation(new_details, next(new_request_iterator))
|
||||
return postprocess(response) if postprocess else response
|
||||
|
||||
def intercept_unary_stream(
|
||||
self, continuation: Any, client_call_details: Any, request: Any
|
||||
) -> Any:
|
||||
new_details, new_request_iterator, postprocess = self._fn(
|
||||
client_call_details, iter((request,)), False, True
|
||||
)
|
||||
response_it = continuation(new_details, next(new_request_iterator))
|
||||
return postprocess(response_it) if postprocess else response_it
|
||||
|
||||
def intercept_stream_unary(
|
||||
self, continuation: Any, client_call_details: Any, request_iterator: Any
|
||||
) -> Any:
|
||||
new_details, new_request_iterator, postprocess = self._fn(
|
||||
client_call_details, request_iterator, True, False
|
||||
)
|
||||
response = continuation(new_details, new_request_iterator)
|
||||
return postprocess(response) if postprocess else response
|
||||
|
||||
def intercept_stream_stream(
|
||||
self, continuation: Any, client_call_details: Any, request_iterator: Any
|
||||
) -> Any:
|
||||
new_details, new_request_iterator, postprocess = self._fn(
|
||||
client_call_details, request_iterator, True, True
|
||||
)
|
||||
response_it = continuation(new_details, new_request_iterator)
|
||||
return postprocess(response_it) if postprocess else response_it
|
||||
|
||||
|
||||
class _GenericAsyncClientInterceptor(
|
||||
grpc.aio.UnaryUnaryClientInterceptor,
|
||||
grpc.aio.UnaryStreamClientInterceptor,
|
||||
grpc.aio.StreamUnaryClientInterceptor,
|
||||
grpc.aio.StreamStreamClientInterceptor,
|
||||
):
|
||||
def __init__(self, interceptor_function: Callable):
|
||||
self._fn = interceptor_function
|
||||
|
||||
async def intercept_unary_unary(
|
||||
self, continuation: Any, client_call_details: Any, request: Any
|
||||
) -> Any:
|
||||
new_details, new_request_iterator, postprocess = await self._fn(
|
||||
client_call_details, iter((request,)), False, False
|
||||
)
|
||||
next_request = next(new_request_iterator)
|
||||
response = await continuation(new_details, next_request)
|
||||
return await postprocess(response) if postprocess else response
|
||||
|
||||
async def intercept_unary_stream(
|
||||
self, continuation: Any, client_call_details: Any, request: Any
|
||||
) -> Any:
|
||||
new_details, new_request_iterator, postprocess = await self._fn(
|
||||
client_call_details, iter((request,)), False, True
|
||||
)
|
||||
response_it = await continuation(new_details, next(new_request_iterator))
|
||||
return await postprocess(response_it) if postprocess else response_it
|
||||
|
||||
async def intercept_stream_unary(
|
||||
self, continuation: Any, client_call_details: Any, request_iterator: Any
|
||||
) -> Any:
|
||||
new_details, new_request_iterator, postprocess = await self._fn(
|
||||
client_call_details, request_iterator, True, False
|
||||
)
|
||||
response = await continuation(new_details, new_request_iterator)
|
||||
return await postprocess(response) if postprocess else response
|
||||
|
||||
async def intercept_stream_stream(
|
||||
self, continuation: Any, client_call_details: Any, request_iterator: Any
|
||||
) -> Any:
|
||||
new_details, new_request_iterator, postprocess = await self._fn(
|
||||
client_call_details, request_iterator, True, True
|
||||
)
|
||||
response_it = await continuation(new_details, new_request_iterator)
|
||||
return await postprocess(response_it) if postprocess else response_it
|
||||
|
||||
|
||||
def create_generic_client_interceptor(intercept_call: Any) -> _GenericClientInterceptor:
|
||||
return _GenericClientInterceptor(intercept_call)
|
||||
|
||||
|
||||
def create_generic_async_client_interceptor(
|
||||
intercept_call: Any,
|
||||
) -> _GenericAsyncClientInterceptor:
|
||||
return _GenericAsyncClientInterceptor(intercept_call)
|
||||
|
||||
|
||||
# Source:
|
||||
# <https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/header_manipulator_client_interceptor.py>
|
||||
class _ClientCallDetails(
|
||||
collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")),
|
||||
grpc.ClientCallDetails,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class _ClientAsyncCallDetails(
|
||||
collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")),
|
||||
grpc.aio.ClientCallDetails,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def header_adder_interceptor(
|
||||
new_metadata: list[tuple[str, str]],
|
||||
auth_token_provider: Optional[Callable[[], str]] = None,
|
||||
) -> _GenericClientInterceptor:
|
||||
def process_response(response: Any) -> Any:
|
||||
if response.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
|
||||
retry_after = None
|
||||
for item in response.trailing_metadata():
|
||||
if item.key == "retry-after":
|
||||
try:
|
||||
retry_after = int(item.value)
|
||||
except Exception:
|
||||
retry_after = None
|
||||
break
|
||||
reason_phrase = response.details() if response.details() else ""
|
||||
if retry_after:
|
||||
raise ResourceExhaustedResponse(message=reason_phrase, retry_after_s=retry_after)
|
||||
return response
|
||||
|
||||
def intercept_call(
|
||||
client_call_details: _ClientCallDetails,
|
||||
request_iterator: Any,
|
||||
_request_streaming: Any,
|
||||
_response_streaming: Any,
|
||||
) -> tuple[_ClientCallDetails, Any, Any]:
|
||||
metadata = []
|
||||
|
||||
if client_call_details.metadata is not None:
|
||||
metadata = list(client_call_details.metadata)
|
||||
for header, value in new_metadata:
|
||||
metadata.append(
|
||||
(
|
||||
header,
|
||||
value,
|
||||
)
|
||||
)
|
||||
|
||||
if auth_token_provider:
|
||||
if not asyncio.iscoroutinefunction(auth_token_provider):
|
||||
metadata.append(("authorization", f"Bearer {auth_token_provider()}"))
|
||||
else:
|
||||
raise ValueError("Synchronous channel requires synchronous auth token provider.")
|
||||
|
||||
client_call_details = _ClientCallDetails(
|
||||
client_call_details.method,
|
||||
client_call_details.timeout,
|
||||
metadata,
|
||||
client_call_details.credentials,
|
||||
)
|
||||
return client_call_details, request_iterator, process_response
|
||||
|
||||
return create_generic_client_interceptor(intercept_call)
|
||||
|
||||
|
||||
def header_adder_async_interceptor(
|
||||
new_metadata: list[tuple[str, str]],
|
||||
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
|
||||
) -> _GenericAsyncClientInterceptor:
|
||||
async def process_response(call: Any) -> Any:
|
||||
try:
|
||||
return await call
|
||||
except grpc.aio.AioRpcError as er:
|
||||
if er.code() == grpc.StatusCode.RESOURCE_EXHAUSTED:
|
||||
retry_after = None
|
||||
for item in er.trailing_metadata():
|
||||
if item[0] == "retry-after":
|
||||
try:
|
||||
retry_after = int(item[1])
|
||||
except Exception:
|
||||
retry_after = None
|
||||
break
|
||||
reason_phrase = er.details() if er.details() else ""
|
||||
if retry_after:
|
||||
raise ResourceExhaustedResponse(
|
||||
message=reason_phrase, retry_after_s=retry_after
|
||||
) from er
|
||||
raise
|
||||
|
||||
async def intercept_call(
|
||||
client_call_details: grpc.aio.ClientCallDetails,
|
||||
request_iterator: Any,
|
||||
_request_streaming: Any,
|
||||
_response_streaming: Any,
|
||||
) -> tuple[_ClientAsyncCallDetails, Any, Any]:
|
||||
metadata = []
|
||||
if client_call_details.metadata is not None:
|
||||
metadata = list(client_call_details.metadata)
|
||||
for header, value in new_metadata:
|
||||
metadata.append(
|
||||
(
|
||||
header,
|
||||
value,
|
||||
)
|
||||
)
|
||||
|
||||
if auth_token_provider:
|
||||
if asyncio.iscoroutinefunction(auth_token_provider):
|
||||
token = await auth_token_provider()
|
||||
else:
|
||||
token = auth_token_provider()
|
||||
metadata.append(("authorization", f"Bearer {token}"))
|
||||
|
||||
client_call_details = client_call_details._replace(metadata=metadata)
|
||||
return client_call_details, request_iterator, process_response
|
||||
|
||||
return create_generic_async_client_interceptor(intercept_call)
|
||||
|
||||
|
||||
def parse_channel_options(options: Optional[dict[str, Any]] = None) -> list[tuple[str, Any]]:
|
||||
default_options: list[tuple[str, Any]] = [
|
||||
("grpc.max_send_message_length", -1),
|
||||
("grpc.max_receive_message_length", -1),
|
||||
]
|
||||
|
||||
if options is None:
|
||||
return default_options
|
||||
|
||||
_options = [(option_name, option_value) for option_name, option_value in options.items()]
|
||||
for option_name, option_value in default_options:
|
||||
if option_name not in options:
|
||||
_options.append((option_name, option_value))
|
||||
|
||||
return _options
|
||||
|
||||
|
||||
def parse_ssl_credentials(options: Optional[dict[str, Any]] = None) -> dict[str, Optional[bytes]]:
|
||||
"""Parse ssl credentials to create `grpc.ssl_channel_credentials` for `grpc.secure_channel`
|
||||
|
||||
WARN: Directly modifies input `options`
|
||||
|
||||
Return:
|
||||
dict[str, Optional[bytes]]: dict(root_certificates=..., private_key=..., certificate_chain=...)
|
||||
"""
|
||||
ssl_options: dict[str, Optional[bytes]] = dict(
|
||||
root_certificates=None, private_key=None, certificate_chain=None
|
||||
)
|
||||
|
||||
if options is None:
|
||||
return ssl_options
|
||||
|
||||
for ssl_option_name in ssl_options:
|
||||
option_value: Any = options.pop(ssl_option_name, None)
|
||||
if f"grpc.{ssl_option_name}" in options:
|
||||
show_warning_once(
|
||||
f"`{ssl_option_name}` is supposed to be used without `grpc.` prefix",
|
||||
idx=f"grpc.{ssl_option_name}",
|
||||
stacklevel=10,
|
||||
)
|
||||
|
||||
if option_value is None:
|
||||
continue
|
||||
|
||||
if not isinstance(option_value, bytes):
|
||||
raise TypeError(f"{ssl_option_name} must be a byte string")
|
||||
|
||||
ssl_options[ssl_option_name] = option_value
|
||||
|
||||
return ssl_options
|
||||
|
||||
|
||||
def get_channel(
|
||||
host: str,
|
||||
port: int,
|
||||
ssl: bool,
|
||||
metadata: Optional[list[tuple[str, str]]] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
compression: Optional[grpc.Compression] = None,
|
||||
auth_token_provider: Optional[Callable[[], str]] = None,
|
||||
) -> grpc.Channel:
|
||||
# Parse gRPC client options
|
||||
_copied_options = (
|
||||
options.copy() if options is not None else None
|
||||
) # we're changing options inplace
|
||||
_ssl_cred_options = parse_ssl_credentials(_copied_options)
|
||||
_options = parse_channel_options(_copied_options)
|
||||
metadata_interceptor = header_adder_interceptor(
|
||||
new_metadata=metadata or [], auth_token_provider=auth_token_provider
|
||||
)
|
||||
|
||||
if ssl:
|
||||
ssl_creds = grpc.ssl_channel_credentials(**_ssl_cred_options)
|
||||
channel = grpc.secure_channel(f"{host}:{port}", ssl_creds, _options, compression)
|
||||
return grpc.intercept_channel(channel, metadata_interceptor)
|
||||
else:
|
||||
channel = grpc.insecure_channel(f"{host}:{port}", _options, compression)
|
||||
return grpc.intercept_channel(channel, metadata_interceptor)
|
||||
|
||||
|
||||
def get_async_channel(
|
||||
host: str,
|
||||
port: int,
|
||||
ssl: bool,
|
||||
metadata: Optional[list[tuple[str, str]]] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
compression: Optional[grpc.Compression] = None,
|
||||
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
|
||||
) -> grpc.aio.Channel:
|
||||
# Parse gRPC client options
|
||||
_copied_options = (
|
||||
options.copy() if options is not None else None
|
||||
) # we're changing options inplace
|
||||
_ssl_cred_options = parse_ssl_credentials(_copied_options)
|
||||
_options = parse_channel_options(_copied_options)
|
||||
|
||||
# Create metadata interceptor
|
||||
metadata_interceptor = header_adder_async_interceptor(
|
||||
new_metadata=metadata or [], auth_token_provider=auth_token_provider
|
||||
)
|
||||
|
||||
if ssl:
|
||||
ssl_creds = grpc.ssl_channel_credentials(**_ssl_cred_options)
|
||||
return grpc.aio.secure_channel(
|
||||
f"{host}:{port}",
|
||||
ssl_creds,
|
||||
_options,
|
||||
compression,
|
||||
interceptors=[metadata_interceptor],
|
||||
)
|
||||
else:
|
||||
return grpc.aio.insecure_channel(
|
||||
f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor]
|
||||
)
|
||||
@@ -0,0 +1,163 @@
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from typing import Union, get_args, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from qdrant_client import grpc
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
typing_remap = {
|
||||
rest.StrictStr: str,
|
||||
rest.StrictInt: int,
|
||||
rest.StrictFloat: float,
|
||||
rest.StrictBool: bool,
|
||||
}
|
||||
|
||||
|
||||
def remap_type(tp: type) -> type:
|
||||
"""Remap type to a type that can be used in type annotations
|
||||
|
||||
Pydantic uses custom types for strict types, so we need to remap them to standard types
|
||||
so that they can be used in type annotations and isinstance checks
|
||||
"""
|
||||
return typing_remap.get(tp, tp)
|
||||
|
||||
|
||||
def get_args_subscribed(tp): # type: ignore
|
||||
"""Get type arguments with all substitutions performed. Supports subscripted generics having __origin__
|
||||
|
||||
Args:
|
||||
tp: type to get arguments from. Can be either a type or a subscripted generic
|
||||
|
||||
Returns:
|
||||
tuple of type arguments
|
||||
"""
|
||||
return tuple(
|
||||
remap_type(arg if not hasattr(arg, "__origin__") else arg.__origin__)
|
||||
for arg in get_args(tp)
|
||||
)
|
||||
|
||||
|
||||
Filter = Union[rest.Filter, grpc.Filter]
|
||||
SearchParams = Union[rest.SearchParams, grpc.SearchParams]
|
||||
PayloadSelector = Union[rest.PayloadSelector, grpc.WithPayloadSelector]
|
||||
Distance = Union[rest.Distance, int] # type(grpc.Distance) == int
|
||||
HnswConfigDiff = Union[rest.HnswConfigDiff, grpc.HnswConfigDiff]
|
||||
VectorsConfigDiff = Union[rest.VectorsConfigDiff, grpc.VectorsConfigDiff]
|
||||
QuantizationConfigDiff = Union[rest.QuantizationConfigDiff, grpc.QuantizationConfigDiff]
|
||||
OptimizersConfigDiff = Union[rest.OptimizersConfigDiff, grpc.OptimizersConfigDiff]
|
||||
CollectionParamsDiff = Union[rest.CollectionParamsDiff, grpc.CollectionParamsDiff]
|
||||
WalConfigDiff = Union[rest.WalConfigDiff, grpc.WalConfigDiff]
|
||||
QuantizationConfig = Union[rest.QuantizationConfig, grpc.QuantizationConfig]
|
||||
PointId = Union[int, str, UUID, grpc.PointId]
|
||||
PayloadSchemaType = Union[
|
||||
rest.PayloadSchemaType,
|
||||
rest.PayloadSchemaParams,
|
||||
int,
|
||||
grpc.PayloadIndexParams,
|
||||
] # type(grpc.PayloadSchemaType) == int
|
||||
PointStruct: TypeAlias = rest.PointStruct
|
||||
Batch: TypeAlias = rest.Batch
|
||||
Points = Union[Batch, Sequence[Union[rest.PointStruct, grpc.PointStruct]]]
|
||||
PointsSelector = Union[
|
||||
list[PointId],
|
||||
rest.Filter,
|
||||
grpc.Filter,
|
||||
rest.PointsSelector,
|
||||
grpc.PointsSelector,
|
||||
]
|
||||
LookupLocation = Union[rest.LookupLocation, grpc.LookupLocation]
|
||||
RecommendStrategy: TypeAlias = rest.RecommendStrategy
|
||||
OrderBy = Union[rest.OrderByInterface, grpc.OrderBy]
|
||||
ShardingMethod: TypeAlias = rest.ShardingMethod
|
||||
ShardKey: TypeAlias = rest.ShardKey
|
||||
ShardKeySelector: TypeAlias = rest.ShardKeySelector
|
||||
|
||||
AliasOperations = Union[
|
||||
rest.CreateAliasOperation,
|
||||
rest.RenameAliasOperation,
|
||||
rest.DeleteAliasOperation,
|
||||
grpc.AliasOperations,
|
||||
]
|
||||
Payload: TypeAlias = rest.Payload
|
||||
|
||||
ScoredPoint: TypeAlias = rest.ScoredPoint
|
||||
UpdateResult: TypeAlias = rest.UpdateResult
|
||||
Record: TypeAlias = rest.Record
|
||||
CollectionsResponse: TypeAlias = rest.CollectionsResponse
|
||||
CollectionInfo: TypeAlias = rest.CollectionInfo
|
||||
CountResult: TypeAlias = rest.CountResult
|
||||
SnapshotDescription: TypeAlias = rest.SnapshotDescription
|
||||
NamedVector: TypeAlias = rest.NamedVector
|
||||
NamedSparseVector: TypeAlias = rest.NamedSparseVector
|
||||
SparseVector: TypeAlias = rest.SparseVector
|
||||
PointVectors: TypeAlias = rest.PointVectors
|
||||
Vector: TypeAlias = rest.Vector
|
||||
VectorInput: TypeAlias = rest.VectorInput
|
||||
VectorStruct: TypeAlias = rest.VectorStruct
|
||||
VectorParams: TypeAlias = rest.VectorParams
|
||||
SparseVectorParams: TypeAlias = rest.SparseVectorParams
|
||||
SnapshotPriority: TypeAlias = rest.SnapshotPriority
|
||||
CollectionsAliasesResponse: TypeAlias = rest.CollectionsAliasesResponse
|
||||
UpdateOperation: TypeAlias = rest.UpdateOperation
|
||||
Query: TypeAlias = rest.Query
|
||||
Prefetch: TypeAlias = rest.Prefetch
|
||||
Document: TypeAlias = rest.Document
|
||||
Image: TypeAlias = rest.Image
|
||||
InferenceObject: TypeAlias = rest.InferenceObject
|
||||
StrictModeConfig: TypeAlias = rest.StrictModeConfig
|
||||
|
||||
QueryRequest: TypeAlias = rest.QueryRequest
|
||||
|
||||
Mmr: TypeAlias = rest.Mmr
|
||||
|
||||
ReadConsistency: TypeAlias = rest.ReadConsistency
|
||||
WriteOrdering: TypeAlias = rest.WriteOrdering
|
||||
WithLookupInterface: TypeAlias = rest.WithLookupInterface
|
||||
|
||||
GroupsResult: TypeAlias = rest.GroupsResult
|
||||
QueryResponse: TypeAlias = rest.QueryResponse
|
||||
|
||||
FacetValue: TypeAlias = rest.FacetValue
|
||||
FacetResponse: TypeAlias = rest.FacetResponse
|
||||
SearchMatrixRequest = Union[rest.SearchMatrixRequest, grpc.SearchMatrixPoints]
|
||||
SearchMatrixOffsetsResponse: TypeAlias = rest.SearchMatrixOffsetsResponse
|
||||
SearchMatrixPairsResponse: TypeAlias = rest.SearchMatrixPairsResponse
|
||||
SearchMatrixPair: TypeAlias = rest.SearchMatrixPair
|
||||
|
||||
VersionInfo: TypeAlias = rest.VersionInfo
|
||||
|
||||
ReplicaState: TypeAlias = rest.ReplicaState
|
||||
ClusterOperations: TypeAlias = rest.ClusterOperations
|
||||
ClusterStatus: TypeAlias = rest.ClusterStatus
|
||||
CollectionClusterInfo: TypeAlias = rest.CollectionClusterInfo
|
||||
|
||||
# we can't use `nptyping` package due to numpy/python-version incompatibilities
|
||||
# thus we need to define precise type annotations while we support python3.7
|
||||
_np_numeric = Union[
|
||||
np.bool_, # pylance can't handle np.bool8 alias
|
||||
np.int8,
|
||||
np.int16,
|
||||
np.int32,
|
||||
np.int64,
|
||||
np.uint8,
|
||||
np.uint16,
|
||||
np.uint32,
|
||||
np.uint64,
|
||||
np.intp,
|
||||
np.uintp,
|
||||
np.float16,
|
||||
np.float32,
|
||||
np.float64,
|
||||
np.longdouble, # np.float96 and np.float128 are platform dependant aliases for longdouble
|
||||
]
|
||||
|
||||
NumpyArray: TypeAlias = npt.NDArray[_np_numeric]
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,94 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.embed.models import NumericVector
|
||||
|
||||
|
||||
class BuiltinEmbedder:
|
||||
_SUPPORTED_MODELS = ("Qdrant/Bm25",)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def embed(
|
||||
self,
|
||||
model_name: str,
|
||||
texts: Optional[list[str]] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> NumericVector:
|
||||
if texts is None:
|
||||
if "images" in kwargs:
|
||||
raise ValueError(
|
||||
"Image processing is only available with cloud inference of FastEmbed"
|
||||
)
|
||||
|
||||
raise ValueError("Texts must be provided for the inference")
|
||||
|
||||
if not self.is_supported_sparse_model(model_name):
|
||||
raise ValueError(
|
||||
f"Model {model_name} is not supported in {self.__class__.__name__}. "
|
||||
f"Did you forget to enable cloud inference or install FastEmbed for local inference?"
|
||||
)
|
||||
|
||||
return [models.Document(text=text, options=options, model=model_name) for text in texts]
|
||||
|
||||
@classmethod
|
||||
def is_supported_text_model(cls, model_name: str) -> bool:
|
||||
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return False # currently only Qdrant/Bm25 is supported
|
||||
|
||||
@classmethod
|
||||
def is_supported_image_model(cls, model_name: str) -> bool:
|
||||
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return False # currently only Qdrant/Bm25 is supported
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_text_model(cls, model_name: str) -> bool:
|
||||
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return False # currently only Qdrant/Bm25 is supported
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_multimodal_model(cls, model_name: str) -> bool:
|
||||
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return False # currently only Qdrant/Bm25 is supported
|
||||
|
||||
@classmethod
|
||||
def is_supported_sparse_model(cls, model_name: str) -> bool:
|
||||
"""Checks if the model is supported. Only `Qdrant/Bm25` is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return model_name.lower() in [model.lower() for model in cls._SUPPORTED_MODELS]
|
||||
@@ -0,0 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
INFERENCE_OBJECT_NAMES: set[str] = {"Document", "Image", "InferenceObject"}
|
||||
INFERENCE_OBJECT_TYPES = Union[models.Document, models.Image, models.InferenceObject]
|
||||
@@ -0,0 +1,176 @@
|
||||
from copy import copy
|
||||
from typing import Union, Optional, Iterable, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client._pydantic_compat import model_fields_set
|
||||
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
|
||||
from qdrant_client.embed.schema_parser import ModelSchemaParser
|
||||
|
||||
from qdrant_client.embed.utils import convert_paths, FieldPath
|
||||
|
||||
|
||||
class InspectorEmbed:
|
||||
"""Inspector which collects paths to objects requiring inference in the received models
|
||||
|
||||
Attributes:
|
||||
parser: ModelSchemaParser instance
|
||||
"""
|
||||
|
||||
def __init__(self, parser: Optional[ModelSchemaParser] = None) -> None:
|
||||
self.parser = ModelSchemaParser() if parser is None else parser
|
||||
|
||||
def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> list[FieldPath]:
|
||||
"""Looks for all the paths to objects requiring inference in the received models
|
||||
|
||||
Args:
|
||||
points: models to inspect
|
||||
|
||||
Returns:
|
||||
list of FieldPath objects
|
||||
"""
|
||||
paths = []
|
||||
if isinstance(points, BaseModel):
|
||||
self.parser.parse_model(points.__class__)
|
||||
paths.extend(self._inspect_model(points))
|
||||
elif isinstance(points, dict):
|
||||
for value in points.values():
|
||||
paths.extend(self.inspect(value))
|
||||
elif isinstance(points, Iterable):
|
||||
for point in points:
|
||||
if isinstance(point, BaseModel):
|
||||
self.parser.parse_model(point.__class__)
|
||||
paths.extend(self._inspect_model(point))
|
||||
|
||||
paths = sorted(set(paths))
|
||||
|
||||
return convert_paths(paths)
|
||||
|
||||
def _inspect_model(
|
||||
self, mod: BaseModel, paths: Optional[list[FieldPath]] = None, accum: Optional[str] = None
|
||||
) -> list[str]:
|
||||
"""Looks for all the paths to objects requiring inference in the received model
|
||||
|
||||
Args:
|
||||
mod: model to inspect
|
||||
paths: list of paths to the fields possibly containing objects for inference
|
||||
accum: accumulator for the path. Path is a dot separated string of field names which we assemble recursively
|
||||
|
||||
Returns:
|
||||
list of paths to the model fields containing objects for inference
|
||||
"""
|
||||
paths = self.parser.path_cache.get(mod.__class__.__name__, []) if paths is None else paths
|
||||
|
||||
found_paths = []
|
||||
for path in paths:
|
||||
found_paths.extend(
|
||||
self._inspect_inner_models(
|
||||
mod, path.current, path.tail if path.tail else [], accum
|
||||
)
|
||||
)
|
||||
return found_paths
|
||||
|
||||
def _inspect_inner_models(
|
||||
self,
|
||||
original_model: BaseModel,
|
||||
current_path: str,
|
||||
tail: list[FieldPath],
|
||||
accum: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Looks for all the paths to objects requiring inference in the received model
|
||||
|
||||
Args:
|
||||
original_model: model to inspect
|
||||
current_path: the field to inspect on the current iteration
|
||||
tail: list of FieldPath objects to the fields possibly containing objects for inference
|
||||
accum: accumulator for the path. Path is a dot separated string of field names which we assemble recursively
|
||||
|
||||
Returns:
|
||||
list of paths to the model fields containing objects for inference
|
||||
"""
|
||||
found_paths = []
|
||||
if accum is None:
|
||||
accum = current_path
|
||||
else:
|
||||
accum += f".{current_path}"
|
||||
|
||||
def inspect_recursive(member: BaseModel, accumulator: str) -> list[str]:
|
||||
"""Iterates over the set model fields, expand recursive ones and find paths to objects requiring inference
|
||||
|
||||
Args:
|
||||
member: currently inspected model, which may or may not contain recursive fields
|
||||
accumulator: accumulator for the path, which is a dot separated string assembled recursively
|
||||
"""
|
||||
recursive_paths = []
|
||||
for field in model_fields_set(member):
|
||||
if field in self.parser.name_recursive_ref_mapping:
|
||||
mapped_field = self.parser.name_recursive_ref_mapping[field]
|
||||
recursive_paths.extend(self.parser.path_cache[mapped_field])
|
||||
|
||||
return self._inspect_model(member, copy(recursive_paths), accumulator)
|
||||
|
||||
model = getattr(original_model, current_path, None)
|
||||
if model is None:
|
||||
return []
|
||||
|
||||
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return [accum]
|
||||
|
||||
if isinstance(model, BaseModel):
|
||||
found_paths.extend(inspect_recursive(model, accum))
|
||||
|
||||
for next_path in tail:
|
||||
found_paths.extend(
|
||||
self._inspect_inner_models(
|
||||
model, next_path.current, next_path.tail if next_path.tail else [], accum
|
||||
)
|
||||
)
|
||||
|
||||
return found_paths
|
||||
|
||||
elif isinstance(model, list):
|
||||
for current_model in model:
|
||||
if not isinstance(current_model, BaseModel):
|
||||
continue
|
||||
|
||||
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
found_paths.append(accum)
|
||||
|
||||
found_paths.extend(inspect_recursive(current_model, accum))
|
||||
|
||||
for next_path in tail:
|
||||
for current_model in model:
|
||||
found_paths.extend(
|
||||
self._inspect_inner_models(
|
||||
current_model,
|
||||
next_path.current,
|
||||
next_path.tail if next_path.tail else [],
|
||||
accum,
|
||||
)
|
||||
)
|
||||
return found_paths
|
||||
|
||||
elif isinstance(model, dict):
|
||||
found_paths = []
|
||||
for key, values in model.items():
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
for current_model in values:
|
||||
if not isinstance(current_model, BaseModel):
|
||||
continue
|
||||
|
||||
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
found_paths.append(accum)
|
||||
|
||||
found_paths.extend(inspect_recursive(current_model, accum))
|
||||
|
||||
for next_path in tail:
|
||||
for current_model in values:
|
||||
found_paths.extend(
|
||||
self._inspect_inner_models(
|
||||
current_model,
|
||||
next_path.current,
|
||||
next_path.tail if next_path.tail else [],
|
||||
accum,
|
||||
)
|
||||
)
|
||||
return found_paths
|
||||
@@ -0,0 +1,447 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Sequence, Any, TypeVar, Generic
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.embed.models import NumericVector
|
||||
from qdrant_client.fastembed_common import (
|
||||
OnnxProvider,
|
||||
ImageInput,
|
||||
TextEmbedding,
|
||||
SparseTextEmbedding,
|
||||
LateInteractionTextEmbedding,
|
||||
LateInteractionMultimodalEmbedding,
|
||||
ImageEmbedding,
|
||||
FastEmbedMisc,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ModelInstance(BaseModel, Generic[T], arbitrary_types_allowed=True): # type: ignore[call-arg]
|
||||
model: T
|
||||
options: dict[str, Any]
|
||||
deprecated: bool = False
|
||||
|
||||
|
||||
class Embedder:
|
||||
def __init__(self, threads: Optional[int] = None, **kwargs: Any) -> None:
|
||||
self.embedding_models: dict[str, list[ModelInstance[TextEmbedding]]] = defaultdict(list)
|
||||
self.sparse_embedding_models: dict[str, list[ModelInstance[SparseTextEmbedding]]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
self.late_interaction_embedding_models: dict[
|
||||
str, list[ModelInstance[LateInteractionTextEmbedding]]
|
||||
] = defaultdict(list)
|
||||
self.image_embedding_models: dict[str, list[ModelInstance[ImageEmbedding]]] = defaultdict(
|
||||
list
|
||||
)
|
||||
self.late_interaction_multimodal_embedding_models: dict[
|
||||
str, list[ModelInstance[LateInteractionMultimodalEmbedding]]
|
||||
] = defaultdict(list)
|
||||
self._threads = threads
|
||||
|
||||
def get_or_init_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
deprecated: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> TextEmbedding:
|
||||
if not FastEmbedMisc.is_supported_text_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_text_models()}"
|
||||
)
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
for instance in self.embedding_models[model_name]:
|
||||
if (deprecated and instance.deprecated) or (
|
||||
not deprecated and instance.options == options
|
||||
):
|
||||
return instance.model
|
||||
|
||||
model = TextEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[TextEmbedding] = ModelInstance(
|
||||
model=model, options=options, deprecated=deprecated
|
||||
)
|
||||
self.embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def get_or_init_sparse_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
deprecated: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> SparseTextEmbedding:
|
||||
if not FastEmbedMisc.is_supported_sparse_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_sparse_models()}"
|
||||
)
|
||||
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for instance in self.sparse_embedding_models[model_name]:
|
||||
if (deprecated and instance.deprecated) or (
|
||||
not deprecated and instance.options == options
|
||||
):
|
||||
return instance.model
|
||||
|
||||
model = SparseTextEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[SparseTextEmbedding] = ModelInstance(
|
||||
model=model, options=options, deprecated=deprecated
|
||||
)
|
||||
self.sparse_embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def get_or_init_late_interaction_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LateInteractionTextEmbedding:
|
||||
if not FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. "
|
||||
f"Supported models: {FastEmbedMisc.list_late_interaction_text_models()}"
|
||||
)
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for instance in self.late_interaction_embedding_models[model_name]:
|
||||
if instance.options == options:
|
||||
return instance.model
|
||||
|
||||
model = LateInteractionTextEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[LateInteractionTextEmbedding] = ModelInstance(
|
||||
model=model, options=options
|
||||
)
|
||||
self.late_interaction_embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def get_or_init_late_interaction_multimodal_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LateInteractionMultimodalEmbedding:
|
||||
if not FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. "
|
||||
f"Supported models: {FastEmbedMisc.list_late_interaction_multimodal_models()}"
|
||||
)
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for instance in self.late_interaction_multimodal_embedding_models[model_name]:
|
||||
if instance.options == options:
|
||||
return instance.model
|
||||
|
||||
model = LateInteractionMultimodalEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[LateInteractionMultimodalEmbedding] = ModelInstance(
|
||||
model=model, options=options
|
||||
)
|
||||
self.late_interaction_multimodal_embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def get_or_init_image_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageEmbedding:
|
||||
if not FastEmbedMisc.is_supported_image_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_image_models()}"
|
||||
)
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for instance in self.image_embedding_models[model_name]:
|
||||
if instance.options == options:
|
||||
return instance.model
|
||||
|
||||
model = ImageEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[ImageEmbedding] = ModelInstance(model=model, options=options)
|
||||
self.image_embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def embed(
|
||||
self,
|
||||
model_name: str,
|
||||
texts: Optional[list[str]] = None,
|
||||
images: Optional[list[ImageInput]] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
is_query: bool = False,
|
||||
batch_size: int = 8,
|
||||
) -> NumericVector:
|
||||
if (texts is None) is (images is None):
|
||||
raise ValueError("Either documents or images should be provided")
|
||||
|
||||
embeddings: NumericVector # define type for a static type checker
|
||||
if texts is not None:
|
||||
if FastEmbedMisc.is_supported_text_model(model_name):
|
||||
embeddings = self._embed_dense_text(
|
||||
texts, model_name, options, is_query, batch_size
|
||||
)
|
||||
elif FastEmbedMisc.is_supported_sparse_model(model_name):
|
||||
embeddings = self._embed_sparse_text(
|
||||
texts, model_name, options, is_query, batch_size
|
||||
)
|
||||
elif FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
|
||||
embeddings = self._embed_late_interaction_text(
|
||||
texts, model_name, options, is_query, batch_size
|
||||
)
|
||||
elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
|
||||
embeddings = self._embed_late_interaction_multimodal_text(
|
||||
texts, model_name, options, batch_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding model: {model_name}")
|
||||
else:
|
||||
assert (
|
||||
images is not None
|
||||
) # just to satisfy mypy which can't infer it from the previous conditions
|
||||
if FastEmbedMisc.is_supported_image_model(model_name):
|
||||
embeddings = self._embed_dense_image(images, model_name, options, batch_size)
|
||||
elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
|
||||
embeddings = self._embed_late_interaction_multimodal_image(
|
||||
images, model_name, options, batch_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding model: {model_name}")
|
||||
|
||||
return embeddings
|
||||
|
||||
def _embed_dense_text(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
is_query: bool,
|
||||
batch_size: int,
|
||||
) -> list[list[float]]:
|
||||
embedding_model_inst = self.get_or_init_model(model_name=model_name, **options or {})
|
||||
|
||||
if not is_query:
|
||||
embeddings = [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
|
||||
]
|
||||
else:
|
||||
embeddings = [
|
||||
embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
|
||||
]
|
||||
return embeddings
|
||||
|
||||
def _embed_sparse_text(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
is_query: bool,
|
||||
batch_size: int,
|
||||
) -> list[models.SparseVector]:
|
||||
embedding_model_inst = self.get_or_init_sparse_model(
|
||||
model_name=model_name, **options or {}
|
||||
)
|
||||
if not is_query:
|
||||
embeddings = [
|
||||
models.SparseVector(
|
||||
indices=sparse_embedding.indices.tolist(),
|
||||
values=sparse_embedding.values.tolist(),
|
||||
)
|
||||
for sparse_embedding in embedding_model_inst.embed(
|
||||
documents=texts, batch_size=batch_size
|
||||
)
|
||||
]
|
||||
else:
|
||||
embeddings = [
|
||||
models.SparseVector(
|
||||
indices=sparse_embedding.indices.tolist(),
|
||||
values=sparse_embedding.values.tolist(),
|
||||
)
|
||||
for sparse_embedding in embedding_model_inst.query_embed(query=texts)
|
||||
]
|
||||
return embeddings
|
||||
|
||||
def _embed_late_interaction_text(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
is_query: bool,
|
||||
batch_size: int,
|
||||
) -> list[list[list[float]]]:
|
||||
embedding_model_inst = self.get_or_init_late_interaction_model(
|
||||
model_name=model_name, **options or {}
|
||||
)
|
||||
if not is_query:
|
||||
embeddings = [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
|
||||
]
|
||||
else:
|
||||
embeddings = [
|
||||
embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
|
||||
]
|
||||
return embeddings
|
||||
|
||||
def _embed_late_interaction_multimodal_text(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
batch_size: int,
|
||||
) -> list[list[list[float]]]:
|
||||
embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
|
||||
model_name=model_name, **options or {}
|
||||
)
|
||||
return [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed_text(
|
||||
documents=texts, batch_size=batch_size
|
||||
)
|
||||
]
|
||||
|
||||
def _embed_late_interaction_multimodal_image(
|
||||
self,
|
||||
images: list[ImageInput],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
batch_size: int,
|
||||
) -> list[list[list[float]]]:
|
||||
embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
|
||||
model_name=model_name, **options or {}
|
||||
)
|
||||
return [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed_image(images=images, batch_size=batch_size)
|
||||
]
|
||||
|
||||
def _embed_dense_image(
|
||||
self,
|
||||
images: list[ImageInput],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
batch_size: int,
|
||||
) -> list[list[float]]:
|
||||
embedding_model_inst = self.get_or_init_image_model(model_name=model_name, **options or {})
|
||||
embeddings = [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed(images=images, batch_size=batch_size)
|
||||
]
|
||||
return embeddings
|
||||
|
||||
@classmethod
|
||||
def is_supported_text_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_text_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def is_supported_image_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_image_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_text_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_late_interaction_text_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_multimodal_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def is_supported_sparse_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_sparse_model(model_name)
|
||||
@@ -0,0 +1,498 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from multiprocessing import get_all_start_methods
|
||||
from typing import Optional, Union, Iterable, Any, Type, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client.embed.builtin_embedder import BuiltinEmbedder
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
|
||||
from qdrant_client.embed.embed_inspector import InspectorEmbed
|
||||
from qdrant_client.embed.embedder import Embedder
|
||||
from qdrant_client.embed.models import NumericVector, NumericVectorStruct
|
||||
from qdrant_client.embed.schema_parser import ModelSchemaParser
|
||||
from qdrant_client.embed.utils import FieldPath
|
||||
from qdrant_client.fastembed_common import FastEmbedMisc
|
||||
from qdrant_client.parallel_processor import ParallelWorkerPool, Worker
|
||||
from qdrant_client.uploader.uploader import iter_batch
|
||||
|
||||
|
||||
class ModelEmbedderWorker(Worker):
|
||||
def __init__(self, batch_size: int, **kwargs: Any):
|
||||
self.model_embedder = ModelEmbedder(**kwargs)
|
||||
self.batch_size = batch_size
|
||||
|
||||
@classmethod
|
||||
def start(cls, batch_size: int, **kwargs: Any) -> "ModelEmbedderWorker":
|
||||
return cls(threads=1, batch_size=batch_size, **kwargs)
|
||||
|
||||
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
|
||||
for idx, batch in items:
|
||||
yield (
|
||||
idx,
|
||||
list(
|
||||
self.model_embedder.embed_models_batch(
|
||||
batch, inference_batch_size=self.batch_size
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ModelEmbedder:
|
||||
MAX_INTERNAL_BATCH_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parser: Optional[ModelSchemaParser] = None,
|
||||
is_local_mode: bool = False,
|
||||
server_version: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._batch_accumulator: dict[str, list[INFERENCE_OBJECT_TYPES]] = {}
|
||||
self._embed_storage: dict[str, list[NumericVector]] = {}
|
||||
self._embed_inspector = InspectorEmbed(parser=parser)
|
||||
self._is_builtin_embedder_available = self._check_builtin_embedder_availability(
|
||||
is_local_mode, server_version
|
||||
)
|
||||
self.embedder = (
|
||||
Embedder(**kwargs) if FastEmbedMisc.is_installed() else BuiltinEmbedder(**kwargs)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_builtin_embedder_availability(
|
||||
is_local_mode: bool, server_version: Optional[str]
|
||||
) -> bool:
|
||||
if is_local_mode:
|
||||
return False
|
||||
|
||||
if (
|
||||
server_version is None
|
||||
): # failed to detect server version, it might happen due to security or network
|
||||
# problems even on supported server versions, so we are not blocking usage of BuiltinEmbedder.
|
||||
return True
|
||||
|
||||
try:
|
||||
major, minor, patch = server_version.split(".")
|
||||
patch = patch.split("-")[0]
|
||||
|
||||
if (int(major), int(minor), int(patch)) >= (1, 15, 3):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
def embed_models(
|
||||
self,
|
||||
raw_models: Union[BaseModel, Iterable[BaseModel]],
|
||||
is_query: bool = False,
|
||||
batch_size: int = 8,
|
||||
) -> Iterable[BaseModel]:
|
||||
"""Embed raw data fields in models and return models with vectors
|
||||
|
||||
If any of model fields required inference, a deepcopy of a model with computed embeddings is returned,
|
||||
otherwise returns original models.
|
||||
Args:
|
||||
raw_models: Iterable[BaseModel] - models which can contain fields with raw data
|
||||
is_query: bool - flag to determine which embed method to use. Defaults to False.
|
||||
batch_size: int - batch size for inference
|
||||
Returns:
|
||||
list[BaseModel]: models with embedded fields
|
||||
"""
|
||||
if not self._is_builtin_embedder_available:
|
||||
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
|
||||
|
||||
if isinstance(raw_models, BaseModel):
|
||||
raw_models = [raw_models]
|
||||
for raw_models_batch in iter_batch(raw_models, batch_size):
|
||||
yield from self.embed_models_batch(
|
||||
raw_models_batch, is_query, inference_batch_size=batch_size
|
||||
)
|
||||
|
||||
def embed_models_strict(
|
||||
self,
|
||||
raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]],
|
||||
batch_size: int = 8,
|
||||
parallel: Optional[int] = None,
|
||||
) -> Iterable[Union[dict[str, BaseModel], BaseModel]]:
|
||||
"""Embed raw data fields in models and return models with vectors
|
||||
|
||||
Requires every input sequences element to contain raw data fields to inference.
|
||||
Does not accept ready vectors.
|
||||
|
||||
Args:
|
||||
raw_models: Iterable[BaseModel] - models which contain fields with raw data to inference
|
||||
batch_size: int - batch size for inference
|
||||
parallel: int - number of parallel processes to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Iterable[Union[dict[str, BaseModel], BaseModel]]: models with embedded fields
|
||||
"""
|
||||
if not self._is_builtin_embedder_available:
|
||||
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
|
||||
|
||||
is_small = False
|
||||
|
||||
if isinstance(raw_models, list):
|
||||
if len(raw_models) < batch_size:
|
||||
is_small = True
|
||||
|
||||
if (
|
||||
isinstance(self.embedder, BuiltinEmbedder)
|
||||
or parallel is None
|
||||
or parallel == 1
|
||||
or is_small
|
||||
):
|
||||
for batch in iter_batch(raw_models, batch_size):
|
||||
yield from self.embed_models_batch(batch, inference_batch_size=batch_size)
|
||||
else:
|
||||
multiprocessing_batch_size = 1 # larger batch sizes do not help with data parallel
|
||||
# on cpu. todo: adjust when multi-gpu is available
|
||||
raw_models_batches = iter_batch(raw_models, size=multiprocessing_batch_size)
|
||||
if parallel == 0:
|
||||
parallel = os.cpu_count()
|
||||
|
||||
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
||||
assert parallel is not None # just a mypy complaint
|
||||
pool = ParallelWorkerPool(
|
||||
num_workers=parallel,
|
||||
worker=self._get_worker_class(),
|
||||
start_method=start_method,
|
||||
max_internal_batch_size=self.MAX_INTERNAL_BATCH_SIZE,
|
||||
)
|
||||
|
||||
for batch in pool.ordered_map(
|
||||
raw_models_batches, batch_size=multiprocessing_batch_size
|
||||
):
|
||||
yield from batch
|
||||
|
||||
def embed_models_batch(
|
||||
self,
|
||||
raw_models: list[Union[dict[str, BaseModel], BaseModel]],
|
||||
is_query: bool = False,
|
||||
inference_batch_size: int = 8,
|
||||
) -> Iterable[BaseModel]:
|
||||
"""Embed a batch of models with raw data fields and return models with vectors
|
||||
|
||||
If any of model fields required inference, a deepcopy of a model with computed embeddings is returned,
|
||||
otherwise returns original models.
|
||||
Args:
|
||||
raw_models: list[Union[dict[str, BaseModel], BaseModel]] - models which can contain fields with raw data
|
||||
is_query: bool - flag to determine which embed method to use. Defaults to False.
|
||||
inference_batch_size: int - batch size for inference
|
||||
Returns:
|
||||
Iterable[BaseModel]: models with embedded fields
|
||||
"""
|
||||
if not self._is_builtin_embedder_available:
|
||||
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
|
||||
|
||||
for raw_model in raw_models:
|
||||
self._process_model(raw_model, is_query=is_query, accumulating=True)
|
||||
|
||||
if not self._batch_accumulator:
|
||||
yield from raw_models
|
||||
else:
|
||||
yield from (
|
||||
self._process_model(
|
||||
raw_model,
|
||||
is_query=is_query,
|
||||
accumulating=False,
|
||||
inference_batch_size=inference_batch_size,
|
||||
)
|
||||
for raw_model in raw_models
|
||||
)
|
||||
|
||||
def _process_model(
|
||||
self,
|
||||
model: Union[dict[str, BaseModel], BaseModel],
|
||||
paths: Optional[list[FieldPath]] = None,
|
||||
is_query: bool = False,
|
||||
accumulating: bool = False,
|
||||
inference_batch_size: Optional[int] = None,
|
||||
) -> Union[dict[str, BaseModel], dict[str, NumericVector], BaseModel, NumericVector]:
|
||||
"""Embed model's fields requiring inference
|
||||
|
||||
Args:
|
||||
model: Qdrant http model containing fields to embed
|
||||
paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])]
|
||||
is_query: Flag to determine which embed method to use. Defaults to False.
|
||||
accumulating: Flag to determine if we are accumulating models for batch embedding. Defaults to False.
|
||||
inference_batch_size: Optional[int] - batch size for inference
|
||||
|
||||
Returns:
|
||||
A deepcopy of the method with embedded fields
|
||||
"""
|
||||
|
||||
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
if accumulating:
|
||||
self._accumulate(model) # type: ignore
|
||||
else:
|
||||
assert (
|
||||
inference_batch_size is not None
|
||||
), "inference_batch_size should be passed for inference"
|
||||
return self._drain_accumulator(
|
||||
model, # type: ignore
|
||||
is_query=is_query,
|
||||
inference_batch_size=inference_batch_size,
|
||||
)
|
||||
|
||||
if paths is None:
|
||||
model = deepcopy(model) if not accumulating else model
|
||||
|
||||
if isinstance(model, dict):
|
||||
for key, value in model.items():
|
||||
if accumulating:
|
||||
self._process_model(value, paths, accumulating=True)
|
||||
else:
|
||||
model[key] = self._process_model(
|
||||
value,
|
||||
paths,
|
||||
is_query=is_query,
|
||||
accumulating=False,
|
||||
inference_batch_size=inference_batch_size,
|
||||
)
|
||||
return model
|
||||
|
||||
paths = paths if paths is not None else self._embed_inspector.inspect(model)
|
||||
|
||||
for path in paths:
|
||||
list_model = [model] if not isinstance(model, list) else model
|
||||
for item in list_model:
|
||||
current_model = getattr(item, path.current, None)
|
||||
if current_model is None:
|
||||
continue
|
||||
if path.tail:
|
||||
self._process_model(
|
||||
current_model,
|
||||
path.tail,
|
||||
is_query=is_query,
|
||||
accumulating=accumulating,
|
||||
inference_batch_size=inference_batch_size,
|
||||
)
|
||||
else:
|
||||
was_list = isinstance(current_model, list)
|
||||
current_model = current_model if was_list else [current_model]
|
||||
|
||||
if not accumulating:
|
||||
assert (
|
||||
inference_batch_size is not None
|
||||
), "inference_batch_size should be passed for inference"
|
||||
embeddings = [
|
||||
self._drain_accumulator(
|
||||
data, is_query=is_query, inference_batch_size=inference_batch_size
|
||||
)
|
||||
for data in current_model
|
||||
]
|
||||
if was_list:
|
||||
setattr(item, path.current, embeddings)
|
||||
else:
|
||||
setattr(item, path.current, embeddings[0])
|
||||
else:
|
||||
for data in current_model:
|
||||
self._accumulate(data)
|
||||
return model
|
||||
|
||||
def _accumulate(self, data: models.VectorStruct) -> None:
|
||||
"""Add data to batch accumulator
|
||||
|
||||
Args:
|
||||
data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - add them
|
||||
to the accumulator, otherwise - do nothing. `InferenceObject` instances are converted to proper types.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
self._accumulate(value)
|
||||
return None
|
||||
|
||||
if isinstance(data, list):
|
||||
for value in data:
|
||||
if not isinstance(value, get_args(INFERENCE_OBJECT_TYPES)): # if value is a vector
|
||||
return None
|
||||
self._accumulate(value)
|
||||
|
||||
if not isinstance(data, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return None
|
||||
|
||||
data = self._resolve_inference_object(data)
|
||||
if data.model not in self._batch_accumulator:
|
||||
self._batch_accumulator[data.model] = []
|
||||
self._batch_accumulator[data.model].append(data)
|
||||
return None
|
||||
|
||||
def _drain_accumulator(
|
||||
self, data: models.VectorStruct, is_query: bool, inference_batch_size: int = 8
|
||||
) -> NumericVectorStruct:
|
||||
"""Drain accumulator and replaces inference objects with computed embeddings
|
||||
It is assumed objects are traversed in the same order as they were added to the accumulator
|
||||
|
||||
Args:
|
||||
data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - replace
|
||||
them with computed embeddings. If embeddings haven't yet been computed - compute them and then replace
|
||||
inference objects.
|
||||
inference_batch_size: int - batch size for inference
|
||||
|
||||
Returns:
|
||||
NumericVectorStruct: data with replaced inference objects
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
data[key] = self._drain_accumulator(
|
||||
value, is_query=is_query, inference_batch_size=inference_batch_size
|
||||
)
|
||||
return data
|
||||
|
||||
if isinstance(data, list):
|
||||
for i, value in enumerate(data):
|
||||
if not isinstance(value, get_args(INFERENCE_OBJECT_TYPES)): # if value is vector
|
||||
return data
|
||||
|
||||
data[i] = self._drain_accumulator(
|
||||
value, is_query=is_query, inference_batch_size=inference_batch_size
|
||||
)
|
||||
return data
|
||||
|
||||
if not isinstance(
|
||||
data, get_args(INFERENCE_OBJECT_TYPES)
|
||||
): # ide type checker ignores `not` and scolds
|
||||
return data # type: ignore
|
||||
|
||||
if not self._embed_storage or not self._embed_storage.get(data.model, None):
|
||||
self._embed_accumulator(is_query=is_query, inference_batch_size=inference_batch_size)
|
||||
|
||||
return self._next_embed(data.model)
|
||||
|
||||
def _embed_accumulator(self, is_query: bool = False, inference_batch_size: int = 8) -> None:
|
||||
"""Embed all accumulated objects for all models
|
||||
|
||||
Args:
|
||||
is_query: bool - flag to determine which embed method to use. Defaults to False.
|
||||
inference_batch_size: int - batch size for inference
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def embed(
|
||||
objects: list[INFERENCE_OBJECT_TYPES], model_name: str, batch_size: int
|
||||
) -> list[NumericVector]:
|
||||
"""
|
||||
Assemble batches by options and data type based groups, embeds and return embeddings in the original order
|
||||
"""
|
||||
unique_options: list[dict[str, Any]] = []
|
||||
unique_options_is_text: list[bool] = [] # multimodal models can have both text
|
||||
# and image data, we need to track which data we process to construct separate batches for texts and images
|
||||
batches: list[Any] = []
|
||||
group_indices: dict[int, list[int]] = defaultdict(list)
|
||||
for i, obj in enumerate(objects):
|
||||
is_text = isinstance(obj, models.Document)
|
||||
for j, (options, options_is_text) in enumerate(
|
||||
zip(unique_options, unique_options_is_text)
|
||||
):
|
||||
if options == obj.options and is_text == options_is_text:
|
||||
group_indices[j].append(i)
|
||||
batches[j].append(obj.text if is_text else obj.image)
|
||||
break
|
||||
else:
|
||||
# Create a new group if no match was found
|
||||
group_indices[len(unique_options)] = [i]
|
||||
unique_options.append(obj.options)
|
||||
unique_options_is_text.append(is_text)
|
||||
batches.append([obj.text if is_text else obj.image])
|
||||
|
||||
embeddings = []
|
||||
for i, (options, is_text) in enumerate(zip(unique_options, unique_options_is_text)):
|
||||
embeddings.extend(
|
||||
[
|
||||
embedding
|
||||
for embedding in self.embedder.embed(
|
||||
model_name=model_name,
|
||||
texts=batches[i] if is_text else None,
|
||||
images=batches[i] if not is_text else None,
|
||||
is_query=is_query,
|
||||
options=options or {},
|
||||
batch_size=batch_size,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
iter_embeddings = iter(embeddings)
|
||||
ordered_embeddings: list[list[NumericVector]] = [[]] * len(objects)
|
||||
for indices in group_indices.values():
|
||||
for index in indices:
|
||||
ordered_embeddings[index] = next(iter_embeddings)
|
||||
return ordered_embeddings
|
||||
|
||||
for model in self._batch_accumulator:
|
||||
if not any(
|
||||
(
|
||||
self.embedder.is_supported_text_model(model),
|
||||
self.embedder.is_supported_sparse_model(model),
|
||||
self.embedder.is_supported_late_interaction_text_model(model),
|
||||
self.embedder.is_supported_image_model(model),
|
||||
self.embedder.is_supported_late_interaction_multimodal_model(model),
|
||||
)
|
||||
):
|
||||
if isinstance(self.embedder, BuiltinEmbedder):
|
||||
raise ValueError(
|
||||
f"{model} is not among supported models. "
|
||||
f"Have you forgotten to set `cloud_inference` or install `fastembed` for local inference?"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{model} is not among supported models")
|
||||
|
||||
for model, data in self._batch_accumulator.items():
|
||||
self._embed_storage[model] = embed(
|
||||
objects=data, model_name=model, batch_size=inference_batch_size
|
||||
)
|
||||
self._batch_accumulator.clear()
|
||||
|
||||
def _next_embed(self, model_name: str) -> NumericVector:
|
||||
"""Get next computed embedding from embedded batch
|
||||
|
||||
Args:
|
||||
model_name: str - retrieve embedding from the storage by this model name
|
||||
|
||||
Returns:
|
||||
NumericVector: computed embedding
|
||||
"""
|
||||
return self._embed_storage[model_name].pop(0)
|
||||
|
||||
def _resolve_inference_object(self, data: models.VectorStruct) -> models.VectorStruct:
|
||||
"""Resolve inference object into a model
|
||||
|
||||
Args:
|
||||
data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type,
|
||||
otherwise - keep unchanged
|
||||
|
||||
Returns:
|
||||
models.VectorStruct: resolved data
|
||||
"""
|
||||
|
||||
if not isinstance(data, models.InferenceObject):
|
||||
return data
|
||||
|
||||
model_name = data.model
|
||||
value = data.object
|
||||
options = data.options
|
||||
if any(
|
||||
(
|
||||
self.embedder.is_supported_text_model(model_name),
|
||||
self.embedder.is_supported_sparse_model(model_name),
|
||||
self.embedder.is_supported_late_interaction_text_model(model_name),
|
||||
)
|
||||
):
|
||||
return models.Document(model=model_name, text=value, options=options)
|
||||
if self.embedder.is_supported_image_model(model_name):
|
||||
return models.Image(model=model_name, image=value, options=options)
|
||||
if self.embedder.is_supported_late_interaction_multimodal_model(model_name):
|
||||
raise ValueError(f"{model_name} does not support `InferenceObject` interface")
|
||||
|
||||
raise ValueError(f"{model_name} is not among supported models")
|
||||
|
||||
@classmethod
|
||||
def _get_worker_class(cls) -> Type[ModelEmbedderWorker]:
|
||||
return ModelEmbedderWorker
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import StrictFloat, StrictStr
|
||||
|
||||
from qdrant_client.http.models import ExtendedPointId, SparseVector
|
||||
|
||||
|
||||
NumericVector = Union[
|
||||
list[StrictFloat],
|
||||
SparseVector,
|
||||
list[list[StrictFloat]],
|
||||
]
|
||||
NumericVectorInput = Union[
|
||||
list[StrictFloat],
|
||||
SparseVector,
|
||||
list[list[StrictFloat]],
|
||||
ExtendedPointId,
|
||||
]
|
||||
NumericVectorStruct = Union[
|
||||
list[StrictFloat],
|
||||
list[list[StrictFloat]],
|
||||
dict[StrictStr, NumericVector],
|
||||
]
|
||||
|
||||
__all__ = ["NumericVector", "NumericVectorInput", "NumericVectorStruct"]
|
||||
@@ -0,0 +1,305 @@
|
||||
from copy import copy, deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Type, Union, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client._pydantic_compat import model_json_schema
|
||||
from qdrant_client.embed.utils import FieldPath, convert_paths
|
||||
|
||||
|
||||
try:
|
||||
from qdrant_client.embed._inspection_cache import (
|
||||
DEFS,
|
||||
CACHE_STR_PATH,
|
||||
RECURSIVE_REFS,
|
||||
EXCLUDED_RECURSIVE_REFS,
|
||||
INCLUDED_RECURSIVE_REFS,
|
||||
NAME_RECURSIVE_REF_MAPPING,
|
||||
)
|
||||
except ImportError as e:
|
||||
DEFS = {}
|
||||
CACHE_STR_PATH = {}
|
||||
RECURSIVE_REFS = set() # type: ignore
|
||||
EXCLUDED_RECURSIVE_REFS = {"Filter"} # type: ignore
|
||||
INCLUDED_RECURSIVE_REFS = set() # type: ignore
|
||||
NAME_RECURSIVE_REF_MAPPING = {}
|
||||
|
||||
|
||||
class ModelSchemaParser:
|
||||
"""Model schema parser. Parses json schemas to retrieve paths to objects requiring inference.
|
||||
|
||||
The parser is stateful, it accumulates the results of parsing in its internal structures.
|
||||
|
||||
Attributes:
|
||||
_defs: definitions extracted from json schemas
|
||||
_recursive_refs: set of recursive refs found in the processed schemas, e.g.:
|
||||
{"Filter", "Prefetch"}
|
||||
_excluded_recursive_refs: predefined time-consuming recursive refs which don't have inference objects, e.g.:
|
||||
{"Filter"}
|
||||
_included_recursive_refs: set of recursive refs which have inference objects, e.g.:
|
||||
{"Prefetch"}
|
||||
_cache: cache of string paths for models containing objects for inference, e.g.:
|
||||
{"Prefetch": ['prefetch.query', 'prefetch.query.context.negative', ...]}
|
||||
path_cache: cache of FieldPath objects for models containing objects for inference, e.g.:
|
||||
{
|
||||
"Prefetch": [
|
||||
FieldPath(
|
||||
current="prefetch",
|
||||
tail=[
|
||||
FieldPath(
|
||||
current="query",
|
||||
tail=[
|
||||
FieldPath(
|
||||
current="recommend",
|
||||
tail=[
|
||||
FieldPath(current="negative", tail=None),
|
||||
FieldPath(current="positive", tail=None),
|
||||
],
|
||||
),
|
||||
...,
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
name_recursive_ref_mapping: mapping of model field names to ref names, e.g.:
|
||||
{"prefetch": "Prefetch"}
|
||||
"""
|
||||
|
||||
CACHE_PATH = "_inspection_cache.py"
|
||||
INFERENCE_OBJECT_NAMES = {"Document", "Image", "InferenceObject"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
# self._defs does not include the whole schema, but only the part with the structures used in $defs
|
||||
self._defs: dict[str, Union[dict[str, Any], list[dict[str, Any]]]] = deepcopy(DEFS) # type: ignore[arg-type]
|
||||
self._cache: dict[str, list[str]] = deepcopy(CACHE_STR_PATH)
|
||||
|
||||
self._recursive_refs: set[str] = set(RECURSIVE_REFS)
|
||||
self._excluded_recursive_refs: set[str] = set(EXCLUDED_RECURSIVE_REFS)
|
||||
self._included_recursive_refs: set[str] = set(INCLUDED_RECURSIVE_REFS)
|
||||
|
||||
self.name_recursive_ref_mapping: dict[str, str] = {
|
||||
k: v for k, v in NAME_RECURSIVE_REF_MAPPING.items()
|
||||
}
|
||||
self.path_cache: dict[str, list[FieldPath]] = {
|
||||
model: convert_paths(paths) for model, paths in self._cache.items()
|
||||
}
|
||||
self._processed_recursive_defs: dict[str, Any] = {}
|
||||
|
||||
def _replace_refs(
|
||||
self,
|
||||
schema: Union[dict[str, Any], list[dict[str, Any]]],
|
||||
parent: Optional[str] = None,
|
||||
seen_refs: Optional[set] = None,
|
||||
) -> Union[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Replace refs in schema with their definitions
|
||||
|
||||
Args:
|
||||
schema: schema to parse
|
||||
parent: previous level key
|
||||
seen_refs: set of seen refs to spot recursive paths
|
||||
|
||||
Returns:
|
||||
schema with replaced refs
|
||||
"""
|
||||
parent = parent if parent else None
|
||||
seen_refs = seen_refs if seen_refs else set()
|
||||
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"]
|
||||
def_key = ref_path.split("/")[-1]
|
||||
if def_key in self._processed_recursive_defs:
|
||||
return self._processed_recursive_defs[def_key]
|
||||
|
||||
if def_key == parent or def_key in seen_refs:
|
||||
self._recursive_refs.add(def_key)
|
||||
self._processed_recursive_defs[def_key] = schema
|
||||
return schema
|
||||
|
||||
seen_refs.add(def_key)
|
||||
|
||||
return self._replace_refs(
|
||||
self._defs[def_key], parent=def_key, seen_refs=copy(seen_refs)
|
||||
)
|
||||
|
||||
schemes = {}
|
||||
if "properties" in schema:
|
||||
for k, v in schema.items():
|
||||
if k == "properties":
|
||||
schemes[k] = self._replace_refs(
|
||||
schema=v, parent=parent, seen_refs=copy(seen_refs)
|
||||
)
|
||||
else:
|
||||
schemes[k] = v
|
||||
else:
|
||||
for k, v in schema.items():
|
||||
parent_key = k if isinstance(v, dict) and "properties" in v else parent
|
||||
schemes[k] = self._replace_refs(
|
||||
schema=v, parent=parent_key, seen_refs=copy(seen_refs)
|
||||
)
|
||||
|
||||
return schemes
|
||||
elif isinstance(schema, list):
|
||||
return [
|
||||
self._replace_refs(schema=item, parent=parent, seen_refs=copy(seen_refs)) # type: ignore
|
||||
for item in schema
|
||||
]
|
||||
else:
|
||||
return schema
|
||||
|
||||
def _find_document_paths(
|
||||
self,
|
||||
schema: Union[dict[str, Any], list[dict[str, Any]]],
|
||||
current_path: str = "",
|
||||
after_properties: bool = False,
|
||||
seen_refs: Optional[set] = None,
|
||||
) -> list[str]:
|
||||
"""Read a schema and find paths to objects requiring inference
|
||||
|
||||
Populates model fields names to ref names mapping
|
||||
|
||||
Args:
|
||||
schema: schema to parse
|
||||
current_path: current path in the schema
|
||||
after_properties: flag indicating if the current path is after "properties" key
|
||||
seen_refs: set of seen refs to spot recursive paths
|
||||
|
||||
Returns:
|
||||
List of string dot separated paths to objects requiring inference
|
||||
"""
|
||||
document_paths: list[str] = []
|
||||
seen_recursive_refs = seen_refs if seen_refs is not None else set()
|
||||
|
||||
parts = current_path.split(".")
|
||||
if len(parts) != len(set(parts)): # check for recursive paths
|
||||
return document_paths
|
||||
|
||||
if not isinstance(schema, dict):
|
||||
return document_paths
|
||||
|
||||
if "title" in schema and schema["title"] in self.INFERENCE_OBJECT_NAMES:
|
||||
document_paths.append(current_path)
|
||||
return document_paths
|
||||
|
||||
for key, value in schema.items():
|
||||
if key == "$defs":
|
||||
continue
|
||||
|
||||
if key == "$ref":
|
||||
model_name = value.split("/")[-1]
|
||||
|
||||
value = self._defs[model_name]
|
||||
if model_name in seen_recursive_refs:
|
||||
continue
|
||||
|
||||
if (
|
||||
model_name in self._excluded_recursive_refs
|
||||
): # on the first run it might be empty
|
||||
continue
|
||||
|
||||
if (
|
||||
model_name in self._recursive_refs
|
||||
): # included and excluded refs might not be filled up yet, we're looking in all recursive refs
|
||||
# we would need to clean up name recursive ref mapping later and delete excluded refs from there
|
||||
seen_recursive_refs.add(model_name)
|
||||
self.name_recursive_ref_mapping[current_path.split(".")[-1]] = model_name
|
||||
|
||||
if after_properties: # field name seen in pydantic models comes after "properties" key
|
||||
if current_path:
|
||||
new_path = f"{current_path}.{key}"
|
||||
else:
|
||||
new_path = key
|
||||
else:
|
||||
new_path = current_path
|
||||
|
||||
if isinstance(value, dict):
|
||||
document_paths.extend(
|
||||
self._find_document_paths(
|
||||
value, new_path, key == "properties", seen_refs=seen_recursive_refs
|
||||
)
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
document_paths.extend(
|
||||
self._find_document_paths(
|
||||
item,
|
||||
new_path,
|
||||
key == "properties",
|
||||
seen_refs=seen_recursive_refs,
|
||||
)
|
||||
)
|
||||
|
||||
return sorted(set(document_paths))
|
||||
|
||||
def parse_model(self, model: Type[BaseModel]) -> None:
|
||||
"""Parse model schema to retrieve paths to objects requiring inference.
|
||||
|
||||
Checks model json schema, extracts definitions and finds paths to objects requiring inference.
|
||||
No parsing happens if model has already been processed.
|
||||
|
||||
Args:
|
||||
model: model to parse
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
model_name = model.__name__
|
||||
if model_name in self._cache:
|
||||
return None
|
||||
|
||||
schema = model_json_schema(model)
|
||||
|
||||
for k, v in schema.get("$defs", {}).items():
|
||||
if k not in self._defs:
|
||||
self._defs[k] = v
|
||||
|
||||
if "$defs" in schema:
|
||||
raw_refs = (
|
||||
{"$ref": schema["$ref"]}
|
||||
if "$ref" in schema
|
||||
else {"properties": schema["properties"]}
|
||||
)
|
||||
refs = self._replace_refs(raw_refs)
|
||||
self._cache[model_name] = self._find_document_paths(refs)
|
||||
else:
|
||||
self._cache[model_name] = []
|
||||
|
||||
for ref in self._recursive_refs:
|
||||
if ref in self._excluded_recursive_refs or ref in self._included_recursive_refs:
|
||||
continue
|
||||
|
||||
if self._find_document_paths(self._defs[ref]):
|
||||
self._included_recursive_refs.add(ref)
|
||||
else:
|
||||
self._excluded_recursive_refs.add(ref)
|
||||
|
||||
self.name_recursive_ref_mapping = {
|
||||
k: v
|
||||
for k, v in self.name_recursive_ref_mapping.items()
|
||||
if v not in self._excluded_recursive_refs
|
||||
}
|
||||
|
||||
# convert str paths to FieldPath objects which group path parts and reduce the time of the traversal
|
||||
self.path_cache = {model: convert_paths(paths) for model, paths in self._cache.items()}
|
||||
|
||||
def _persist(self, output_path: Union[Path, str] = CACHE_PATH) -> None:
|
||||
"""Persist the parser state to a file
|
||||
|
||||
Args:
|
||||
output_path: path to the file to save the parser state
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with open(output_path, "w") as f:
|
||||
f.write(f"CACHE_STR_PATH = {self._cache}\n")
|
||||
f.write(f"DEFS = {self._defs}\n")
|
||||
# `sorted is required` to use `diff` in comparisons
|
||||
f.write(f"RECURSIVE_REFS = {sorted(self._recursive_refs)}\n")
|
||||
f.write(f"INCLUDED_RECURSIVE_REFS = {sorted(self._included_recursive_refs)}\n")
|
||||
f.write(f"EXCLUDED_RECURSIVE_REFS = {sorted(self._excluded_recursive_refs)}\n")
|
||||
f.write(f"NAME_RECURSIVE_REF_MAPPING = {self.name_recursive_ref_mapping}\n")
|
||||
@@ -0,0 +1,149 @@
|
||||
from typing import Union, Optional, Iterable, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client._pydantic_compat import model_fields_set
|
||||
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
|
||||
|
||||
from qdrant_client.embed.schema_parser import ModelSchemaParser
|
||||
from qdrant_client.embed.utils import FieldPath
|
||||
|
||||
|
||||
class Inspector:
|
||||
"""Inspector which tries to find at least one occurrence of an object requiring inference
|
||||
|
||||
Inspector is stateful and accumulates parsed model schemes in its parser.
|
||||
|
||||
Attributes:
|
||||
parser: ModelSchemaParser instance to inspect model json schemas
|
||||
"""
|
||||
|
||||
def __init__(self, parser: Optional[ModelSchemaParser] = None) -> None:
|
||||
self.parser = ModelSchemaParser() if parser is None else parser
|
||||
|
||||
def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> bool:
|
||||
"""Looks for at least one occurrence of an object requiring inference in the received models
|
||||
|
||||
Args:
|
||||
points: models to inspect
|
||||
|
||||
Returns:
|
||||
True if at least one object requiring inference is found, False otherwise
|
||||
"""
|
||||
if isinstance(points, BaseModel):
|
||||
self.parser.parse_model(points.__class__)
|
||||
return self._inspect_model(points)
|
||||
|
||||
elif isinstance(points, dict):
|
||||
for value in points.values():
|
||||
if self.inspect(value):
|
||||
return True
|
||||
|
||||
elif isinstance(points, Iterable):
|
||||
for point in points:
|
||||
if isinstance(point, BaseModel):
|
||||
self.parser.parse_model(point.__class__)
|
||||
if self._inspect_model(point):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def _inspect_model(self, model: BaseModel, paths: Optional[list[FieldPath]] = None) -> bool:
|
||||
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return True
|
||||
|
||||
paths = (
|
||||
self.parser.path_cache.get(model.__class__.__name__, []) if paths is None else paths
|
||||
)
|
||||
|
||||
for path in paths:
|
||||
type_found = self._inspect_inner_models(
|
||||
model, path.current, path.tail if path.tail else []
|
||||
)
|
||||
if type_found:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _inspect_inner_models(
|
||||
self, original_model: BaseModel, current_path: str, tail: list[FieldPath]
|
||||
) -> bool:
|
||||
def inspect_recursive(member: BaseModel) -> bool:
|
||||
recursive_paths = []
|
||||
for field_name in model_fields_set(member):
|
||||
if field_name in self.parser.name_recursive_ref_mapping:
|
||||
mapped_model_name = self.parser.name_recursive_ref_mapping[field_name]
|
||||
recursive_paths.extend(self.parser.path_cache[mapped_model_name])
|
||||
|
||||
if recursive_paths:
|
||||
found = self._inspect_model(member, recursive_paths)
|
||||
if found:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
model = getattr(original_model, current_path, None)
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return True
|
||||
|
||||
if isinstance(model, BaseModel):
|
||||
type_found = inspect_recursive(model)
|
||||
if type_found:
|
||||
return True
|
||||
|
||||
for next_path in tail:
|
||||
type_found = self._inspect_inner_models(
|
||||
model, next_path.current, next_path.tail if next_path.tail else []
|
||||
)
|
||||
if type_found:
|
||||
return True
|
||||
return False
|
||||
|
||||
elif isinstance(model, list):
|
||||
for current_model in model:
|
||||
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return True
|
||||
|
||||
if not isinstance(current_model, BaseModel):
|
||||
continue
|
||||
|
||||
type_found = inspect_recursive(current_model)
|
||||
if type_found:
|
||||
return True
|
||||
|
||||
for next_path in tail:
|
||||
for current_model in model:
|
||||
type_found = self._inspect_inner_models(
|
||||
current_model, next_path.current, next_path.tail if next_path.tail else []
|
||||
)
|
||||
if type_found:
|
||||
return True
|
||||
return False
|
||||
|
||||
elif isinstance(model, dict):
|
||||
for key, values in model.items():
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
for current_model in values:
|
||||
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return True
|
||||
|
||||
if not isinstance(current_model, BaseModel):
|
||||
continue
|
||||
|
||||
found_type = inspect_recursive(current_model)
|
||||
if found_type:
|
||||
return True
|
||||
|
||||
for next_path in tail:
|
||||
for current_model in values:
|
||||
found_type = self._inspect_inner_models(
|
||||
current_model,
|
||||
next_path.current,
|
||||
next_path.tail if next_path.tail else [],
|
||||
)
|
||||
if found_type:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,79 @@
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FieldPath(BaseModel):
|
||||
current: str
|
||||
tail: Optional[list["FieldPath"]] = Field(default=None)
|
||||
|
||||
def as_str_list(self) -> list[str]:
|
||||
"""
|
||||
>>> FieldPath(current='a', tail=[FieldPath(current='b', tail=[FieldPath(current='c'), FieldPath(current='d')])]).as_str_list()
|
||||
['a.b.c', 'a.b.d']
|
||||
"""
|
||||
|
||||
# Recursive function to collect all paths
|
||||
def collect_paths(path: FieldPath, prefix: str = "") -> list[str]:
|
||||
current_path = prefix + path.current
|
||||
if not path.tail:
|
||||
return [current_path]
|
||||
else:
|
||||
paths = []
|
||||
for sub_path in path.tail:
|
||||
paths.extend(collect_paths(sub_path, current_path + "."))
|
||||
return paths
|
||||
|
||||
# Collect all paths starting from this object
|
||||
return collect_paths(self)
|
||||
|
||||
|
||||
def convert_paths(paths: list[str]) -> list[FieldPath]:
|
||||
"""Convert string paths into FieldPath objects
|
||||
|
||||
Paths which share the same root are grouped together.
|
||||
|
||||
Args:
|
||||
paths: List[str]: List of str paths containing "." as separator
|
||||
|
||||
Returns:
|
||||
List[FieldPath]: List of FieldPath objects
|
||||
"""
|
||||
sorted_paths = sorted(paths)
|
||||
prev_root = None
|
||||
converted_paths = []
|
||||
for path in sorted_paths:
|
||||
parts = path.split(".")
|
||||
root = parts[0]
|
||||
if root != prev_root:
|
||||
converted_paths.append(FieldPath(current=root))
|
||||
prev_root = root
|
||||
current = converted_paths[-1]
|
||||
for part in parts[1:]:
|
||||
if current.tail is None:
|
||||
current.tail = []
|
||||
found = False
|
||||
for tail in current.tail:
|
||||
if tail.current == part:
|
||||
current = tail
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
new_tail = FieldPath(current=part)
|
||||
assert current.tail is not None
|
||||
current.tail.append(new_tail)
|
||||
current = new_tail
|
||||
return converted_paths
|
||||
|
||||
|
||||
def read_base64(file_path: Union[str, Path]) -> str:
|
||||
"""Convert a file path to a base64 encoded string."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"The file {path} does not exist.")
|
||||
|
||||
with open(path, "rb") as file:
|
||||
file_content = file.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
@@ -0,0 +1,323 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from qdrant_client.conversions.common_types import SparseVector
|
||||
from qdrant_client.http import models
|
||||
|
||||
try:
|
||||
from fastembed import (
|
||||
TextEmbedding,
|
||||
SparseTextEmbedding,
|
||||
ImageEmbedding,
|
||||
LateInteractionTextEmbedding,
|
||||
LateInteractionMultimodalEmbedding,
|
||||
)
|
||||
from fastembed.common import OnnxProvider, ImageInput
|
||||
except ImportError:
|
||||
TextEmbedding = None
|
||||
SparseTextEmbedding = None
|
||||
ImageEmbedding = None
|
||||
LateInteractionTextEmbedding = None
|
||||
LateInteractionMultimodalEmbedding = None
|
||||
OnnxProvider = None
|
||||
ImageInput = None
|
||||
|
||||
|
||||
class QueryResponse(BaseModel, extra="forbid"): # type: ignore
|
||||
id: Union[str, int]
|
||||
embedding: Optional[list[float]]
|
||||
sparse_embedding: Optional[SparseVector] = Field(default=None)
|
||||
metadata: dict[str, Any]
|
||||
document: str
|
||||
score: float
|
||||
|
||||
|
||||
class FastEmbedMisc:
|
||||
IS_INSTALLED: bool = False
|
||||
_TEXT_MODELS: set[str] = set()
|
||||
_IMAGE_MODELS: set[str] = set()
|
||||
_LATE_INTERACTION_TEXT_MODELS: set[str] = set()
|
||||
_LATE_INTERACTION_MULTIMODAL_MODELS: set[str] = set()
|
||||
_SPARSE_MODELS: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def is_installed(cls) -> bool:
|
||||
if cls.IS_INSTALLED:
|
||||
return cls.IS_INSTALLED
|
||||
|
||||
try:
|
||||
from fastembed import (
|
||||
SparseTextEmbedding,
|
||||
TextEmbedding,
|
||||
ImageEmbedding,
|
||||
LateInteractionMultimodalEmbedding,
|
||||
LateInteractionTextEmbedding,
|
||||
)
|
||||
|
||||
assert len(SparseTextEmbedding.list_supported_models()) > 0
|
||||
assert len(TextEmbedding.list_supported_models()) > 0
|
||||
assert len(ImageEmbedding.list_supported_models()) > 0
|
||||
assert len(LateInteractionTextEmbedding.list_supported_models()) > 0
|
||||
assert len(LateInteractionMultimodalEmbedding.list_supported_models()) > 0
|
||||
cls.IS_INSTALLED = True
|
||||
except ImportError:
|
||||
cls.IS_INSTALLED = False
|
||||
|
||||
return cls.IS_INSTALLED
|
||||
|
||||
@classmethod
|
||||
def import_fastembed(cls) -> None:
|
||||
if cls.IS_INSTALLED:
|
||||
return
|
||||
|
||||
# If it's not, ask the user to install it
|
||||
raise ImportError(
|
||||
"fastembed is not installed."
|
||||
" Please install it to enable fast vector indexing with `pip install fastembed`."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_text_models(cls) -> dict[str, tuple[int, models.Distance]]:
|
||||
"""Lists the supported dense text models.
|
||||
|
||||
Requires invocation of TextEmbedding.list_supported_models() to support custom models.
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
|
||||
"""
|
||||
return (
|
||||
{
|
||||
model["model"]: (model["dim"], models.Distance.COSINE)
|
||||
for model in TextEmbedding.list_supported_models()
|
||||
}
|
||||
if TextEmbedding
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_image_models(cls) -> dict[str, tuple[int, models.Distance]]:
|
||||
"""Lists the supported image dense models.
|
||||
|
||||
Custom image models are not supported yet, but calls to ImageEmbedding.list_supported_models() is done each
|
||||
time in order for preserving the same style as with TextEmbedding.
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
|
||||
"""
|
||||
return (
|
||||
{
|
||||
model["model"]: (model["dim"], models.Distance.COSINE)
|
||||
for model in ImageEmbedding.list_supported_models()
|
||||
}
|
||||
if ImageEmbedding
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_late_interaction_text_models(cls) -> dict[str, tuple[int, models.Distance]]:
|
||||
"""Lists the supported late interaction text models.
|
||||
|
||||
Custom late interaction models are not supported yet, but calls to
|
||||
LateInteractionTextEmbedding.list_supported_models()
|
||||
is done each time in order for preserving the same style as with TextEmbedding.
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
|
||||
"""
|
||||
return (
|
||||
{
|
||||
model["model"]: (model["dim"], models.Distance.COSINE)
|
||||
for model in LateInteractionTextEmbedding.list_supported_models()
|
||||
}
|
||||
if LateInteractionTextEmbedding
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_late_interaction_multimodal_models(cls) -> dict[str, tuple[int, models.Distance]]:
|
||||
"""Lists the supported late interaction multimodal models.
|
||||
|
||||
Custom late interaction multimodal models are not supported yet, but calls to
|
||||
LateInteractionMultimodalEmbedding.list_supported_models()
|
||||
is done each time in order for preserving the same style as with TextEmbedding.
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[int, models.Distance]]: A dict of model names, their dimensions and distance metrics.
|
||||
"""
|
||||
return (
|
||||
{
|
||||
model["model"]: (model["dim"], models.Distance.COSINE)
|
||||
for model in LateInteractionMultimodalEmbedding.list_supported_models()
|
||||
}
|
||||
if LateInteractionMultimodalEmbedding
|
||||
else {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_sparse_models(cls) -> dict[str, dict[str, Any]]:
|
||||
"""Lists the supported sparse models.
|
||||
|
||||
Custom sparse models are not supported yet, but calls to
|
||||
SparseTextEmbedding.list_supported_models()
|
||||
is done each time in order for preserving the same style as with TextEmbedding.
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, Any]]: A dict of model names and their descriptions.
|
||||
"""
|
||||
descriptions = {}
|
||||
if SparseTextEmbedding:
|
||||
for description in SparseTextEmbedding.list_supported_models():
|
||||
descriptions[description.pop("model")] = description
|
||||
return descriptions
|
||||
|
||||
@classmethod
|
||||
def is_supported_text_model(cls, model_name: str) -> bool:
|
||||
"""Checks if the model is supported by fastembed.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
if model_name.lower() in cls._TEXT_MODELS:
|
||||
return True
|
||||
# update cached list in case custom models were added
|
||||
cls._TEXT_MODELS = {model.lower() for model in cls.list_text_models()}
|
||||
if model_name.lower() in cls._TEXT_MODELS:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_supported_image_model(cls, model_name: str) -> bool:
|
||||
"""Checks if the model is supported by fastembed.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
if model_name.lower() in cls._IMAGE_MODELS:
|
||||
return True
|
||||
# update cached list in case custom models were added
|
||||
cls._IMAGE_MODELS = {model.lower() for model in cls.list_image_models()}
|
||||
if model_name.lower() in cls._IMAGE_MODELS:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_text_model(cls, model_name: str) -> bool:
|
||||
"""Checks if the model is supported by fastembed.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
if model_name.lower() in cls._LATE_INTERACTION_TEXT_MODELS:
|
||||
return True
|
||||
# update cached list in case custom models were added
|
||||
cls._LATE_INTERACTION_TEXT_MODELS = {
|
||||
model.lower() for model in cls.list_late_interaction_text_models()
|
||||
}
|
||||
if model_name.lower() in cls._LATE_INTERACTION_TEXT_MODELS:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_multimodal_model(cls, model_name: str) -> bool:
|
||||
"""Checks if the model is supported by fastembed.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
if model_name.lower() in cls._LATE_INTERACTION_MULTIMODAL_MODELS:
|
||||
return True
|
||||
# update cached list in case custom models were added
|
||||
cls._LATE_INTERACTION_MULTIMODAL_MODELS = {
|
||||
model.lower() for model in cls.list_late_interaction_multimodal_models()
|
||||
}
|
||||
if model_name.lower() in cls._LATE_INTERACTION_MULTIMODAL_MODELS:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_supported_sparse_model(cls, model_name: str) -> bool:
|
||||
"""Checks if the model is supported by fastembed.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
if model_name.lower() in cls._SPARSE_MODELS:
|
||||
return True
|
||||
# update cached list in case custom models were added
|
||||
cls._SPARSE_MODELS = {model.lower() for model in cls.list_sparse_models()}
|
||||
if model_name.lower() in cls._SPARSE_MODELS:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# region deprecated
|
||||
# prefer using methods builtin into QdrantClient, e.g. list_supported_text_models, list_supported_idf_models, etc.
|
||||
|
||||
SUPPORTED_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
|
||||
{
|
||||
model["model"]: (model["dim"], models.Distance.COSINE)
|
||||
for model in TextEmbedding.list_supported_models()
|
||||
}
|
||||
if TextEmbedding
|
||||
else {}
|
||||
)
|
||||
|
||||
SUPPORTED_SPARSE_EMBEDDING_MODELS: dict[str, dict[str, Any]] = (
|
||||
{model["model"]: model for model in SparseTextEmbedding.list_supported_models()}
|
||||
if SparseTextEmbedding
|
||||
else {}
|
||||
)
|
||||
|
||||
IDF_EMBEDDING_MODELS: set[str] = (
|
||||
{
|
||||
model_config["model"]
|
||||
for model_config in SparseTextEmbedding.list_supported_models()
|
||||
if model_config.get("requires_idf", None)
|
||||
}
|
||||
if SparseTextEmbedding
|
||||
else set()
|
||||
)
|
||||
|
||||
_LATE_INTERACTION_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
|
||||
{
|
||||
model["model"]: (model["dim"], models.Distance.COSINE)
|
||||
for model in LateInteractionTextEmbedding.list_supported_models()
|
||||
}
|
||||
if LateInteractionTextEmbedding
|
||||
else {}
|
||||
)
|
||||
|
||||
_IMAGE_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
|
||||
{
|
||||
model["model"]: (model["dim"], models.Distance.COSINE)
|
||||
for model in ImageEmbedding.list_supported_models()
|
||||
}
|
||||
if ImageEmbedding
|
||||
else {}
|
||||
)
|
||||
|
||||
_LATE_INTERACTION_MULTIMODAL_EMBEDDING_MODELS: dict[str, tuple[int, models.Distance]] = (
|
||||
{
|
||||
model["model"]: (model["dim"], models.Distance.COSINE)
|
||||
for model in LateInteractionMultimodalEmbedding.list_supported_models()
|
||||
}
|
||||
if LateInteractionMultimodalEmbedding
|
||||
else {}
|
||||
)
|
||||
# endregion
|
||||
@@ -0,0 +1,10 @@
|
||||
from .points_pb2 import *
|
||||
from .collections_pb2 import *
|
||||
from .qdrant_common_pb2 import *
|
||||
from .snapshots_service_pb2 import *
|
||||
from .json_with_int_pb2 import *
|
||||
from .collections_service_pb2_grpc import *
|
||||
from .points_service_pb2_grpc import *
|
||||
from .snapshots_service_pb2_grpc import *
|
||||
from .qdrant_pb2 import *
|
||||
from .qdrant_pb2_grpc import *
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,4 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: collections_service.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from . import collections_pb2 as collections__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19\x63ollections_service.proto\x12\x06qdrant\x1a\x11\x63ollections.proto2\xe2\x08\n\x0b\x43ollections\x12L\n\x03Get\x12 .qdrant.GetCollectionInfoRequest\x1a!.qdrant.GetCollectionInfoResponse\"\x00\x12I\n\x04List\x12\x1e.qdrant.ListCollectionsRequest\x1a\x1f.qdrant.ListCollectionsResponse\"\x00\x12I\n\x06\x43reate\x12\x18.qdrant.CreateCollection\x1a#.qdrant.CollectionOperationResponse\"\x00\x12I\n\x06Update\x12\x18.qdrant.UpdateCollection\x1a#.qdrant.CollectionOperationResponse\"\x00\x12I\n\x06\x44\x65lete\x12\x18.qdrant.DeleteCollection\x1a#.qdrant.CollectionOperationResponse\"\x00\x12M\n\rUpdateAliases\x12\x15.qdrant.ChangeAliases\x1a#.qdrant.CollectionOperationResponse\"\x00\x12\\\n\x15ListCollectionAliases\x12$.qdrant.ListCollectionAliasesRequest\x1a\x1b.qdrant.ListAliasesResponse\"\x00\x12H\n\x0bListAliases\x12\x1a.qdrant.ListAliasesRequest\x1a\x1b.qdrant.ListAliasesResponse\"\x00\x12\x66\n\x15\x43ollectionClusterInfo\x12$.qdrant.CollectionClusterInfoRequest\x1a%.qdrant.CollectionClusterInfoResponse\"\x00\x12W\n\x10\x43ollectionExists\x12\x1f.qdrant.CollectionExistsRequest\x1a .qdrant.CollectionExistsResponse\"\x00\x12{\n\x1cUpdateCollectionClusterSetup\x12+.qdrant.UpdateCollectionClusterSetupRequest\x1a,.qdrant.UpdateCollectionClusterSetupResponse\"\x00\x12Q\n\x0e\x43reateShardKey\x12\x1d.qdrant.CreateShardKeyRequest\x1a\x1e.qdrant.CreateShardKeyResponse\"\x00\x12Q\n\x0e\x44\x65leteShardKey\x12\x1d.qdrant.DeleteShardKeyRequest\x1a\x1e.qdrant.DeleteShardKeyResponse\"\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'collections_service_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
|
||||
_globals['_COLLECTIONS']._serialized_start=57
|
||||
_globals['_COLLECTIONS']._serialized_end=1179
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
import google.protobuf.descriptor
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
+488
@@ -0,0 +1,488 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
from . import collections_pb2 as collections__pb2
|
||||
|
||||
|
||||
class CollectionsStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Get = channel.unary_unary(
|
||||
'/qdrant.Collections/Get',
|
||||
request_serializer=collections__pb2.GetCollectionInfoRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.GetCollectionInfoResponse.FromString,
|
||||
)
|
||||
self.List = channel.unary_unary(
|
||||
'/qdrant.Collections/List',
|
||||
request_serializer=collections__pb2.ListCollectionsRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.ListCollectionsResponse.FromString,
|
||||
)
|
||||
self.Create = channel.unary_unary(
|
||||
'/qdrant.Collections/Create',
|
||||
request_serializer=collections__pb2.CreateCollection.SerializeToString,
|
||||
response_deserializer=collections__pb2.CollectionOperationResponse.FromString,
|
||||
)
|
||||
self.Update = channel.unary_unary(
|
||||
'/qdrant.Collections/Update',
|
||||
request_serializer=collections__pb2.UpdateCollection.SerializeToString,
|
||||
response_deserializer=collections__pb2.CollectionOperationResponse.FromString,
|
||||
)
|
||||
self.Delete = channel.unary_unary(
|
||||
'/qdrant.Collections/Delete',
|
||||
request_serializer=collections__pb2.DeleteCollection.SerializeToString,
|
||||
response_deserializer=collections__pb2.CollectionOperationResponse.FromString,
|
||||
)
|
||||
self.UpdateAliases = channel.unary_unary(
|
||||
'/qdrant.Collections/UpdateAliases',
|
||||
request_serializer=collections__pb2.ChangeAliases.SerializeToString,
|
||||
response_deserializer=collections__pb2.CollectionOperationResponse.FromString,
|
||||
)
|
||||
self.ListCollectionAliases = channel.unary_unary(
|
||||
'/qdrant.Collections/ListCollectionAliases',
|
||||
request_serializer=collections__pb2.ListCollectionAliasesRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.ListAliasesResponse.FromString,
|
||||
)
|
||||
self.ListAliases = channel.unary_unary(
|
||||
'/qdrant.Collections/ListAliases',
|
||||
request_serializer=collections__pb2.ListAliasesRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.ListAliasesResponse.FromString,
|
||||
)
|
||||
self.CollectionClusterInfo = channel.unary_unary(
|
||||
'/qdrant.Collections/CollectionClusterInfo',
|
||||
request_serializer=collections__pb2.CollectionClusterInfoRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.CollectionClusterInfoResponse.FromString,
|
||||
)
|
||||
self.CollectionExists = channel.unary_unary(
|
||||
'/qdrant.Collections/CollectionExists',
|
||||
request_serializer=collections__pb2.CollectionExistsRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.CollectionExistsResponse.FromString,
|
||||
)
|
||||
self.UpdateCollectionClusterSetup = channel.unary_unary(
|
||||
'/qdrant.Collections/UpdateCollectionClusterSetup',
|
||||
request_serializer=collections__pb2.UpdateCollectionClusterSetupRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.UpdateCollectionClusterSetupResponse.FromString,
|
||||
)
|
||||
self.CreateShardKey = channel.unary_unary(
|
||||
'/qdrant.Collections/CreateShardKey',
|
||||
request_serializer=collections__pb2.CreateShardKeyRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.CreateShardKeyResponse.FromString,
|
||||
)
|
||||
self.DeleteShardKey = channel.unary_unary(
|
||||
'/qdrant.Collections/DeleteShardKey',
|
||||
request_serializer=collections__pb2.DeleteShardKeyRequest.SerializeToString,
|
||||
response_deserializer=collections__pb2.DeleteShardKeyResponse.FromString,
|
||||
)
|
||||
|
||||
|
||||
class CollectionsServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Get(self, request, context):
|
||||
"""
|
||||
Get detailed information about specified existing collection
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def List(self, request, context):
|
||||
"""
|
||||
Get list name of all existing collections
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Create(self, request, context):
|
||||
"""
|
||||
Create new collection with given parameters
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Update(self, request, context):
|
||||
"""
|
||||
Update parameters of the existing collection
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Delete(self, request, context):
|
||||
"""
|
||||
Drop collection and all associated data
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def UpdateAliases(self, request, context):
|
||||
"""
|
||||
Update Aliases of the existing collection
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def ListCollectionAliases(self, request, context):
|
||||
"""
|
||||
Get list of all aliases for a collection
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def ListAliases(self, request, context):
|
||||
"""
|
||||
Get list of all aliases for all existing collections
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CollectionClusterInfo(self, request, context):
|
||||
"""
|
||||
Get cluster information for a collection
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CollectionExists(self, request, context):
|
||||
"""
|
||||
Check the existence of a collection
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def UpdateCollectionClusterSetup(self, request, context):
|
||||
"""
|
||||
Update cluster setup for a collection
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CreateShardKey(self, request, context):
|
||||
"""
|
||||
Create shard key
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def DeleteShardKey(self, request, context):
|
||||
"""
|
||||
Delete shard key
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_CollectionsServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Get': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Get,
|
||||
request_deserializer=collections__pb2.GetCollectionInfoRequest.FromString,
|
||||
response_serializer=collections__pb2.GetCollectionInfoResponse.SerializeToString,
|
||||
),
|
||||
'List': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.List,
|
||||
request_deserializer=collections__pb2.ListCollectionsRequest.FromString,
|
||||
response_serializer=collections__pb2.ListCollectionsResponse.SerializeToString,
|
||||
),
|
||||
'Create': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Create,
|
||||
request_deserializer=collections__pb2.CreateCollection.FromString,
|
||||
response_serializer=collections__pb2.CollectionOperationResponse.SerializeToString,
|
||||
),
|
||||
'Update': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Update,
|
||||
request_deserializer=collections__pb2.UpdateCollection.FromString,
|
||||
response_serializer=collections__pb2.CollectionOperationResponse.SerializeToString,
|
||||
),
|
||||
'Delete': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Delete,
|
||||
request_deserializer=collections__pb2.DeleteCollection.FromString,
|
||||
response_serializer=collections__pb2.CollectionOperationResponse.SerializeToString,
|
||||
),
|
||||
'UpdateAliases': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.UpdateAliases,
|
||||
request_deserializer=collections__pb2.ChangeAliases.FromString,
|
||||
response_serializer=collections__pb2.CollectionOperationResponse.SerializeToString,
|
||||
),
|
||||
'ListCollectionAliases': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ListCollectionAliases,
|
||||
request_deserializer=collections__pb2.ListCollectionAliasesRequest.FromString,
|
||||
response_serializer=collections__pb2.ListAliasesResponse.SerializeToString,
|
||||
),
|
||||
'ListAliases': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ListAliases,
|
||||
request_deserializer=collections__pb2.ListAliasesRequest.FromString,
|
||||
response_serializer=collections__pb2.ListAliasesResponse.SerializeToString,
|
||||
),
|
||||
'CollectionClusterInfo': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectionClusterInfo,
|
||||
request_deserializer=collections__pb2.CollectionClusterInfoRequest.FromString,
|
||||
response_serializer=collections__pb2.CollectionClusterInfoResponse.SerializeToString,
|
||||
),
|
||||
'CollectionExists': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectionExists,
|
||||
request_deserializer=collections__pb2.CollectionExistsRequest.FromString,
|
||||
response_serializer=collections__pb2.CollectionExistsResponse.SerializeToString,
|
||||
),
|
||||
'UpdateCollectionClusterSetup': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.UpdateCollectionClusterSetup,
|
||||
request_deserializer=collections__pb2.UpdateCollectionClusterSetupRequest.FromString,
|
||||
response_serializer=collections__pb2.UpdateCollectionClusterSetupResponse.SerializeToString,
|
||||
),
|
||||
'CreateShardKey': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CreateShardKey,
|
||||
request_deserializer=collections__pb2.CreateShardKeyRequest.FromString,
|
||||
response_serializer=collections__pb2.CreateShardKeyResponse.SerializeToString,
|
||||
),
|
||||
'DeleteShardKey': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.DeleteShardKey,
|
||||
request_deserializer=collections__pb2.DeleteShardKeyRequest.FromString,
|
||||
response_serializer=collections__pb2.DeleteShardKeyResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'qdrant.Collections', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Collections(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Get(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/Get',
|
||||
collections__pb2.GetCollectionInfoRequest.SerializeToString,
|
||||
collections__pb2.GetCollectionInfoResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def List(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/List',
|
||||
collections__pb2.ListCollectionsRequest.SerializeToString,
|
||||
collections__pb2.ListCollectionsResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Create(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/Create',
|
||||
collections__pb2.CreateCollection.SerializeToString,
|
||||
collections__pb2.CollectionOperationResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Update(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/Update',
|
||||
collections__pb2.UpdateCollection.SerializeToString,
|
||||
collections__pb2.CollectionOperationResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Delete(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/Delete',
|
||||
collections__pb2.DeleteCollection.SerializeToString,
|
||||
collections__pb2.CollectionOperationResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def UpdateAliases(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/UpdateAliases',
|
||||
collections__pb2.ChangeAliases.SerializeToString,
|
||||
collections__pb2.CollectionOperationResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def ListCollectionAliases(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/ListCollectionAliases',
|
||||
collections__pb2.ListCollectionAliasesRequest.SerializeToString,
|
||||
collections__pb2.ListAliasesResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def ListAliases(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/ListAliases',
|
||||
collections__pb2.ListAliasesRequest.SerializeToString,
|
||||
collections__pb2.ListAliasesResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def CollectionClusterInfo(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/CollectionClusterInfo',
|
||||
collections__pb2.CollectionClusterInfoRequest.SerializeToString,
|
||||
collections__pb2.CollectionClusterInfoResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def CollectionExists(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/CollectionExists',
|
||||
collections__pb2.CollectionExistsRequest.SerializeToString,
|
||||
collections__pb2.CollectionExistsResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def UpdateCollectionClusterSetup(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/UpdateCollectionClusterSetup',
|
||||
collections__pb2.UpdateCollectionClusterSetupRequest.SerializeToString,
|
||||
collections__pb2.UpdateCollectionClusterSetupResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def CreateShardKey(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/CreateShardKey',
|
||||
collections__pb2.CreateShardKeyRequest.SerializeToString,
|
||||
collections__pb2.CreateShardKeyResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def DeleteShardKey(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Collections/DeleteShardKey',
|
||||
collections__pb2.DeleteShardKeyRequest.SerializeToString,
|
||||
collections__pb2.DeleteShardKeyResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
@@ -0,0 +1,37 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: json_with_int.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13json_with_int.proto\x12\x06qdrant\"r\n\x06Struct\x12*\n\x06\x66ields\x18\x01 \x03(\x0b\x32\x1a.qdrant.Struct.FieldsEntry\x1a<\n\x0b\x46ieldsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1c\n\x05value\x18\x02 \x01(\x0b\x32\r.qdrant.Value:\x02\x38\x01\"\xe8\x01\n\x05Value\x12\'\n\nnull_value\x18\x01 \x01(\x0e\x32\x11.qdrant.NullValueH\x00\x12\x16\n\x0c\x64ouble_value\x18\x02 \x01(\x01H\x00\x12\x17\n\rinteger_value\x18\x03 \x01(\x03H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x14\n\nbool_value\x18\x05 \x01(\x08H\x00\x12&\n\x0cstruct_value\x18\x06 \x01(\x0b\x32\x0e.qdrant.StructH\x00\x12\'\n\nlist_value\x18\x07 \x01(\x0b\x32\x11.qdrant.ListValueH\x00\x42\x06\n\x04kind\"*\n\tListValue\x12\x1d\n\x06values\x18\x01 \x03(\x0b\x32\r.qdrant.Value*\x1b\n\tNullValue\x12\x0e\n\nNULL_VALUE\x10\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'json_with_int_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
|
||||
_globals['_STRUCT_FIELDSENTRY']._options = None
|
||||
_globals['_STRUCT_FIELDSENTRY']._serialized_options = b'8\001'
|
||||
_globals['_NULLVALUE']._serialized_start=426
|
||||
_globals['_NULLVALUE']._serialized_end=453
|
||||
_globals['_STRUCT']._serialized_start=31
|
||||
_globals['_STRUCT']._serialized_end=145
|
||||
_globals['_STRUCT_FIELDSENTRY']._serialized_start=85
|
||||
_globals['_STRUCT_FIELDSENTRY']._serialized_end=145
|
||||
_globals['_VALUE']._serialized_start=148
|
||||
_globals['_VALUE']._serialized_end=380
|
||||
_globals['_LISTVALUE']._serialized_start=382
|
||||
_globals['_LISTVALUE']._serialized_end=424
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Fork of the google.protobuf.Value with explicit support for integer values"""
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
class _NullValue:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _NullValueEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_NullValue.ValueType], builtins.type): # noqa: F821
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NULL_VALUE: _NullValue.ValueType # 0
|
||||
"""Null value."""
|
||||
|
||||
class NullValue(_NullValue, metaclass=_NullValueEnumTypeWrapper):
|
||||
"""`NullValue` is a singleton enumeration to represent the null value for the
|
||||
`Value` type union.
|
||||
|
||||
The JSON representation for `NullValue` is JSON `null`.
|
||||
"""
|
||||
|
||||
NULL_VALUE: NullValue.ValueType # 0
|
||||
"""Null value."""
|
||||
global___NullValue = NullValue
|
||||
|
||||
class Struct(google.protobuf.message.Message):
|
||||
"""`Struct` represents a structured data value, consisting of fields
|
||||
which map to dynamically typed values. In some languages, `Struct`
|
||||
might be supported by a native representation. For example, in
|
||||
scripting languages like JS a struct is represented as an
|
||||
object. The details of that representation are described together
|
||||
with the proto support for the language.
|
||||
|
||||
The JSON representation for `Struct` is a JSON object.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class FieldsEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
@property
|
||||
def value(self) -> global___Value: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
value: global___Value | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
FIELDS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def fields(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___Value]:
|
||||
"""Unordered map of dynamically typed values."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
fields: collections.abc.Mapping[builtins.str, global___Value] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["fields", b"fields"]) -> None: ...
|
||||
|
||||
global___Struct = Struct
|
||||
|
||||
class Value(google.protobuf.message.Message):
|
||||
"""`Value` represents a dynamically typed value which can be either
|
||||
null, a number, a string, a boolean, a recursive struct value, or a
|
||||
list of values. A producer of value is expected to set one of those
|
||||
variants, absence of any variant indicates an error.
|
||||
|
||||
The JSON representation for `Value` is a JSON value.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
NULL_VALUE_FIELD_NUMBER: builtins.int
|
||||
DOUBLE_VALUE_FIELD_NUMBER: builtins.int
|
||||
INTEGER_VALUE_FIELD_NUMBER: builtins.int
|
||||
STRING_VALUE_FIELD_NUMBER: builtins.int
|
||||
BOOL_VALUE_FIELD_NUMBER: builtins.int
|
||||
STRUCT_VALUE_FIELD_NUMBER: builtins.int
|
||||
LIST_VALUE_FIELD_NUMBER: builtins.int
|
||||
null_value: global___NullValue.ValueType
|
||||
"""Represents a null value."""
|
||||
double_value: builtins.float
|
||||
"""Represents a double value."""
|
||||
integer_value: builtins.int
|
||||
"""Represents an integer value"""
|
||||
string_value: builtins.str
|
||||
"""Represents a string value."""
|
||||
bool_value: builtins.bool
|
||||
"""Represents a boolean value."""
|
||||
@property
|
||||
def struct_value(self) -> global___Struct:
|
||||
"""Represents a structured value."""
|
||||
@property
|
||||
def list_value(self) -> global___ListValue:
|
||||
"""Represents a repeated `Value`."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
null_value: global___NullValue.ValueType = ...,
|
||||
double_value: builtins.float = ...,
|
||||
integer_value: builtins.int = ...,
|
||||
string_value: builtins.str = ...,
|
||||
bool_value: builtins.bool = ...,
|
||||
struct_value: global___Struct | None = ...,
|
||||
list_value: global___ListValue | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["bool_value", b"bool_value", "double_value", b"double_value", "integer_value", b"integer_value", "kind", b"kind", "list_value", b"list_value", "null_value", b"null_value", "string_value", b"string_value", "struct_value", b"struct_value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["bool_value", b"bool_value", "double_value", b"double_value", "integer_value", b"integer_value", "kind", b"kind", "list_value", b"list_value", "null_value", b"null_value", "string_value", b"string_value", "struct_value", b"struct_value"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["kind", b"kind"]) -> typing_extensions.Literal["null_value", "double_value", "integer_value", "string_value", "bool_value", "struct_value", "list_value"] | None: ...
|
||||
|
||||
global___Value = Value
|
||||
|
||||
class ListValue(google.protobuf.message.Message):
|
||||
"""`ListValue` is a wrapper around a repeated field of values.
|
||||
|
||||
The JSON representation for `ListValue` is a JSON array.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
VALUES_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Value]:
|
||||
"""Repeated field of dynamically typed values."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
values: collections.abc.Iterable[global___Value] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["values", b"values"]) -> None: ...
|
||||
|
||||
global___ListValue = ListValue
|
||||
@@ -0,0 +1,4 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,4 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: points_service.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from . import points_pb2 as points__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14points_service.proto\x12\x06qdrant\x1a\x0cpoints.proto2\xfd\x0f\n\x06Points\x12\x41\n\x06Upsert\x12\x14.qdrant.UpsertPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12\x41\n\x06\x44\x65lete\x12\x14.qdrant.DeletePoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12/\n\x03Get\x12\x11.qdrant.GetPoints\x1a\x13.qdrant.GetResponse\"\x00\x12N\n\rUpdateVectors\x12\x1a.qdrant.UpdatePointVectors\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12N\n\rDeleteVectors\x12\x1a.qdrant.DeletePointVectors\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12I\n\nSetPayload\x12\x18.qdrant.SetPayloadPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12O\n\x10OverwritePayload\x12\x18.qdrant.SetPayloadPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12O\n\rDeletePayload\x12\x1b.qdrant.DeletePayloadPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12M\n\x0c\x43learPayload\x12\x1a.qdrant.ClearPayloadPoints\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12Y\n\x10\x43reateFieldIndex\x12\".qdrant.CreateFieldIndexCollection\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12Y\n\x10\x44\x65leteFieldIndex\x12\".qdrant.DeleteFieldIndexCollection\x1a\x1f.qdrant.PointsOperationResponse\"\x00\x12\x38\n\x06Search\x12\x14.qdrant.SearchPoints\x1a\x16.qdrant.SearchResponse\"\x00\x12G\n\x0bSearchBatch\x12\x19.qdrant.SearchBatchPoints\x1a\x1b.qdrant.SearchBatchResponse\"\x00\x12I\n\x0cSearchGroups\x12\x19.qdrant.SearchPointGroups\x1a\x1c.qdrant.SearchGroupsResponse\"\x00\x12\x38\n\x06Scroll\x12\x14.qdrant.ScrollPoints\x1a\x16.qdrant.ScrollResponse\"\x00\x12\x41\n\tRecommend\x12\x17.qdrant.RecommendPoints\x1a\x19.qdrant.RecommendResponse\"\x00\x12P\n\x0eRecommendBatch\x12\x1c.qdrant.RecommendBatchPoints\x1a\x1e.qdrant.RecommendBatchResponse\"\x00\x12R\n\x0fRecommendGroups\x12\x1c.qdrant.RecommendPointGroups\x1a\x1f.qdrant.RecommendGroupsResponse\"\x00\x12>\n\x08\x44iscover\x12\x16.qdrant.DiscoverPoints\x1a\x18.qdrant.DiscoverResponse\"\x00\x12M\n\rDiscoverBatch\x12\x1b.qdrant.DiscoverBatchPoints\x1a\x1d.qdrant.DiscoverBatchResponse\"\x00\x12\x35\n\x05\x43ount\x12\x13.qdrant.CountPoints\x1a\x15.qdrant.CountResponse\"\x00\x12G\n\x0bUpdateBatch\x12\x19.qdrant.UpdateBatchPoints\x1a\x1b.qdrant.UpdateBatchResponse\"\x00\x12\x35\n\x05Query\x12\x13.qdrant.QueryPoints\x1a\x15.qdrant.QueryResponse\"\x00\x12\x44\n\nQueryBatch\x12\x18.qdrant.QueryBatchPoints\x1a\x1a.qdrant.QueryBatchResponse\"\x00\x12\x46\n\x0bQueryGroups\x12\x18.qdrant.QueryPointGroups\x1a\x1b.qdrant.QueryGroupsResponse\"\x00\x12\x35\n\x05\x46\x61\x63\x65t\x12\x13.qdrant.FacetCounts\x1a\x15.qdrant.FacetResponse\"\x00\x12T\n\x11SearchMatrixPairs\x12\x1a.qdrant.SearchMatrixPoints\x1a!.qdrant.SearchMatrixPairsResponse\"\x00\x12X\n\x13SearchMatrixOffsets\x12\x1a.qdrant.SearchMatrixPoints\x1a#.qdrant.SearchMatrixOffsetsResponse\"\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'points_service_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
|
||||
_globals['_POINTS']._serialized_start=47
|
||||
_globals['_POINTS']._serialized_end=2092
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
import google.protobuf.descriptor
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
+1027
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: qdrant_common.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13qdrant_common.proto\x12\x06qdrant\x1a\x1fgoogle/protobuf/timestamp.proto\"<\n\x07PointId\x12\r\n\x03num\x18\x01 \x01(\x04H\x00\x12\x0e\n\x04uuid\x18\x02 \x01(\tH\x00\x42\x12\n\x10point_id_options\"$\n\x08GeoPoint\x12\x0b\n\x03lon\x18\x01 \x01(\x01\x12\x0b\n\x03lat\x18\x02 \x01(\x01\"\xac\x01\n\x06\x46ilter\x12!\n\x06should\x18\x01 \x03(\x0b\x32\x11.qdrant.Condition\x12\x1f\n\x04must\x18\x02 \x03(\x0b\x32\x11.qdrant.Condition\x12#\n\x08must_not\x18\x03 \x03(\x0b\x32\x11.qdrant.Condition\x12*\n\nmin_should\x18\x04 \x01(\x0b\x32\x11.qdrant.MinShouldH\x00\x88\x01\x01\x42\r\n\x0b_min_should\"E\n\tMinShould\x12%\n\nconditions\x18\x01 \x03(\x0b\x32\x11.qdrant.Condition\x12\x11\n\tmin_count\x18\x02 \x01(\x04\"\xcb\x02\n\tCondition\x12\'\n\x05\x66ield\x18\x01 \x01(\x0b\x32\x16.qdrant.FieldConditionH\x00\x12,\n\x08is_empty\x18\x02 \x01(\x0b\x32\x18.qdrant.IsEmptyConditionH\x00\x12(\n\x06has_id\x18\x03 \x01(\x0b\x32\x16.qdrant.HasIdConditionH\x00\x12 \n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x0e.qdrant.FilterH\x00\x12*\n\x07is_null\x18\x05 \x01(\x0b\x32\x17.qdrant.IsNullConditionH\x00\x12)\n\x06nested\x18\x06 \x01(\x0b\x32\x17.qdrant.NestedConditionH\x00\x12\x30\n\nhas_vector\x18\x07 \x01(\x0b\x32\x1a.qdrant.HasVectorConditionH\x00\x42\x12\n\x10\x63ondition_one_of\"\x1f\n\x10IsEmptyCondition\x12\x0b\n\x03key\x18\x01 \x01(\t\"\x1e\n\x0fIsNullCondition\x12\x0b\n\x03key\x18\x01 \x01(\t\"1\n\x0eHasIdCondition\x12\x1f\n\x06has_id\x18\x01 \x03(\x0b\x32\x0f.qdrant.PointId\"(\n\x12HasVectorCondition\x12\x12\n\nhas_vector\x18\x01 \x01(\t\">\n\x0fNestedCondition\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1e\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x0e.qdrant.Filter\"\xfb\x02\n\x0e\x46ieldCondition\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1c\n\x05match\x18\x02 \x01(\x0b\x32\r.qdrant.Match\x12\x1c\n\x05range\x18\x03 \x01(\x0b\x32\r.qdrant.Range\x12\x30\n\x10geo_bounding_box\x18\x04 \x01(\x0b\x32\x16.qdrant.GeoBoundingBox\x12%\n\ngeo_radius\x18\x05 \x01(\x0b\x32\x11.qdrant.GeoRadius\x12)\n\x0cvalues_count\x18\x06 \x01(\x0b\x32\x13.qdrant.ValuesCount\x12\'\n\x0bgeo_polygon\x18\x07 \x01(\x0b\x32\x12.qdrant.GeoPolygon\x12-\n\x0e\x64\x61tetime_range\x18\x08 \x01(\x0b\x32\x15.qdrant.DatetimeRange\x12\x15\n\x08is_empty\x18\t \x01(\x08H\x00\x88\x01\x01\x12\x14\n\x07is_null\x18\n \x01(\x08H\x01\x88\x01\x01\x42\x0b\n\t_is_emptyB\n\n\x08_is_null\"\xc9\x02\n\x05Match\x12\x11\n\x07keyword\x18\x01 \x01(\tH\x00\x12\x11\n\x07integer\x18\x02 \x01(\x03H\x00\x12\x11\n\x07\x62oolean\x18\x03 \x01(\x08H\x00\x12\x0e\n\x04text\x18\x04 \x01(\tH\x00\x12+\n\x08keywords\x18\x05 \x01(\x0b\x32\x17.qdrant.RepeatedStringsH\x00\x12,\n\x08integers\x18\x06 \x01(\x0b\x32\x18.qdrant.RepeatedIntegersH\x00\x12\x33\n\x0f\x65xcept_integers\x18\x07 \x01(\x0b\x32\x18.qdrant.RepeatedIntegersH\x00\x12\x32\n\x0f\x65xcept_keywords\x18\x08 \x01(\x0b\x32\x17.qdrant.RepeatedStringsH\x00\x12\x10\n\x06phrase\x18\t \x01(\tH\x00\x12\x12\n\x08text_any\x18\n \x01(\tH\x00\x42\r\n\x0bmatch_value\"\"\n\x0fRepeatedStrings\x12\x0f\n\x07strings\x18\x01 \x03(\t\"$\n\x10RepeatedIntegers\x12\x10\n\x08integers\x18\x01 \x03(\x03\"k\n\x05Range\x12\x0f\n\x02lt\x18\x01 \x01(\x01H\x00\x88\x01\x01\x12\x0f\n\x02gt\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12\x10\n\x03gte\x18\x03 \x01(\x01H\x02\x88\x01\x01\x12\x10\n\x03lte\x18\x04 \x01(\x01H\x03\x88\x01\x01\x42\x05\n\x03_ltB\x05\n\x03_gtB\x06\n\x04_gteB\x06\n\x04_lte\"\xe3\x01\n\rDatetimeRange\x12+\n\x02lt\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x88\x01\x01\x12+\n\x02gt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x01\x88\x01\x01\x12,\n\x03gte\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x02\x88\x01\x01\x12,\n\x03lte\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x03\x88\x01\x01\x42\x05\n\x03_ltB\x05\n\x03_gtB\x06\n\x04_gteB\x06\n\x04_lte\"\\\n\x0eGeoBoundingBox\x12\"\n\x08top_left\x18\x01 \x01(\x0b\x32\x10.qdrant.GeoPoint\x12&\n\x0c\x62ottom_right\x18\x02 \x01(\x0b\x32\x10.qdrant.GeoPoint\"=\n\tGeoRadius\x12 \n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x10.qdrant.GeoPoint\x12\x0e\n\x06radius\x18\x02 \x01(\x02\"1\n\rGeoLineString\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.qdrant.GeoPoint\"_\n\nGeoPolygon\x12\'\n\x08\x65xterior\x18\x01 \x01(\x0b\x32\x15.qdrant.GeoLineString\x12(\n\tinteriors\x18\x02 \x03(\x0b\x32\x15.qdrant.GeoLineString\"q\n\x0bValuesCount\x12\x0f\n\x02lt\x18\x01 \x01(\x04H\x00\x88\x01\x01\x12\x0f\n\x02gt\x18\x02 \x01(\x04H\x01\x88\x01\x01\x12\x10\n\x03gte\x18\x03 \x01(\x04H\x02\x88\x01\x01\x12\x10\n\x03lte\x18\x04 \x01(\x04H\x03\x88\x01\x01\x42\x05\n\x03_ltB\x05\n\x03_gtB\x06\n\x04_gteB\x06\n\x04_lteB\x1d\x42\x06\x43ommon\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'qdrant_common_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'B\006Common\252\002\022Qdrant.Client.Grpc'
|
||||
_globals['_POINTID']._serialized_start=64
|
||||
_globals['_POINTID']._serialized_end=124
|
||||
_globals['_GEOPOINT']._serialized_start=126
|
||||
_globals['_GEOPOINT']._serialized_end=162
|
||||
_globals['_FILTER']._serialized_start=165
|
||||
_globals['_FILTER']._serialized_end=337
|
||||
_globals['_MINSHOULD']._serialized_start=339
|
||||
_globals['_MINSHOULD']._serialized_end=408
|
||||
_globals['_CONDITION']._serialized_start=411
|
||||
_globals['_CONDITION']._serialized_end=742
|
||||
_globals['_ISEMPTYCONDITION']._serialized_start=744
|
||||
_globals['_ISEMPTYCONDITION']._serialized_end=775
|
||||
_globals['_ISNULLCONDITION']._serialized_start=777
|
||||
_globals['_ISNULLCONDITION']._serialized_end=807
|
||||
_globals['_HASIDCONDITION']._serialized_start=809
|
||||
_globals['_HASIDCONDITION']._serialized_end=858
|
||||
_globals['_HASVECTORCONDITION']._serialized_start=860
|
||||
_globals['_HASVECTORCONDITION']._serialized_end=900
|
||||
_globals['_NESTEDCONDITION']._serialized_start=902
|
||||
_globals['_NESTEDCONDITION']._serialized_end=964
|
||||
_globals['_FIELDCONDITION']._serialized_start=967
|
||||
_globals['_FIELDCONDITION']._serialized_end=1346
|
||||
_globals['_MATCH']._serialized_start=1349
|
||||
_globals['_MATCH']._serialized_end=1678
|
||||
_globals['_REPEATEDSTRINGS']._serialized_start=1680
|
||||
_globals['_REPEATEDSTRINGS']._serialized_end=1714
|
||||
_globals['_REPEATEDINTEGERS']._serialized_start=1716
|
||||
_globals['_REPEATEDINTEGERS']._serialized_end=1752
|
||||
_globals['_RANGE']._serialized_start=1754
|
||||
_globals['_RANGE']._serialized_end=1861
|
||||
_globals['_DATETIMERANGE']._serialized_start=1864
|
||||
_globals['_DATETIMERANGE']._serialized_end=2091
|
||||
_globals['_GEOBOUNDINGBOX']._serialized_start=2093
|
||||
_globals['_GEOBOUNDINGBOX']._serialized_end=2185
|
||||
_globals['_GEORADIUS']._serialized_start=2187
|
||||
_globals['_GEORADIUS']._serialized_end=2248
|
||||
_globals['_GEOLINESTRING']._serialized_start=2250
|
||||
_globals['_GEOLINESTRING']._serialized_end=2299
|
||||
_globals['_GEOPOLYGON']._serialized_start=2301
|
||||
_globals['_GEOPOLYGON']._serialized_end=2396
|
||||
_globals['_VALUESCOUNT']._serialized_start=2398
|
||||
_globals['_VALUESCOUNT']._serialized_end=2511
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import google.protobuf.timestamp_pb2
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
class PointId(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
NUM_FIELD_NUMBER: builtins.int
|
||||
UUID_FIELD_NUMBER: builtins.int
|
||||
num: builtins.int
|
||||
"""Numerical ID of the point"""
|
||||
uuid: builtins.str
|
||||
"""UUID"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num: builtins.int = ...,
|
||||
uuid: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["num", b"num", "point_id_options", b"point_id_options", "uuid", b"uuid"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["num", b"num", "point_id_options", b"point_id_options", "uuid", b"uuid"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["point_id_options", b"point_id_options"]) -> typing_extensions.Literal["num", "uuid"] | None: ...
|
||||
|
||||
global___PointId = PointId
|
||||
|
||||
class GeoPoint(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
LON_FIELD_NUMBER: builtins.int
|
||||
LAT_FIELD_NUMBER: builtins.int
|
||||
lon: builtins.float
|
||||
lat: builtins.float
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
lon: builtins.float = ...,
|
||||
lat: builtins.float = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["lat", b"lat", "lon", b"lon"]) -> None: ...
|
||||
|
||||
global___GeoPoint = GeoPoint
|
||||
|
||||
class Filter(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SHOULD_FIELD_NUMBER: builtins.int
|
||||
MUST_FIELD_NUMBER: builtins.int
|
||||
MUST_NOT_FIELD_NUMBER: builtins.int
|
||||
MIN_SHOULD_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def should(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Condition]:
|
||||
"""At least one of those conditions should match"""
|
||||
@property
|
||||
def must(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Condition]:
|
||||
"""All conditions must match"""
|
||||
@property
|
||||
def must_not(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Condition]:
|
||||
"""All conditions must NOT match"""
|
||||
@property
|
||||
def min_should(self) -> global___MinShould:
|
||||
"""At least minimum amount of given conditions should match"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
should: collections.abc.Iterable[global___Condition] | None = ...,
|
||||
must: collections.abc.Iterable[global___Condition] | None = ...,
|
||||
must_not: collections.abc.Iterable[global___Condition] | None = ...,
|
||||
min_should: global___MinShould | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["_min_should", b"_min_should", "min_should", b"min_should"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["_min_should", b"_min_should", "min_should", b"min_should", "must", b"must", "must_not", b"must_not", "should", b"should"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_min_should", b"_min_should"]) -> typing_extensions.Literal["min_should"] | None: ...
|
||||
|
||||
global___Filter = Filter
|
||||
|
||||
class MinShould(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
CONDITIONS_FIELD_NUMBER: builtins.int
|
||||
MIN_COUNT_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def conditions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Condition]: ...
|
||||
min_count: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
conditions: collections.abc.Iterable[global___Condition] | None = ...,
|
||||
min_count: builtins.int = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["conditions", b"conditions", "min_count", b"min_count"]) -> None: ...
|
||||
|
||||
global___MinShould = MinShould
|
||||
|
||||
class Condition(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
FIELD_FIELD_NUMBER: builtins.int
|
||||
IS_EMPTY_FIELD_NUMBER: builtins.int
|
||||
HAS_ID_FIELD_NUMBER: builtins.int
|
||||
FILTER_FIELD_NUMBER: builtins.int
|
||||
IS_NULL_FIELD_NUMBER: builtins.int
|
||||
NESTED_FIELD_NUMBER: builtins.int
|
||||
HAS_VECTOR_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def field(self) -> global___FieldCondition: ...
|
||||
@property
|
||||
def is_empty(self) -> global___IsEmptyCondition: ...
|
||||
@property
|
||||
def has_id(self) -> global___HasIdCondition: ...
|
||||
@property
|
||||
def filter(self) -> global___Filter: ...
|
||||
@property
|
||||
def is_null(self) -> global___IsNullCondition: ...
|
||||
@property
|
||||
def nested(self) -> global___NestedCondition: ...
|
||||
@property
|
||||
def has_vector(self) -> global___HasVectorCondition: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
field: global___FieldCondition | None = ...,
|
||||
is_empty: global___IsEmptyCondition | None = ...,
|
||||
has_id: global___HasIdCondition | None = ...,
|
||||
filter: global___Filter | None = ...,
|
||||
is_null: global___IsNullCondition | None = ...,
|
||||
nested: global___NestedCondition | None = ...,
|
||||
has_vector: global___HasVectorCondition | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["condition_one_of", b"condition_one_of", "field", b"field", "filter", b"filter", "has_id", b"has_id", "has_vector", b"has_vector", "is_empty", b"is_empty", "is_null", b"is_null", "nested", b"nested"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["condition_one_of", b"condition_one_of", "field", b"field", "filter", b"filter", "has_id", b"has_id", "has_vector", b"has_vector", "is_empty", b"is_empty", "is_null", b"is_null", "nested", b"nested"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["condition_one_of", b"condition_one_of"]) -> typing_extensions.Literal["field", "is_empty", "has_id", "filter", "is_null", "nested", "has_vector"] | None: ...
|
||||
|
||||
global___Condition = Condition
|
||||
|
||||
class IsEmptyCondition(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["key", b"key"]) -> None: ...
|
||||
|
||||
global___IsEmptyCondition = IsEmptyCondition
|
||||
|
||||
class IsNullCondition(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["key", b"key"]) -> None: ...
|
||||
|
||||
global___IsNullCondition = IsNullCondition
|
||||
|
||||
class HasIdCondition(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
HAS_ID_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def has_id(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PointId]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
has_id: collections.abc.Iterable[global___PointId] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["has_id", b"has_id"]) -> None: ...
|
||||
|
||||
global___HasIdCondition = HasIdCondition
|
||||
|
||||
class HasVectorCondition(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
HAS_VECTOR_FIELD_NUMBER: builtins.int
|
||||
has_vector: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
has_vector: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["has_vector", b"has_vector"]) -> None: ...
|
||||
|
||||
global___HasVectorCondition = HasVectorCondition
|
||||
|
||||
class NestedCondition(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
FILTER_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
"""Path to nested object"""
|
||||
@property
|
||||
def filter(self) -> global___Filter:
|
||||
"""Filter condition"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
filter: global___Filter | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["filter", b"filter"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["filter", b"filter", "key", b"key"]) -> None: ...
|
||||
|
||||
global___NestedCondition = NestedCondition
|
||||
|
||||
class FieldCondition(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
MATCH_FIELD_NUMBER: builtins.int
|
||||
RANGE_FIELD_NUMBER: builtins.int
|
||||
GEO_BOUNDING_BOX_FIELD_NUMBER: builtins.int
|
||||
GEO_RADIUS_FIELD_NUMBER: builtins.int
|
||||
VALUES_COUNT_FIELD_NUMBER: builtins.int
|
||||
GEO_POLYGON_FIELD_NUMBER: builtins.int
|
||||
DATETIME_RANGE_FIELD_NUMBER: builtins.int
|
||||
IS_EMPTY_FIELD_NUMBER: builtins.int
|
||||
IS_NULL_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
@property
|
||||
def match(self) -> global___Match:
|
||||
"""Check if point has field with a given value"""
|
||||
@property
|
||||
def range(self) -> global___Range:
|
||||
"""Check if points value lies in a given range"""
|
||||
@property
|
||||
def geo_bounding_box(self) -> global___GeoBoundingBox:
|
||||
"""Check if points geolocation lies in a given area"""
|
||||
@property
|
||||
def geo_radius(self) -> global___GeoRadius:
|
||||
"""Check if geo point is within a given radius"""
|
||||
@property
|
||||
def values_count(self) -> global___ValuesCount:
|
||||
"""Check number of values for a specific field"""
|
||||
@property
|
||||
def geo_polygon(self) -> global___GeoPolygon:
|
||||
"""Check if geo point is within a given polygon"""
|
||||
@property
|
||||
def datetime_range(self) -> global___DatetimeRange:
|
||||
"""Check if datetime is within a given range"""
|
||||
is_empty: builtins.bool
|
||||
"""Check if field is empty"""
|
||||
is_null: builtins.bool
|
||||
"""Check if field is null"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str = ...,
|
||||
match: global___Match | None = ...,
|
||||
range: global___Range | None = ...,
|
||||
geo_bounding_box: global___GeoBoundingBox | None = ...,
|
||||
geo_radius: global___GeoRadius | None = ...,
|
||||
values_count: global___ValuesCount | None = ...,
|
||||
geo_polygon: global___GeoPolygon | None = ...,
|
||||
datetime_range: global___DatetimeRange | None = ...,
|
||||
is_empty: builtins.bool | None = ...,
|
||||
is_null: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["_is_empty", b"_is_empty", "_is_null", b"_is_null", "datetime_range", b"datetime_range", "geo_bounding_box", b"geo_bounding_box", "geo_polygon", b"geo_polygon", "geo_radius", b"geo_radius", "is_empty", b"is_empty", "is_null", b"is_null", "match", b"match", "range", b"range", "values_count", b"values_count"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["_is_empty", b"_is_empty", "_is_null", b"_is_null", "datetime_range", b"datetime_range", "geo_bounding_box", b"geo_bounding_box", "geo_polygon", b"geo_polygon", "geo_radius", b"geo_radius", "is_empty", b"is_empty", "is_null", b"is_null", "key", b"key", "match", b"match", "range", b"range", "values_count", b"values_count"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_is_empty", b"_is_empty"]) -> typing_extensions.Literal["is_empty"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_is_null", b"_is_null"]) -> typing_extensions.Literal["is_null"] | None: ...
|
||||
|
||||
global___FieldCondition = FieldCondition
|
||||
|
||||
class Match(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEYWORD_FIELD_NUMBER: builtins.int
|
||||
INTEGER_FIELD_NUMBER: builtins.int
|
||||
BOOLEAN_FIELD_NUMBER: builtins.int
|
||||
TEXT_FIELD_NUMBER: builtins.int
|
||||
KEYWORDS_FIELD_NUMBER: builtins.int
|
||||
INTEGERS_FIELD_NUMBER: builtins.int
|
||||
EXCEPT_INTEGERS_FIELD_NUMBER: builtins.int
|
||||
EXCEPT_KEYWORDS_FIELD_NUMBER: builtins.int
|
||||
PHRASE_FIELD_NUMBER: builtins.int
|
||||
TEXT_ANY_FIELD_NUMBER: builtins.int
|
||||
keyword: builtins.str
|
||||
"""Match string keyword"""
|
||||
integer: builtins.int
|
||||
"""Match integer"""
|
||||
boolean: builtins.bool
|
||||
"""Match boolean"""
|
||||
text: builtins.str
|
||||
"""Match text"""
|
||||
@property
|
||||
def keywords(self) -> global___RepeatedStrings:
|
||||
"""Match multiple keywords"""
|
||||
@property
|
||||
def integers(self) -> global___RepeatedIntegers:
|
||||
"""Match multiple integers"""
|
||||
@property
|
||||
def except_integers(self) -> global___RepeatedIntegers:
|
||||
"""Match any other value except those integers"""
|
||||
@property
|
||||
def except_keywords(self) -> global___RepeatedStrings:
|
||||
"""Match any other value except those keywords"""
|
||||
phrase: builtins.str
|
||||
"""Match phrase text"""
|
||||
text_any: builtins.str
|
||||
"""Match any word in the text"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
keyword: builtins.str = ...,
|
||||
integer: builtins.int = ...,
|
||||
boolean: builtins.bool = ...,
|
||||
text: builtins.str = ...,
|
||||
keywords: global___RepeatedStrings | None = ...,
|
||||
integers: global___RepeatedIntegers | None = ...,
|
||||
except_integers: global___RepeatedIntegers | None = ...,
|
||||
except_keywords: global___RepeatedStrings | None = ...,
|
||||
phrase: builtins.str = ...,
|
||||
text_any: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["boolean", b"boolean", "except_integers", b"except_integers", "except_keywords", b"except_keywords", "integer", b"integer", "integers", b"integers", "keyword", b"keyword", "keywords", b"keywords", "match_value", b"match_value", "phrase", b"phrase", "text", b"text", "text_any", b"text_any"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["boolean", b"boolean", "except_integers", b"except_integers", "except_keywords", b"except_keywords", "integer", b"integer", "integers", b"integers", "keyword", b"keyword", "keywords", b"keywords", "match_value", b"match_value", "phrase", b"phrase", "text", b"text", "text_any", b"text_any"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["match_value", b"match_value"]) -> typing_extensions.Literal["keyword", "integer", "boolean", "text", "keywords", "integers", "except_integers", "except_keywords", "phrase", "text_any"] | None: ...
|
||||
|
||||
global___Match = Match
|
||||
|
||||
class RepeatedStrings(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
STRINGS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def strings(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
strings: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["strings", b"strings"]) -> None: ...
|
||||
|
||||
global___RepeatedStrings = RepeatedStrings
|
||||
|
||||
class RepeatedIntegers(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
INTEGERS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def integers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
integers: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["integers", b"integers"]) -> None: ...
|
||||
|
||||
global___RepeatedIntegers = RepeatedIntegers
|
||||
|
||||
class Range(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
LT_FIELD_NUMBER: builtins.int
|
||||
GT_FIELD_NUMBER: builtins.int
|
||||
GTE_FIELD_NUMBER: builtins.int
|
||||
LTE_FIELD_NUMBER: builtins.int
|
||||
lt: builtins.float
|
||||
gt: builtins.float
|
||||
gte: builtins.float
|
||||
lte: builtins.float
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
lt: builtins.float | None = ...,
|
||||
gt: builtins.float | None = ...,
|
||||
gte: builtins.float | None = ...,
|
||||
lte: builtins.float | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gt", b"_gt"]) -> typing_extensions.Literal["gt"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gte", b"_gte"]) -> typing_extensions.Literal["gte"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lt", b"_lt"]) -> typing_extensions.Literal["lt"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lte", b"_lte"]) -> typing_extensions.Literal["lte"] | None: ...
|
||||
|
||||
global___Range = Range
|
||||
|
||||
class DatetimeRange(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
LT_FIELD_NUMBER: builtins.int
|
||||
GT_FIELD_NUMBER: builtins.int
|
||||
GTE_FIELD_NUMBER: builtins.int
|
||||
LTE_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def lt(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
|
||||
@property
|
||||
def gt(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
|
||||
@property
|
||||
def gte(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
|
||||
@property
|
||||
def lte(self) -> google.protobuf.timestamp_pb2.Timestamp: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
lt: google.protobuf.timestamp_pb2.Timestamp | None = ...,
|
||||
gt: google.protobuf.timestamp_pb2.Timestamp | None = ...,
|
||||
gte: google.protobuf.timestamp_pb2.Timestamp | None = ...,
|
||||
lte: google.protobuf.timestamp_pb2.Timestamp | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gt", b"_gt"]) -> typing_extensions.Literal["gt"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gte", b"_gte"]) -> typing_extensions.Literal["gte"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lt", b"_lt"]) -> typing_extensions.Literal["lt"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lte", b"_lte"]) -> typing_extensions.Literal["lte"] | None: ...
|
||||
|
||||
global___DatetimeRange = DatetimeRange
|
||||
|
||||
class GeoBoundingBox(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOP_LEFT_FIELD_NUMBER: builtins.int
|
||||
BOTTOM_RIGHT_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def top_left(self) -> global___GeoPoint:
|
||||
"""north-west corner"""
|
||||
@property
|
||||
def bottom_right(self) -> global___GeoPoint:
|
||||
"""south-east corner"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
top_left: global___GeoPoint | None = ...,
|
||||
bottom_right: global___GeoPoint | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["bottom_right", b"bottom_right", "top_left", b"top_left"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["bottom_right", b"bottom_right", "top_left", b"top_left"]) -> None: ...
|
||||
|
||||
global___GeoBoundingBox = GeoBoundingBox
|
||||
|
||||
class GeoRadius(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
CENTER_FIELD_NUMBER: builtins.int
|
||||
RADIUS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def center(self) -> global___GeoPoint:
|
||||
"""Center of the circle"""
|
||||
radius: builtins.float
|
||||
"""In meters"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
center: global___GeoPoint | None = ...,
|
||||
radius: builtins.float = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["center", b"center"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["center", b"center", "radius", b"radius"]) -> None: ...
|
||||
|
||||
global___GeoRadius = GeoRadius
|
||||
|
||||
class GeoLineString(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
POINTS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def points(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GeoPoint]:
|
||||
"""Ordered sequence of GeoPoints representing the line"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
points: collections.abc.Iterable[global___GeoPoint] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["points", b"points"]) -> None: ...
|
||||
|
||||
global___GeoLineString = GeoLineString
|
||||
|
||||
class GeoPolygon(google.protobuf.message.Message):
|
||||
"""For a valid GeoPolygon, both the exterior and interior GeoLineStrings must consist of a minimum of 4 points.
|
||||
Additionally, the first and last points of each GeoLineString must be the same.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
EXTERIOR_FIELD_NUMBER: builtins.int
|
||||
INTERIORS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def exterior(self) -> global___GeoLineString:
|
||||
"""The exterior line bounds the surface"""
|
||||
@property
|
||||
def interiors(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GeoLineString]:
|
||||
"""Interior lines (if present) bound holes within the surface"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exterior: global___GeoLineString | None = ...,
|
||||
interiors: collections.abc.Iterable[global___GeoLineString] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["exterior", b"exterior"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["exterior", b"exterior", "interiors", b"interiors"]) -> None: ...
|
||||
|
||||
global___GeoPolygon = GeoPolygon
|
||||
|
||||
class ValuesCount(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
LT_FIELD_NUMBER: builtins.int
|
||||
GT_FIELD_NUMBER: builtins.int
|
||||
GTE_FIELD_NUMBER: builtins.int
|
||||
LTE_FIELD_NUMBER: builtins.int
|
||||
lt: builtins.int
|
||||
gt: builtins.int
|
||||
gte: builtins.int
|
||||
lte: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
lt: builtins.int | None = ...,
|
||||
gt: builtins.int | None = ...,
|
||||
gte: builtins.int | None = ...,
|
||||
lte: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["_gt", b"_gt", "_gte", b"_gte", "_lt", b"_lt", "_lte", b"_lte", "gt", b"gt", "gte", b"gte", "lt", b"lt", "lte", b"lte"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gt", b"_gt"]) -> typing_extensions.Literal["gt"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_gte", b"_gte"]) -> typing_extensions.Literal["gte"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lt", b"_lt"]) -> typing_extensions.Literal["lt"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_lte", b"_lte"]) -> typing_extensions.Literal["lte"] | None: ...
|
||||
|
||||
global___ValuesCount = ValuesCount
|
||||
@@ -0,0 +1,4 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: qdrant.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from . import collections_service_pb2 as collections__service__pb2
|
||||
from . import points_service_pb2 as points__service__pb2
|
||||
from . import snapshots_service_pb2 as snapshots__service__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cqdrant.proto\x12\x06qdrant\x1a\x19\x63ollections_service.proto\x1a\x14points_service.proto\x1a\x17snapshots_service.proto\"\x14\n\x12HealthCheckRequest\"R\n\x10HealthCheckReply\x12\r\n\x05title\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x13\n\x06\x63ommit\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_commit2O\n\x06Qdrant\x12\x45\n\x0bHealthCheck\x12\x1a.qdrant.HealthCheckRequest\x1a\x18.qdrant.HealthCheckReply\"\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'qdrant_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=98
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=118
|
||||
_globals['_HEALTHCHECKREPLY']._serialized_start=120
|
||||
_globals['_HEALTHCHECKREPLY']._serialized_end=202
|
||||
_globals['_QDRANT']._serialized_start=204
|
||||
_globals['_QDRANT']._serialized_end=283
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
import builtins
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
class HealthCheckRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
global___HealthCheckRequest = HealthCheckRequest
|
||||
|
||||
class HealthCheckReply(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TITLE_FIELD_NUMBER: builtins.int
|
||||
VERSION_FIELD_NUMBER: builtins.int
|
||||
COMMIT_FIELD_NUMBER: builtins.int
|
||||
title: builtins.str
|
||||
version: builtins.str
|
||||
commit: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
title: builtins.str = ...,
|
||||
version: builtins.str = ...,
|
||||
commit: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["_commit", b"_commit", "commit", b"commit"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["_commit", b"_commit", "commit", b"commit", "title", b"title", "version", b"version"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_commit", b"_commit"]) -> typing_extensions.Literal["commit"] | None: ...
|
||||
|
||||
global___HealthCheckReply = HealthCheckReply
|
||||
@@ -0,0 +1,66 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
from . import qdrant_pb2 as qdrant__pb2
|
||||
|
||||
|
||||
class QdrantStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.HealthCheck = channel.unary_unary(
|
||||
'/qdrant.Qdrant/HealthCheck',
|
||||
request_serializer=qdrant__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=qdrant__pb2.HealthCheckReply.FromString,
|
||||
)
|
||||
|
||||
|
||||
class QdrantServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def HealthCheck(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_QdrantServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'HealthCheck': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.HealthCheck,
|
||||
request_deserializer=qdrant__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=qdrant__pb2.HealthCheckReply.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'qdrant.Qdrant', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Qdrant(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def HealthCheck(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Qdrant/HealthCheck',
|
||||
qdrant__pb2.HealthCheckRequest.SerializeToString,
|
||||
qdrant__pb2.HealthCheckReply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
@@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: snapshots_service.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17snapshots_service.proto\x12\x06qdrant\x1a\x1fgoogle/protobuf/timestamp.proto\"\x1b\n\x19\x43reateFullSnapshotRequest\"\x1a\n\x18ListFullSnapshotsRequest\"2\n\x19\x44\x65leteFullSnapshotRequest\x12\x15\n\rsnapshot_name\x18\x01 \x01(\t\"0\n\x15\x43reateSnapshotRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\"/\n\x14ListSnapshotsRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\"G\n\x15\x44\x65leteSnapshotRequest\x12\x17\n\x0f\x63ollection_name\x18\x01 \x01(\t\x12\x15\n\rsnapshot_name\x18\x02 \x01(\t\"\x88\x01\n\x13SnapshotDescription\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\rcreation_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04size\x18\x03 \x01(\x03\x12\x15\n\x08\x63hecksum\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\x0b\n\t_checksum\"a\n\x16\x43reateSnapshotResponse\x12\x39\n\x14snapshot_description\x18\x01 \x01(\x0b\x32\x1b.qdrant.SnapshotDescription\x12\x0c\n\x04time\x18\x02 \x01(\x01\"a\n\x15ListSnapshotsResponse\x12:\n\x15snapshot_descriptions\x18\x01 \x03(\x0b\x32\x1b.qdrant.SnapshotDescription\x12\x0c\n\x04time\x18\x02 \x01(\x01\"&\n\x16\x44\x65leteSnapshotResponse\x12\x0c\n\x04time\x18\x01 \x01(\x01\x32\xdd\x03\n\tSnapshots\x12I\n\x06\x43reate\x12\x1d.qdrant.CreateSnapshotRequest\x1a\x1e.qdrant.CreateSnapshotResponse\"\x00\x12\x45\n\x04List\x12\x1c.qdrant.ListSnapshotsRequest\x1a\x1d.qdrant.ListSnapshotsResponse\"\x00\x12I\n\x06\x44\x65lete\x12\x1d.qdrant.DeleteSnapshotRequest\x1a\x1e.qdrant.DeleteSnapshotResponse\"\x00\x12Q\n\nCreateFull\x12!.qdrant.CreateFullSnapshotRequest\x1a\x1e.qdrant.CreateSnapshotResponse\"\x00\x12M\n\x08ListFull\x12 .qdrant.ListFullSnapshotsRequest\x1a\x1d.qdrant.ListSnapshotsResponse\"\x00\x12Q\n\nDeleteFull\x12!.qdrant.DeleteFullSnapshotRequest\x1a\x1e.qdrant.DeleteSnapshotResponse\"\x00\x42\x15\xaa\x02\x12Qdrant.Client.Grpcb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'snapshots_service_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\022Qdrant.Client.Grpc'
|
||||
_globals['_CREATEFULLSNAPSHOTREQUEST']._serialized_start=68
|
||||
_globals['_CREATEFULLSNAPSHOTREQUEST']._serialized_end=95
|
||||
_globals['_LISTFULLSNAPSHOTSREQUEST']._serialized_start=97
|
||||
_globals['_LISTFULLSNAPSHOTSREQUEST']._serialized_end=123
|
||||
_globals['_DELETEFULLSNAPSHOTREQUEST']._serialized_start=125
|
||||
_globals['_DELETEFULLSNAPSHOTREQUEST']._serialized_end=175
|
||||
_globals['_CREATESNAPSHOTREQUEST']._serialized_start=177
|
||||
_globals['_CREATESNAPSHOTREQUEST']._serialized_end=225
|
||||
_globals['_LISTSNAPSHOTSREQUEST']._serialized_start=227
|
||||
_globals['_LISTSNAPSHOTSREQUEST']._serialized_end=274
|
||||
_globals['_DELETESNAPSHOTREQUEST']._serialized_start=276
|
||||
_globals['_DELETESNAPSHOTREQUEST']._serialized_end=347
|
||||
_globals['_SNAPSHOTDESCRIPTION']._serialized_start=350
|
||||
_globals['_SNAPSHOTDESCRIPTION']._serialized_end=486
|
||||
_globals['_CREATESNAPSHOTRESPONSE']._serialized_start=488
|
||||
_globals['_CREATESNAPSHOTRESPONSE']._serialized_end=585
|
||||
_globals['_LISTSNAPSHOTSRESPONSE']._serialized_start=587
|
||||
_globals['_LISTSNAPSHOTSRESPONSE']._serialized_end=684
|
||||
_globals['_DELETESNAPSHOTRESPONSE']._serialized_start=686
|
||||
_globals['_DELETESNAPSHOTRESPONSE']._serialized_end=724
|
||||
_globals['_SNAPSHOTS']._serialized_start=727
|
||||
_globals['_SNAPSHOTS']._serialized_end=1204
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import google.protobuf.timestamp_pb2
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
class CreateFullSnapshotRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
global___CreateFullSnapshotRequest = CreateFullSnapshotRequest
|
||||
|
||||
class ListFullSnapshotsRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
global___ListFullSnapshotsRequest = ListFullSnapshotsRequest
|
||||
|
||||
class DeleteFullSnapshotRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SNAPSHOT_NAME_FIELD_NUMBER: builtins.int
|
||||
snapshot_name: builtins.str
|
||||
"""Name of the full snapshot"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
snapshot_name: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["snapshot_name", b"snapshot_name"]) -> None: ...
|
||||
|
||||
global___DeleteFullSnapshotRequest = DeleteFullSnapshotRequest
|
||||
|
||||
class CreateSnapshotRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
COLLECTION_NAME_FIELD_NUMBER: builtins.int
|
||||
collection_name: builtins.str
|
||||
"""Name of the collection"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
collection_name: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["collection_name", b"collection_name"]) -> None: ...
|
||||
|
||||
global___CreateSnapshotRequest = CreateSnapshotRequest
|
||||
|
||||
class ListSnapshotsRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
COLLECTION_NAME_FIELD_NUMBER: builtins.int
|
||||
collection_name: builtins.str
|
||||
"""Name of the collection"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
collection_name: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["collection_name", b"collection_name"]) -> None: ...
|
||||
|
||||
global___ListSnapshotsRequest = ListSnapshotsRequest
|
||||
|
||||
class DeleteSnapshotRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
COLLECTION_NAME_FIELD_NUMBER: builtins.int
|
||||
SNAPSHOT_NAME_FIELD_NUMBER: builtins.int
|
||||
collection_name: builtins.str
|
||||
"""Name of the collection"""
|
||||
snapshot_name: builtins.str
|
||||
"""Name of the collection snapshot"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
collection_name: builtins.str = ...,
|
||||
snapshot_name: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["collection_name", b"collection_name", "snapshot_name", b"snapshot_name"]) -> None: ...
|
||||
|
||||
global___DeleteSnapshotRequest = DeleteSnapshotRequest
|
||||
|
||||
class SnapshotDescription(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
NAME_FIELD_NUMBER: builtins.int
|
||||
CREATION_TIME_FIELD_NUMBER: builtins.int
|
||||
SIZE_FIELD_NUMBER: builtins.int
|
||||
CHECKSUM_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
"""Name of the snapshot"""
|
||||
@property
|
||||
def creation_time(self) -> google.protobuf.timestamp_pb2.Timestamp:
|
||||
"""Creation time of the snapshot"""
|
||||
size: builtins.int
|
||||
"""Size of the snapshot in bytes"""
|
||||
checksum: builtins.str
|
||||
"""SHA256 digest of the snapshot file"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: builtins.str = ...,
|
||||
creation_time: google.protobuf.timestamp_pb2.Timestamp | None = ...,
|
||||
size: builtins.int = ...,
|
||||
checksum: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["_checksum", b"_checksum", "checksum", b"checksum", "creation_time", b"creation_time"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["_checksum", b"_checksum", "checksum", b"checksum", "creation_time", b"creation_time", "name", b"name", "size", b"size"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing_extensions.Literal["_checksum", b"_checksum"]) -> typing_extensions.Literal["checksum"] | None: ...
|
||||
|
||||
global___SnapshotDescription = SnapshotDescription
|
||||
|
||||
class CreateSnapshotResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SNAPSHOT_DESCRIPTION_FIELD_NUMBER: builtins.int
|
||||
TIME_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def snapshot_description(self) -> global___SnapshotDescription: ...
|
||||
time: builtins.float
|
||||
"""Time spent to process"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
snapshot_description: global___SnapshotDescription | None = ...,
|
||||
time: builtins.float = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing_extensions.Literal["snapshot_description", b"snapshot_description"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["snapshot_description", b"snapshot_description", "time", b"time"]) -> None: ...
|
||||
|
||||
global___CreateSnapshotResponse = CreateSnapshotResponse
|
||||
|
||||
class ListSnapshotsResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SNAPSHOT_DESCRIPTIONS_FIELD_NUMBER: builtins.int
|
||||
TIME_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def snapshot_descriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___SnapshotDescription]: ...
|
||||
time: builtins.float
|
||||
"""Time spent to process"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
snapshot_descriptions: collections.abc.Iterable[global___SnapshotDescription] | None = ...,
|
||||
time: builtins.float = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["snapshot_descriptions", b"snapshot_descriptions", "time", b"time"]) -> None: ...
|
||||
|
||||
global___ListSnapshotsResponse = ListSnapshotsResponse
|
||||
|
||||
class DeleteSnapshotResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TIME_FIELD_NUMBER: builtins.int
|
||||
time: builtins.float
|
||||
"""Time spent to process"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
time: builtins.float = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing_extensions.Literal["time", b"time"]) -> None: ...
|
||||
|
||||
global___DeleteSnapshotResponse = DeleteSnapshotResponse
|
||||
+243
@@ -0,0 +1,243 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
from . import snapshots_service_pb2 as snapshots__service__pb2
|
||||
|
||||
|
||||
class SnapshotsStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Create = channel.unary_unary(
|
||||
'/qdrant.Snapshots/Create',
|
||||
request_serializer=snapshots__service__pb2.CreateSnapshotRequest.SerializeToString,
|
||||
response_deserializer=snapshots__service__pb2.CreateSnapshotResponse.FromString,
|
||||
)
|
||||
self.List = channel.unary_unary(
|
||||
'/qdrant.Snapshots/List',
|
||||
request_serializer=snapshots__service__pb2.ListSnapshotsRequest.SerializeToString,
|
||||
response_deserializer=snapshots__service__pb2.ListSnapshotsResponse.FromString,
|
||||
)
|
||||
self.Delete = channel.unary_unary(
|
||||
'/qdrant.Snapshots/Delete',
|
||||
request_serializer=snapshots__service__pb2.DeleteSnapshotRequest.SerializeToString,
|
||||
response_deserializer=snapshots__service__pb2.DeleteSnapshotResponse.FromString,
|
||||
)
|
||||
self.CreateFull = channel.unary_unary(
|
||||
'/qdrant.Snapshots/CreateFull',
|
||||
request_serializer=snapshots__service__pb2.CreateFullSnapshotRequest.SerializeToString,
|
||||
response_deserializer=snapshots__service__pb2.CreateSnapshotResponse.FromString,
|
||||
)
|
||||
self.ListFull = channel.unary_unary(
|
||||
'/qdrant.Snapshots/ListFull',
|
||||
request_serializer=snapshots__service__pb2.ListFullSnapshotsRequest.SerializeToString,
|
||||
response_deserializer=snapshots__service__pb2.ListSnapshotsResponse.FromString,
|
||||
)
|
||||
self.DeleteFull = channel.unary_unary(
|
||||
'/qdrant.Snapshots/DeleteFull',
|
||||
request_serializer=snapshots__service__pb2.DeleteFullSnapshotRequest.SerializeToString,
|
||||
response_deserializer=snapshots__service__pb2.DeleteSnapshotResponse.FromString,
|
||||
)
|
||||
|
||||
|
||||
class SnapshotsServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Create(self, request, context):
|
||||
"""
|
||||
Create collection snapshot
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def List(self, request, context):
|
||||
"""
|
||||
List collection snapshots
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Delete(self, request, context):
|
||||
"""
|
||||
Delete collection snapshot
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CreateFull(self, request, context):
|
||||
"""
|
||||
Create full storage snapshot
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def ListFull(self, request, context):
|
||||
"""
|
||||
List full storage snapshots
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def DeleteFull(self, request, context):
|
||||
"""
|
||||
Delete full storage snapshot
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_SnapshotsServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Create': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Create,
|
||||
request_deserializer=snapshots__service__pb2.CreateSnapshotRequest.FromString,
|
||||
response_serializer=snapshots__service__pb2.CreateSnapshotResponse.SerializeToString,
|
||||
),
|
||||
'List': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.List,
|
||||
request_deserializer=snapshots__service__pb2.ListSnapshotsRequest.FromString,
|
||||
response_serializer=snapshots__service__pb2.ListSnapshotsResponse.SerializeToString,
|
||||
),
|
||||
'Delete': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Delete,
|
||||
request_deserializer=snapshots__service__pb2.DeleteSnapshotRequest.FromString,
|
||||
response_serializer=snapshots__service__pb2.DeleteSnapshotResponse.SerializeToString,
|
||||
),
|
||||
'CreateFull': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CreateFull,
|
||||
request_deserializer=snapshots__service__pb2.CreateFullSnapshotRequest.FromString,
|
||||
response_serializer=snapshots__service__pb2.CreateSnapshotResponse.SerializeToString,
|
||||
),
|
||||
'ListFull': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ListFull,
|
||||
request_deserializer=snapshots__service__pb2.ListFullSnapshotsRequest.FromString,
|
||||
response_serializer=snapshots__service__pb2.ListSnapshotsResponse.SerializeToString,
|
||||
),
|
||||
'DeleteFull': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.DeleteFull,
|
||||
request_deserializer=snapshots__service__pb2.DeleteFullSnapshotRequest.FromString,
|
||||
response_serializer=snapshots__service__pb2.DeleteSnapshotResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'qdrant.Snapshots', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class Snapshots(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Create(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/Create',
|
||||
snapshots__service__pb2.CreateSnapshotRequest.SerializeToString,
|
||||
snapshots__service__pb2.CreateSnapshotResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def List(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/List',
|
||||
snapshots__service__pb2.ListSnapshotsRequest.SerializeToString,
|
||||
snapshots__service__pb2.ListSnapshotsResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Delete(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/Delete',
|
||||
snapshots__service__pb2.DeleteSnapshotRequest.SerializeToString,
|
||||
snapshots__service__pb2.DeleteSnapshotResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def CreateFull(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/CreateFull',
|
||||
snapshots__service__pb2.CreateFullSnapshotRequest.SerializeToString,
|
||||
snapshots__service__pb2.CreateSnapshotResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def ListFull(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/ListFull',
|
||||
snapshots__service__pb2.ListFullSnapshotsRequest.SerializeToString,
|
||||
snapshots__service__pb2.ListSnapshotsResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def DeleteFull(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/qdrant.Snapshots/DeleteFull',
|
||||
snapshots__service__pb2.DeleteFullSnapshotRequest.SerializeToString,
|
||||
snapshots__service__pb2.DeleteSnapshotResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
@@ -0,0 +1,17 @@
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client._pydantic_compat import update_forward_refs
|
||||
from qdrant_client.http.api_client import ( # noqa F401
|
||||
ApiClient as ApiClient,
|
||||
AsyncApiClient as AsyncApiClient,
|
||||
AsyncApis as AsyncApis,
|
||||
SyncApis as SyncApis,
|
||||
)
|
||||
from qdrant_client.http.models import models as models # noqa F401
|
||||
|
||||
for model in inspect.getmembers(models, inspect.isclass):
|
||||
if model[1].__module__ == "qdrant_client.http.models.models":
|
||||
model_class = model[1]
|
||||
if issubclass(model_class, BaseModel):
|
||||
update_forward_refs(model_class)
|
||||
@@ -0,0 +1,171 @@
|
||||
# flake8: noqa E501
|
||||
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.http.models import models as m
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _AliasesApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_get_collection_aliases(
|
||||
self,
|
||||
collection_name: str,
|
||||
):
|
||||
"""
|
||||
Get list of all aliases for a collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2008,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}/aliases",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_get_collections_aliases(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Get list of all existing collections aliases
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2008,
|
||||
method="GET",
|
||||
url="/aliases",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_update_aliases(
|
||||
self,
|
||||
timeout: int = None,
|
||||
change_aliases_operation: m.ChangeAliasesOperation = None,
|
||||
):
|
||||
query_params = {}
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(change_aliases_operation)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="POST",
|
||||
url="/collections/aliases",
|
||||
headers=headers if headers else None,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
|
||||
class AsyncAliasesApi(_AliasesApi):
|
||||
async def get_collection_aliases(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse2008:
|
||||
"""
|
||||
Get list of all aliases for a collection
|
||||
"""
|
||||
return await self._build_for_get_collection_aliases(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
async def get_collections_aliases(
|
||||
self,
|
||||
) -> m.InlineResponse2008:
|
||||
"""
|
||||
Get list of all existing collections aliases
|
||||
"""
|
||||
return await self._build_for_get_collections_aliases()
|
||||
|
||||
async def update_aliases(
|
||||
self,
|
||||
timeout: int = None,
|
||||
change_aliases_operation: m.ChangeAliasesOperation = None,
|
||||
) -> m.InlineResponse200:
|
||||
return await self._build_for_update_aliases(
|
||||
timeout=timeout,
|
||||
change_aliases_operation=change_aliases_operation,
|
||||
)
|
||||
|
||||
|
||||
class SyncAliasesApi(_AliasesApi):
|
||||
def get_collection_aliases(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse2008:
|
||||
"""
|
||||
Get list of all aliases for a collection
|
||||
"""
|
||||
return self._build_for_get_collection_aliases(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
def get_collections_aliases(
|
||||
self,
|
||||
) -> m.InlineResponse2008:
|
||||
"""
|
||||
Get list of all existing collections aliases
|
||||
"""
|
||||
return self._build_for_get_collections_aliases()
|
||||
|
||||
def update_aliases(
|
||||
self,
|
||||
timeout: int = None,
|
||||
change_aliases_operation: m.ChangeAliasesOperation = None,
|
||||
) -> m.InlineResponse200:
|
||||
return self._build_for_update_aliases(
|
||||
timeout=timeout,
|
||||
change_aliases_operation=change_aliases_operation,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
# flake8: noqa E501
|
||||
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _BetaApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_clear_issues(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Removes all issues reported so far
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=bool,
|
||||
method="DELETE",
|
||||
url="/issues",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_get_issues(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Get a report of performance issues and configuration suggestions
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=object,
|
||||
method="GET",
|
||||
url="/issues",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
|
||||
class AsyncBetaApi(_BetaApi):
|
||||
async def clear_issues(
|
||||
self,
|
||||
) -> bool:
|
||||
"""
|
||||
Removes all issues reported so far
|
||||
"""
|
||||
return await self._build_for_clear_issues()
|
||||
|
||||
async def get_issues(
|
||||
self,
|
||||
) -> object:
|
||||
"""
|
||||
Get a report of performance issues and configuration suggestions
|
||||
"""
|
||||
return await self._build_for_get_issues()
|
||||
|
||||
|
||||
class SyncBetaApi(_BetaApi):
|
||||
def clear_issues(
|
||||
self,
|
||||
) -> bool:
|
||||
"""
|
||||
Removes all issues reported so far
|
||||
"""
|
||||
return self._build_for_clear_issues()
|
||||
|
||||
def get_issues(
|
||||
self,
|
||||
) -> object:
|
||||
"""
|
||||
Get a report of performance issues and configuration suggestions
|
||||
"""
|
||||
return self._build_for_get_issues()
|
||||
@@ -0,0 +1,345 @@
|
||||
# flake8: noqa E501
|
||||
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.http.models import models as m
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _CollectionsApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_collection_exists(
|
||||
self,
|
||||
collection_name: str,
|
||||
):
|
||||
"""
|
||||
Returns \"true\" if the given collection name exists, and \"false\" otherwise
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2006,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}/exists",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
create_collection: m.CreateCollection = None,
|
||||
):
|
||||
"""
|
||||
Create new collection with given parameters
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(create_collection)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="PUT",
|
||||
url="/collections/{collection_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_delete_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
):
|
||||
"""
|
||||
Drop collection and all associated data
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="DELETE",
|
||||
url="/collections/{collection_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_get_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
):
|
||||
"""
|
||||
Get detailed information about specified existing collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2004,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_get_collections(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Get list name of all existing collections
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2003,
|
||||
method="GET",
|
||||
url="/collections",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_update_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
update_collection: m.UpdateCollection = None,
|
||||
):
|
||||
"""
|
||||
Update parameters of the existing collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(update_collection)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="PATCH",
|
||||
url="/collections/{collection_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCollectionsApi(_CollectionsApi):
|
||||
async def collection_exists(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse2006:
|
||||
"""
|
||||
Returns \"true\" if the given collection name exists, and \"false\" otherwise
|
||||
"""
|
||||
return await self._build_for_collection_exists(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
create_collection: m.CreateCollection = None,
|
||||
) -> m.InlineResponse200:
|
||||
"""
|
||||
Create new collection with given parameters
|
||||
"""
|
||||
return await self._build_for_create_collection(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
create_collection=create_collection,
|
||||
)
|
||||
|
||||
async def delete_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
) -> m.InlineResponse200:
|
||||
"""
|
||||
Drop collection and all associated data
|
||||
"""
|
||||
return await self._build_for_delete_collection(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def get_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse2004:
|
||||
"""
|
||||
Get detailed information about specified existing collection
|
||||
"""
|
||||
return await self._build_for_get_collection(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
async def get_collections(
|
||||
self,
|
||||
) -> m.InlineResponse2003:
|
||||
"""
|
||||
Get list name of all existing collections
|
||||
"""
|
||||
return await self._build_for_get_collections()
|
||||
|
||||
async def update_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
update_collection: m.UpdateCollection = None,
|
||||
) -> m.InlineResponse200:
|
||||
"""
|
||||
Update parameters of the existing collection
|
||||
"""
|
||||
return await self._build_for_update_collection(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
update_collection=update_collection,
|
||||
)
|
||||
|
||||
|
||||
class SyncCollectionsApi(_CollectionsApi):
|
||||
def collection_exists(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse2006:
|
||||
"""
|
||||
Returns \"true\" if the given collection name exists, and \"false\" otherwise
|
||||
"""
|
||||
return self._build_for_collection_exists(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
create_collection: m.CreateCollection = None,
|
||||
) -> m.InlineResponse200:
|
||||
"""
|
||||
Create new collection with given parameters
|
||||
"""
|
||||
return self._build_for_create_collection(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
create_collection=create_collection,
|
||||
)
|
||||
|
||||
def delete_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
) -> m.InlineResponse200:
|
||||
"""
|
||||
Drop collection and all associated data
|
||||
"""
|
||||
return self._build_for_delete_collection(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def get_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse2004:
|
||||
"""
|
||||
Get detailed information about specified existing collection
|
||||
"""
|
||||
return self._build_for_get_collection(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
def get_collections(
|
||||
self,
|
||||
) -> m.InlineResponse2003:
|
||||
"""
|
||||
Get list name of all existing collections
|
||||
"""
|
||||
return self._build_for_get_collections()
|
||||
|
||||
def update_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
update_collection: m.UpdateCollection = None,
|
||||
) -> m.InlineResponse200:
|
||||
"""
|
||||
Update parameters of the existing collection
|
||||
"""
|
||||
return self._build_for_update_collection(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
update_collection=update_collection,
|
||||
)
|
||||
@@ -0,0 +1,365 @@
|
||||
# flake8: noqa E501
|
||||
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.http.models import models as m
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _DistributedApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_cluster_status(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Get information about the current state and composition of the cluster
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2002,
|
||||
method="GET",
|
||||
url="/cluster",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_collection_cluster_info(
|
||||
self,
|
||||
collection_name: str,
|
||||
):
|
||||
"""
|
||||
Get cluster information for a collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2007,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}/cluster",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_create_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
create_sharding_key: m.CreateShardingKey = None,
|
||||
):
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(create_sharding_key)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="PUT",
|
||||
url="/collections/{collection_name}/shards",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_delete_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
drop_sharding_key: m.DropShardingKey = None,
|
||||
):
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(drop_sharding_key)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/shards/delete",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_recover_current_peer(
|
||||
self,
|
||||
):
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="POST",
|
||||
url="/cluster/recover",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_remove_peer(
|
||||
self,
|
||||
peer_id: int,
|
||||
timeout: int = None,
|
||||
force: bool = None,
|
||||
):
|
||||
"""
|
||||
Tries to remove peer from the cluster. Will return an error if peer has shards on it.
|
||||
"""
|
||||
path_params = {
|
||||
"peer_id": str(peer_id),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
if force is not None:
|
||||
query_params["force"] = str(force).lower()
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="DELETE",
|
||||
url="/cluster/peer/{peer_id}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_update_collection_cluster(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
cluster_operations: m.ClusterOperations = None,
|
||||
):
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(cluster_operations)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse200,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/cluster",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
|
||||
class AsyncDistributedApi(_DistributedApi):
|
||||
async def cluster_status(
|
||||
self,
|
||||
) -> m.InlineResponse2002:
|
||||
"""
|
||||
Get information about the current state and composition of the cluster
|
||||
"""
|
||||
return await self._build_for_cluster_status()
|
||||
|
||||
async def collection_cluster_info(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse2007:
|
||||
"""
|
||||
Get cluster information for a collection
|
||||
"""
|
||||
return await self._build_for_collection_cluster_info(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
async def create_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
create_sharding_key: m.CreateShardingKey = None,
|
||||
) -> m.InlineResponse200:
|
||||
return await self._build_for_create_shard_key(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
create_sharding_key=create_sharding_key,
|
||||
)
|
||||
|
||||
async def delete_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
drop_sharding_key: m.DropShardingKey = None,
|
||||
) -> m.InlineResponse200:
|
||||
return await self._build_for_delete_shard_key(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
drop_sharding_key=drop_sharding_key,
|
||||
)
|
||||
|
||||
async def recover_current_peer(
|
||||
self,
|
||||
) -> m.InlineResponse200:
|
||||
return await self._build_for_recover_current_peer()
|
||||
|
||||
async def remove_peer(
|
||||
self,
|
||||
peer_id: int,
|
||||
timeout: int = None,
|
||||
force: bool = None,
|
||||
) -> m.InlineResponse200:
|
||||
"""
|
||||
Tries to remove peer from the cluster. Will return an error if peer has shards on it.
|
||||
"""
|
||||
return await self._build_for_remove_peer(
|
||||
peer_id=peer_id,
|
||||
timeout=timeout,
|
||||
force=force,
|
||||
)
|
||||
|
||||
async def update_collection_cluster(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
cluster_operations: m.ClusterOperations = None,
|
||||
) -> m.InlineResponse200:
|
||||
return await self._build_for_update_collection_cluster(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
cluster_operations=cluster_operations,
|
||||
)
|
||||
|
||||
|
||||
class SyncDistributedApi(_DistributedApi):
|
||||
def cluster_status(
|
||||
self,
|
||||
) -> m.InlineResponse2002:
|
||||
"""
|
||||
Get information about the current state and composition of the cluster
|
||||
"""
|
||||
return self._build_for_cluster_status()
|
||||
|
||||
def collection_cluster_info(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse2007:
|
||||
"""
|
||||
Get cluster information for a collection
|
||||
"""
|
||||
return self._build_for_collection_cluster_info(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
def create_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
create_sharding_key: m.CreateShardingKey = None,
|
||||
) -> m.InlineResponse200:
|
||||
return self._build_for_create_shard_key(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
create_sharding_key=create_sharding_key,
|
||||
)
|
||||
|
||||
def delete_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
drop_sharding_key: m.DropShardingKey = None,
|
||||
) -> m.InlineResponse200:
|
||||
return self._build_for_delete_shard_key(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
drop_sharding_key=drop_sharding_key,
|
||||
)
|
||||
|
||||
def recover_current_peer(
|
||||
self,
|
||||
) -> m.InlineResponse200:
|
||||
return self._build_for_recover_current_peer()
|
||||
|
||||
def remove_peer(
|
||||
self,
|
||||
peer_id: int,
|
||||
timeout: int = None,
|
||||
force: bool = None,
|
||||
) -> m.InlineResponse200:
|
||||
"""
|
||||
Tries to remove peer from the cluster. Will return an error if peer has shards on it.
|
||||
"""
|
||||
return self._build_for_remove_peer(
|
||||
peer_id=peer_id,
|
||||
timeout=timeout,
|
||||
force=force,
|
||||
)
|
||||
|
||||
def update_collection_cluster(
|
||||
self,
|
||||
collection_name: str,
|
||||
timeout: int = None,
|
||||
cluster_operations: m.ClusterOperations = None,
|
||||
) -> m.InlineResponse200:
|
||||
return self._build_for_update_collection_cluster(
|
||||
collection_name=collection_name,
|
||||
timeout=timeout,
|
||||
cluster_operations=cluster_operations,
|
||||
)
|
||||
@@ -0,0 +1,190 @@
|
||||
# flake8: noqa E501
|
||||
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.http.models import models as m
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _IndexesApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_create_field_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
create_field_index: m.CreateFieldIndex = None,
|
||||
):
|
||||
"""
|
||||
Create index for field in collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(create_field_index)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="PUT",
|
||||
url="/collections/{collection_name}/index",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_delete_field_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
field_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
):
|
||||
"""
|
||||
Delete field index for collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"field_name": str(field_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="DELETE",
|
||||
url="/collections/{collection_name}/index/{field_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
|
||||
class AsyncIndexesApi(_IndexesApi):
|
||||
async def create_field_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
create_field_index: m.CreateFieldIndex = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Create index for field in collection
|
||||
"""
|
||||
return await self._build_for_create_field_index(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
create_field_index=create_field_index,
|
||||
)
|
||||
|
||||
async def delete_field_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
field_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Delete field index for collection
|
||||
"""
|
||||
return await self._build_for_delete_field_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
)
|
||||
|
||||
|
||||
class SyncIndexesApi(_IndexesApi):
|
||||
def create_field_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
create_field_index: m.CreateFieldIndex = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Create index for field in collection
|
||||
"""
|
||||
return self._build_for_create_field_index(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
create_field_index=create_field_index,
|
||||
)
|
||||
|
||||
def delete_field_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
field_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Delete field index for collection
|
||||
"""
|
||||
return self._build_for_delete_field_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
)
|
||||
@@ -0,0 +1,999 @@
|
||||
# flake8: noqa E501
|
||||
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.http.models import models as m
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _PointsApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_batch_update(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
update_operations: m.UpdateOperations = None,
|
||||
):
|
||||
"""
|
||||
Apply a series of update operations for points, vectors and payloads
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(update_operations)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20014,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/batch",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_clear_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
points_selector: m.PointsSelector = None,
|
||||
):
|
||||
"""
|
||||
Remove all payload for specified points
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(points_selector)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/payload/clear",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_count_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
count_request: m.CountRequest = None,
|
||||
):
|
||||
"""
|
||||
Count points which matches given filtering condition
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(count_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20019,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/count",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_delete_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
delete_payload: m.DeletePayload = None,
|
||||
):
|
||||
"""
|
||||
Delete specified key payload for points
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(delete_payload)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/payload/delete",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_delete_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
points_selector: m.PointsSelector = None,
|
||||
):
|
||||
"""
|
||||
Delete points
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(points_selector)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/delete",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_delete_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
delete_vectors: m.DeleteVectors = None,
|
||||
):
|
||||
"""
|
||||
Delete named vectors from the given points.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(delete_vectors)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/vectors/delete",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_facet(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
facet_request: m.FacetRequest = None,
|
||||
):
|
||||
"""
|
||||
Count points that satisfy the given filter for each unique value of a payload key.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(facet_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20020,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/facet",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_get_point(
|
||||
self,
|
||||
collection_name: str,
|
||||
id: m.ExtendedPointId,
|
||||
consistency: m.ReadConsistency = None,
|
||||
):
|
||||
"""
|
||||
Retrieve full information of single point by id
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"id": str(id),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20012,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}/points/{id}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_get_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
point_request: m.PointRequest = None,
|
||||
):
|
||||
"""
|
||||
Retrieve multiple points by specified IDs
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(point_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20013,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_overwrite_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
set_payload: m.SetPayload = None,
|
||||
):
|
||||
"""
|
||||
Replace full payload of points with new one
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(set_payload)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="PUT",
|
||||
url="/collections/{collection_name}/points/payload",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_scroll_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
scroll_request: m.ScrollRequest = None,
|
||||
):
|
||||
"""
|
||||
Scroll request - paginate over all points which matches given filtering condition
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(scroll_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20015,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/scroll",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_set_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
set_payload: m.SetPayload = None,
|
||||
):
|
||||
"""
|
||||
Set payload values for points
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(set_payload)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/payload",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_update_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
update_vectors: m.UpdateVectors = None,
|
||||
):
|
||||
"""
|
||||
Update specified named vectors on points, keep unspecified vectors intact.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(update_vectors)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="PUT",
|
||||
url="/collections/{collection_name}/points/vectors",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_upsert_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
point_insert_operations: m.PointInsertOperations = None,
|
||||
):
|
||||
"""
|
||||
Perform insert + updates on points. If point with given ID already exists - it will be overwritten.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if ordering is not None:
|
||||
query_params["ordering"] = str(ordering)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(point_insert_operations)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2005,
|
||||
method="PUT",
|
||||
url="/collections/{collection_name}/points",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
|
||||
class AsyncPointsApi(_PointsApi):
|
||||
async def batch_update(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
update_operations: m.UpdateOperations = None,
|
||||
) -> m.InlineResponse20014:
|
||||
"""
|
||||
Apply a series of update operations for points, vectors and payloads
|
||||
"""
|
||||
return await self._build_for_batch_update(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
update_operations=update_operations,
|
||||
)
|
||||
|
||||
async def clear_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
points_selector: m.PointsSelector = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Remove all payload for specified points
|
||||
"""
|
||||
return await self._build_for_clear_payload(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
points_selector=points_selector,
|
||||
)
|
||||
|
||||
async def count_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
count_request: m.CountRequest = None,
|
||||
) -> m.InlineResponse20019:
|
||||
"""
|
||||
Count points which matches given filtering condition
|
||||
"""
|
||||
return await self._build_for_count_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
count_request=count_request,
|
||||
)
|
||||
|
||||
async def delete_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
delete_payload: m.DeletePayload = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Delete specified key payload for points
|
||||
"""
|
||||
return await self._build_for_delete_payload(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
delete_payload=delete_payload,
|
||||
)
|
||||
|
||||
async def delete_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
points_selector: m.PointsSelector = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Delete points
|
||||
"""
|
||||
return await self._build_for_delete_points(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
points_selector=points_selector,
|
||||
)
|
||||
|
||||
async def delete_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
delete_vectors: m.DeleteVectors = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Delete named vectors from the given points.
|
||||
"""
|
||||
return await self._build_for_delete_vectors(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
delete_vectors=delete_vectors,
|
||||
)
|
||||
|
||||
async def facet(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
facet_request: m.FacetRequest = None,
|
||||
) -> m.InlineResponse20020:
|
||||
"""
|
||||
Count points that satisfy the given filter for each unique value of a payload key.
|
||||
"""
|
||||
return await self._build_for_facet(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
facet_request=facet_request,
|
||||
)
|
||||
|
||||
async def get_point(
|
||||
self,
|
||||
collection_name: str,
|
||||
id: m.ExtendedPointId,
|
||||
consistency: m.ReadConsistency = None,
|
||||
) -> m.InlineResponse20012:
|
||||
"""
|
||||
Retrieve full information of single point by id
|
||||
"""
|
||||
return await self._build_for_get_point(
|
||||
collection_name=collection_name,
|
||||
id=id,
|
||||
consistency=consistency,
|
||||
)
|
||||
|
||||
async def get_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
point_request: m.PointRequest = None,
|
||||
) -> m.InlineResponse20013:
|
||||
"""
|
||||
Retrieve multiple points by specified IDs
|
||||
"""
|
||||
return await self._build_for_get_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
point_request=point_request,
|
||||
)
|
||||
|
||||
async def overwrite_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
set_payload: m.SetPayload = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Replace full payload of points with new one
|
||||
"""
|
||||
return await self._build_for_overwrite_payload(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
set_payload=set_payload,
|
||||
)
|
||||
|
||||
async def scroll_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
scroll_request: m.ScrollRequest = None,
|
||||
) -> m.InlineResponse20015:
|
||||
"""
|
||||
Scroll request - paginate over all points which matches given filtering condition
|
||||
"""
|
||||
return await self._build_for_scroll_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
scroll_request=scroll_request,
|
||||
)
|
||||
|
||||
async def set_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
set_payload: m.SetPayload = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Set payload values for points
|
||||
"""
|
||||
return await self._build_for_set_payload(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
set_payload=set_payload,
|
||||
)
|
||||
|
||||
async def update_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
update_vectors: m.UpdateVectors = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Update specified named vectors on points, keep unspecified vectors intact.
|
||||
"""
|
||||
return await self._build_for_update_vectors(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
update_vectors=update_vectors,
|
||||
)
|
||||
|
||||
async def upsert_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
point_insert_operations: m.PointInsertOperations = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Perform insert + updates on points. If point with given ID already exists - it will be overwritten.
|
||||
"""
|
||||
return await self._build_for_upsert_points(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
point_insert_operations=point_insert_operations,
|
||||
)
|
||||
|
||||
|
||||
class SyncPointsApi(_PointsApi):
|
||||
def batch_update(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
update_operations: m.UpdateOperations = None,
|
||||
) -> m.InlineResponse20014:
|
||||
"""
|
||||
Apply a series of update operations for points, vectors and payloads
|
||||
"""
|
||||
return self._build_for_batch_update(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
update_operations=update_operations,
|
||||
)
|
||||
|
||||
def clear_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
points_selector: m.PointsSelector = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Remove all payload for specified points
|
||||
"""
|
||||
return self._build_for_clear_payload(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
points_selector=points_selector,
|
||||
)
|
||||
|
||||
def count_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
count_request: m.CountRequest = None,
|
||||
) -> m.InlineResponse20019:
|
||||
"""
|
||||
Count points which matches given filtering condition
|
||||
"""
|
||||
return self._build_for_count_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
count_request=count_request,
|
||||
)
|
||||
|
||||
def delete_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
delete_payload: m.DeletePayload = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Delete specified key payload for points
|
||||
"""
|
||||
return self._build_for_delete_payload(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
delete_payload=delete_payload,
|
||||
)
|
||||
|
||||
def delete_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
points_selector: m.PointsSelector = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Delete points
|
||||
"""
|
||||
return self._build_for_delete_points(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
points_selector=points_selector,
|
||||
)
|
||||
|
||||
def delete_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
delete_vectors: m.DeleteVectors = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Delete named vectors from the given points.
|
||||
"""
|
||||
return self._build_for_delete_vectors(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
delete_vectors=delete_vectors,
|
||||
)
|
||||
|
||||
def facet(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
facet_request: m.FacetRequest = None,
|
||||
) -> m.InlineResponse20020:
|
||||
"""
|
||||
Count points that satisfy the given filter for each unique value of a payload key.
|
||||
"""
|
||||
return self._build_for_facet(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
facet_request=facet_request,
|
||||
)
|
||||
|
||||
def get_point(
|
||||
self,
|
||||
collection_name: str,
|
||||
id: m.ExtendedPointId,
|
||||
consistency: m.ReadConsistency = None,
|
||||
) -> m.InlineResponse20012:
|
||||
"""
|
||||
Retrieve full information of single point by id
|
||||
"""
|
||||
return self._build_for_get_point(
|
||||
collection_name=collection_name,
|
||||
id=id,
|
||||
consistency=consistency,
|
||||
)
|
||||
|
||||
def get_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
point_request: m.PointRequest = None,
|
||||
) -> m.InlineResponse20013:
|
||||
"""
|
||||
Retrieve multiple points by specified IDs
|
||||
"""
|
||||
return self._build_for_get_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
point_request=point_request,
|
||||
)
|
||||
|
||||
def overwrite_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
set_payload: m.SetPayload = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Replace full payload of points with new one
|
||||
"""
|
||||
return self._build_for_overwrite_payload(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
set_payload=set_payload,
|
||||
)
|
||||
|
||||
def scroll_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
scroll_request: m.ScrollRequest = None,
|
||||
) -> m.InlineResponse20015:
|
||||
"""
|
||||
Scroll request - paginate over all points which matches given filtering condition
|
||||
"""
|
||||
return self._build_for_scroll_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
scroll_request=scroll_request,
|
||||
)
|
||||
|
||||
def set_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
set_payload: m.SetPayload = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Set payload values for points
|
||||
"""
|
||||
return self._build_for_set_payload(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
set_payload=set_payload,
|
||||
)
|
||||
|
||||
def update_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
update_vectors: m.UpdateVectors = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Update specified named vectors on points, keep unspecified vectors intact.
|
||||
"""
|
||||
return self._build_for_update_vectors(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
update_vectors=update_vectors,
|
||||
)
|
||||
|
||||
def upsert_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
ordering: WriteOrdering = None,
|
||||
point_insert_operations: m.PointInsertOperations = None,
|
||||
) -> m.InlineResponse2005:
|
||||
"""
|
||||
Perform insert + updates on points. If point with given ID already exists - it will be overwritten.
|
||||
"""
|
||||
return self._build_for_upsert_points(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
ordering=ordering,
|
||||
point_insert_operations=point_insert_operations,
|
||||
)
|
||||
@@ -0,0 +1,941 @@
|
||||
# flake8: noqa E501
|
||||
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.http.models import models as m
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _SearchApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_discover_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
discover_request_batch: m.DiscoverRequestBatch = None,
|
||||
):
|
||||
"""
|
||||
Look for points based on target and/or positive and negative example pairs, in batch.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(discover_request_batch)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20017,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/discover/batch",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_discover_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
discover_request: m.DiscoverRequest = None,
|
||||
):
|
||||
"""
|
||||
Use context and a target to find the most similar points to the target, constrained by the context. When using only the context (without a target), a special search - called context search - is performed where pairs of points are used to generate a loss that guides the search towards the zone where most positive examples overlap. This means that the score minimizes the scenario of finding a point closer to a negative than to a positive part of a pair. Since the score of a context relates to loss, the maximum score a point can get is 0.0, and it becomes normal that many points can have a score of 0.0. When using target (with or without context), the score behaves a little different: The integer part of the score represents the rank with respect to the context, while the decimal part of the score relates to the distance to the target. The context part of the score for each pair is calculated +1 if the point is closer to a positive than to a negative part of a pair, and -1 otherwise.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(discover_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20016,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/discover",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_query_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_request_batch: m.QueryRequestBatch = None,
|
||||
):
|
||||
"""
|
||||
Universally query points in batch. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(query_request_batch)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20022,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/query/batch",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_query_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_request: m.QueryRequest = None,
|
||||
):
|
||||
"""
|
||||
Universally query points. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(query_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20021,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/query",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_query_points_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_groups_request: m.QueryGroupsRequest = None,
|
||||
):
|
||||
"""
|
||||
Universally query points, grouped by a given payload field
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(query_groups_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20018,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/query/groups",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_recommend_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_request_batch: m.RecommendRequestBatch = None,
|
||||
):
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(recommend_request_batch)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20017,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/recommend/batch",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_recommend_point_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_groups_request: m.RecommendGroupsRequest = None,
|
||||
):
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples, grouped by a given payload field.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(recommend_groups_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20018,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/recommend/groups",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_recommend_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_request: m.RecommendRequest = None,
|
||||
):
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(recommend_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20016,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/recommend",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_search_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_request_batch: m.SearchRequestBatch = None,
|
||||
):
|
||||
"""
|
||||
Retrieve by batch the closest points based on vector similarity and given filtering conditions
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(search_request_batch)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20017,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/search/batch",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_search_matrix_offsets(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_matrix_request: m.SearchMatrixRequest = None,
|
||||
):
|
||||
"""
|
||||
Compute distance matrix for sampled points with an offset based output format
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(search_matrix_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20024,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/search/matrix/offsets",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_search_matrix_pairs(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_matrix_request: m.SearchMatrixRequest = None,
|
||||
):
|
||||
"""
|
||||
Compute distance matrix for sampled points with a pair based output format
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(search_matrix_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20023,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/search/matrix/pairs",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_search_point_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_groups_request: m.SearchGroupsRequest = None,
|
||||
):
|
||||
"""
|
||||
Retrieve closest points based on vector similarity and given filtering conditions, grouped by a given payload field
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(search_groups_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20018,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/search/groups",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_search_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_request: m.SearchRequest = None,
|
||||
):
|
||||
"""
|
||||
Retrieve closest points based on vector similarity and given filtering conditions
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if consistency is not None:
|
||||
query_params["consistency"] = str(consistency)
|
||||
if timeout is not None:
|
||||
query_params["timeout"] = str(timeout)
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(search_request)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20016,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/points/search",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSearchApi(_SearchApi):
|
||||
async def discover_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
discover_request_batch: m.DiscoverRequestBatch = None,
|
||||
) -> m.InlineResponse20017:
|
||||
"""
|
||||
Look for points based on target and/or positive and negative example pairs, in batch.
|
||||
"""
|
||||
return await self._build_for_discover_batch_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
discover_request_batch=discover_request_batch,
|
||||
)
|
||||
|
||||
async def discover_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
discover_request: m.DiscoverRequest = None,
|
||||
) -> m.InlineResponse20016:
|
||||
"""
|
||||
Use context and a target to find the most similar points to the target, constrained by the context. When using only the context (without a target), a special search - called context search - is performed where pairs of points are used to generate a loss that guides the search towards the zone where most positive examples overlap. This means that the score minimizes the scenario of finding a point closer to a negative than to a positive part of a pair. Since the score of a context relates to loss, the maximum score a point can get is 0.0, and it becomes normal that many points can have a score of 0.0. When using target (with or without context), the score behaves a little different: The integer part of the score represents the rank with respect to the context, while the decimal part of the score relates to the distance to the target. The context part of the score for each pair is calculated +1 if the point is closer to a positive than to a negative part of a pair, and -1 otherwise.
|
||||
"""
|
||||
return await self._build_for_discover_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
discover_request=discover_request,
|
||||
)
|
||||
|
||||
async def query_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_request_batch: m.QueryRequestBatch = None,
|
||||
) -> m.InlineResponse20022:
|
||||
"""
|
||||
Universally query points in batch. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
|
||||
"""
|
||||
return await self._build_for_query_batch_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
query_request_batch=query_request_batch,
|
||||
)
|
||||
|
||||
async def query_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_request: m.QueryRequest = None,
|
||||
) -> m.InlineResponse20021:
|
||||
"""
|
||||
Universally query points. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
|
||||
"""
|
||||
return await self._build_for_query_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
query_request=query_request,
|
||||
)
|
||||
|
||||
async def query_points_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_groups_request: m.QueryGroupsRequest = None,
|
||||
) -> m.InlineResponse20018:
|
||||
"""
|
||||
Universally query points, grouped by a given payload field
|
||||
"""
|
||||
return await self._build_for_query_points_groups(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
query_groups_request=query_groups_request,
|
||||
)
|
||||
|
||||
async def recommend_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_request_batch: m.RecommendRequestBatch = None,
|
||||
) -> m.InlineResponse20017:
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
|
||||
"""
|
||||
return await self._build_for_recommend_batch_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
recommend_request_batch=recommend_request_batch,
|
||||
)
|
||||
|
||||
async def recommend_point_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_groups_request: m.RecommendGroupsRequest = None,
|
||||
) -> m.InlineResponse20018:
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples, grouped by a given payload field.
|
||||
"""
|
||||
return await self._build_for_recommend_point_groups(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
recommend_groups_request=recommend_groups_request,
|
||||
)
|
||||
|
||||
async def recommend_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_request: m.RecommendRequest = None,
|
||||
) -> m.InlineResponse20016:
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
|
||||
"""
|
||||
return await self._build_for_recommend_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
recommend_request=recommend_request,
|
||||
)
|
||||
|
||||
async def search_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_request_batch: m.SearchRequestBatch = None,
|
||||
) -> m.InlineResponse20017:
|
||||
"""
|
||||
Retrieve by batch the closest points based on vector similarity and given filtering conditions
|
||||
"""
|
||||
return await self._build_for_search_batch_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_request_batch=search_request_batch,
|
||||
)
|
||||
|
||||
async def search_matrix_offsets(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_matrix_request: m.SearchMatrixRequest = None,
|
||||
) -> m.InlineResponse20024:
|
||||
"""
|
||||
Compute distance matrix for sampled points with an offset based output format
|
||||
"""
|
||||
return await self._build_for_search_matrix_offsets(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_matrix_request=search_matrix_request,
|
||||
)
|
||||
|
||||
async def search_matrix_pairs(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_matrix_request: m.SearchMatrixRequest = None,
|
||||
) -> m.InlineResponse20023:
|
||||
"""
|
||||
Compute distance matrix for sampled points with a pair based output format
|
||||
"""
|
||||
return await self._build_for_search_matrix_pairs(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_matrix_request=search_matrix_request,
|
||||
)
|
||||
|
||||
async def search_point_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_groups_request: m.SearchGroupsRequest = None,
|
||||
) -> m.InlineResponse20018:
|
||||
"""
|
||||
Retrieve closest points based on vector similarity and given filtering conditions, grouped by a given payload field
|
||||
"""
|
||||
return await self._build_for_search_point_groups(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_groups_request=search_groups_request,
|
||||
)
|
||||
|
||||
async def search_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_request: m.SearchRequest = None,
|
||||
) -> m.InlineResponse20016:
|
||||
"""
|
||||
Retrieve closest points based on vector similarity and given filtering conditions
|
||||
"""
|
||||
return await self._build_for_search_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_request=search_request,
|
||||
)
|
||||
|
||||
|
||||
class SyncSearchApi(_SearchApi):
|
||||
def discover_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
discover_request_batch: m.DiscoverRequestBatch = None,
|
||||
) -> m.InlineResponse20017:
|
||||
"""
|
||||
Look for points based on target and/or positive and negative example pairs, in batch.
|
||||
"""
|
||||
return self._build_for_discover_batch_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
discover_request_batch=discover_request_batch,
|
||||
)
|
||||
|
||||
def discover_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
discover_request: m.DiscoverRequest = None,
|
||||
) -> m.InlineResponse20016:
|
||||
"""
|
||||
Use context and a target to find the most similar points to the target, constrained by the context. When using only the context (without a target), a special search - called context search - is performed where pairs of points are used to generate a loss that guides the search towards the zone where most positive examples overlap. This means that the score minimizes the scenario of finding a point closer to a negative than to a positive part of a pair. Since the score of a context relates to loss, the maximum score a point can get is 0.0, and it becomes normal that many points can have a score of 0.0. When using target (with or without context), the score behaves a little different: The integer part of the score represents the rank with respect to the context, while the decimal part of the score relates to the distance to the target. The context part of the score for each pair is calculated +1 if the point is closer to a positive than to a negative part of a pair, and -1 otherwise.
|
||||
"""
|
||||
return self._build_for_discover_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
discover_request=discover_request,
|
||||
)
|
||||
|
||||
def query_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_request_batch: m.QueryRequestBatch = None,
|
||||
) -> m.InlineResponse20022:
|
||||
"""
|
||||
Universally query points in batch. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
|
||||
"""
|
||||
return self._build_for_query_batch_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
query_request_batch=query_request_batch,
|
||||
)
|
||||
|
||||
def query_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_request: m.QueryRequest = None,
|
||||
) -> m.InlineResponse20021:
|
||||
"""
|
||||
Universally query points. This endpoint covers all capabilities of search, recommend, discover, filters. But also enables hybrid and multi-stage queries.
|
||||
"""
|
||||
return self._build_for_query_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
query_request=query_request,
|
||||
)
|
||||
|
||||
def query_points_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
query_groups_request: m.QueryGroupsRequest = None,
|
||||
) -> m.InlineResponse20018:
|
||||
"""
|
||||
Universally query points, grouped by a given payload field
|
||||
"""
|
||||
return self._build_for_query_points_groups(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
query_groups_request=query_groups_request,
|
||||
)
|
||||
|
||||
def recommend_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_request_batch: m.RecommendRequestBatch = None,
|
||||
) -> m.InlineResponse20017:
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
|
||||
"""
|
||||
return self._build_for_recommend_batch_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
recommend_request_batch=recommend_request_batch,
|
||||
)
|
||||
|
||||
def recommend_point_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_groups_request: m.RecommendGroupsRequest = None,
|
||||
) -> m.InlineResponse20018:
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples, grouped by a given payload field.
|
||||
"""
|
||||
return self._build_for_recommend_point_groups(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
recommend_groups_request=recommend_groups_request,
|
||||
)
|
||||
|
||||
def recommend_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
recommend_request: m.RecommendRequest = None,
|
||||
) -> m.InlineResponse20016:
|
||||
"""
|
||||
Look for the points which are closer to stored positive examples and at the same time further to negative examples.
|
||||
"""
|
||||
return self._build_for_recommend_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
recommend_request=recommend_request,
|
||||
)
|
||||
|
||||
def search_batch_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_request_batch: m.SearchRequestBatch = None,
|
||||
) -> m.InlineResponse20017:
|
||||
"""
|
||||
Retrieve by batch the closest points based on vector similarity and given filtering conditions
|
||||
"""
|
||||
return self._build_for_search_batch_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_request_batch=search_request_batch,
|
||||
)
|
||||
|
||||
def search_matrix_offsets(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_matrix_request: m.SearchMatrixRequest = None,
|
||||
) -> m.InlineResponse20024:
|
||||
"""
|
||||
Compute distance matrix for sampled points with an offset based output format
|
||||
"""
|
||||
return self._build_for_search_matrix_offsets(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_matrix_request=search_matrix_request,
|
||||
)
|
||||
|
||||
def search_matrix_pairs(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_matrix_request: m.SearchMatrixRequest = None,
|
||||
) -> m.InlineResponse20023:
|
||||
"""
|
||||
Compute distance matrix for sampled points with a pair based output format
|
||||
"""
|
||||
return self._build_for_search_matrix_pairs(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_matrix_request=search_matrix_request,
|
||||
)
|
||||
|
||||
def search_point_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_groups_request: m.SearchGroupsRequest = None,
|
||||
) -> m.InlineResponse20018:
|
||||
"""
|
||||
Retrieve closest points based on vector similarity and given filtering conditions, grouped by a given payload field
|
||||
"""
|
||||
return self._build_for_search_point_groups(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_groups_request=search_groups_request,
|
||||
)
|
||||
|
||||
def search_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
consistency: m.ReadConsistency = None,
|
||||
timeout: int = None,
|
||||
search_request: m.SearchRequest = None,
|
||||
) -> m.InlineResponse20016:
|
||||
"""
|
||||
Retrieve closest points based on vector similarity and given filtering conditions
|
||||
"""
|
||||
return self._build_for_search_points(
|
||||
collection_name=collection_name,
|
||||
consistency=consistency,
|
||||
timeout=timeout,
|
||||
search_request=search_request,
|
||||
)
|
||||
@@ -0,0 +1,268 @@
|
||||
# flake8: noqa E501
|
||||
from typing import TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.http.models import models as m
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _ServiceApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_healthz(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=str,
|
||||
method="GET",
|
||||
url="/healthz",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_livez(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=str,
|
||||
method="GET",
|
||||
url="/livez",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_metrics(
|
||||
self,
|
||||
anonymize: bool = None,
|
||||
):
|
||||
"""
|
||||
Collect metrics data including app info, collections info, cluster info and statistics
|
||||
"""
|
||||
query_params = {}
|
||||
if anonymize is not None:
|
||||
query_params["anonymize"] = str(anonymize).lower()
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=str,
|
||||
method="GET",
|
||||
url="/metrics",
|
||||
headers=headers if headers else None,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_readyz(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=str,
|
||||
method="GET",
|
||||
url="/readyz",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_root(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Returns information about the running Qdrant instance like version and commit id
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.VersionInfo,
|
||||
method="GET",
|
||||
url="/",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_telemetry(
|
||||
self,
|
||||
anonymize: bool = None,
|
||||
details_level: int = None,
|
||||
):
|
||||
"""
|
||||
Collect telemetry data including app info, system info, collections info, cluster info, configs and statistics
|
||||
"""
|
||||
query_params = {}
|
||||
if anonymize is not None:
|
||||
query_params["anonymize"] = str(anonymize).lower()
|
||||
if details_level is not None:
|
||||
query_params["details_level"] = str(details_level)
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2001,
|
||||
method="GET",
|
||||
url="/telemetry",
|
||||
headers=headers if headers else None,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
|
||||
class AsyncServiceApi(_ServiceApi):
|
||||
async def healthz(
|
||||
self,
|
||||
) -> str:
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
return await self._build_for_healthz()
|
||||
|
||||
async def livez(
|
||||
self,
|
||||
) -> str:
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
return await self._build_for_livez()
|
||||
|
||||
async def metrics(
|
||||
self,
|
||||
anonymize: bool = None,
|
||||
) -> str:
|
||||
"""
|
||||
Collect metrics data including app info, collections info, cluster info and statistics
|
||||
"""
|
||||
return await self._build_for_metrics(
|
||||
anonymize=anonymize,
|
||||
)
|
||||
|
||||
async def readyz(
|
||||
self,
|
||||
) -> str:
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
return await self._build_for_readyz()
|
||||
|
||||
async def root(
|
||||
self,
|
||||
) -> m.VersionInfo:
|
||||
"""
|
||||
Returns information about the running Qdrant instance like version and commit id
|
||||
"""
|
||||
return await self._build_for_root()
|
||||
|
||||
async def telemetry(
|
||||
self,
|
||||
anonymize: bool = None,
|
||||
details_level: int = None,
|
||||
) -> m.InlineResponse2001:
|
||||
"""
|
||||
Collect telemetry data including app info, system info, collections info, cluster info, configs and statistics
|
||||
"""
|
||||
return await self._build_for_telemetry(
|
||||
anonymize=anonymize,
|
||||
details_level=details_level,
|
||||
)
|
||||
|
||||
|
||||
class SyncServiceApi(_ServiceApi):
|
||||
def healthz(
|
||||
self,
|
||||
) -> str:
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
return self._build_for_healthz()
|
||||
|
||||
def livez(
|
||||
self,
|
||||
) -> str:
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
return self._build_for_livez()
|
||||
|
||||
def metrics(
|
||||
self,
|
||||
anonymize: bool = None,
|
||||
) -> str:
|
||||
"""
|
||||
Collect metrics data including app info, collections info, cluster info and statistics
|
||||
"""
|
||||
return self._build_for_metrics(
|
||||
anonymize=anonymize,
|
||||
)
|
||||
|
||||
def readyz(
|
||||
self,
|
||||
) -> str:
|
||||
"""
|
||||
An endpoint for health checking used in Kubernetes.
|
||||
"""
|
||||
return self._build_for_readyz()
|
||||
|
||||
def root(
|
||||
self,
|
||||
) -> m.VersionInfo:
|
||||
"""
|
||||
Returns information about the running Qdrant instance like version and commit id
|
||||
"""
|
||||
return self._build_for_root()
|
||||
|
||||
def telemetry(
|
||||
self,
|
||||
anonymize: bool = None,
|
||||
details_level: int = None,
|
||||
) -> m.InlineResponse2001:
|
||||
"""
|
||||
Collect telemetry data including app info, system info, collections info, cluster info, configs and statistics
|
||||
"""
|
||||
return self._build_for_telemetry(
|
||||
anonymize=anonymize,
|
||||
details_level=details_level,
|
||||
)
|
||||
@@ -0,0 +1,937 @@
|
||||
# flake8: noqa E501
|
||||
from typing import IO, TYPE_CHECKING, Any, Dict, Set, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.http.models import models as m
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
Model = TypeVar("Model", bound="BaseModel")
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
file = None
|
||||
|
||||
|
||||
def to_json(model: BaseModel, *args: Any, **kwargs: Any) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(*args, **kwargs)
|
||||
else:
|
||||
return model.json(*args, **kwargs)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Union[SetIntStr, DictIntStrAny] = None,
|
||||
exclude=None,
|
||||
by_alias: bool = True,
|
||||
skip_defaults: bool = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = True,
|
||||
):
|
||||
if hasattr(obj, "json") or hasattr(obj, "model_dump_json"):
|
||||
return to_json(
|
||||
obj,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=bool(exclude_unset or skip_defaults),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.http.api_client import ApiClient
|
||||
|
||||
|
||||
class _SnapshotsApi:
|
||||
def __init__(self, api_client: "Union[ApiClient, AsyncApiClient]"):
|
||||
self.api_client = api_client
|
||||
|
||||
def _build_for_create_full_snapshot(
|
||||
self,
|
||||
wait: bool = None,
|
||||
):
|
||||
"""
|
||||
Create new snapshot of the whole storage
|
||||
"""
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20011,
|
||||
method="POST",
|
||||
url="/snapshots",
|
||||
headers=headers if headers else None,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_create_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
):
|
||||
"""
|
||||
Create new snapshot of a shard for a collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"shard_id": str(shard_id),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20011,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/shards/{shard_id}/snapshots",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_create_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
):
|
||||
"""
|
||||
Create new snapshot for a collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20011,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/snapshots",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_delete_full_snapshot(
|
||||
self,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
):
|
||||
"""
|
||||
Delete snapshot of the whole storage
|
||||
"""
|
||||
path_params = {
|
||||
"snapshot_name": str(snapshot_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2009,
|
||||
method="DELETE",
|
||||
url="/snapshots/{snapshot_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_delete_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
):
|
||||
"""
|
||||
Delete snapshot of a shard for a collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"shard_id": str(shard_id),
|
||||
"snapshot_name": str(snapshot_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2009,
|
||||
method="DELETE",
|
||||
url="/collections/{collection_name}/shards/{shard_id}/snapshots/{snapshot_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_delete_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
):
|
||||
"""
|
||||
Delete snapshot for a collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"snapshot_name": str(snapshot_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2009,
|
||||
method="DELETE",
|
||||
url="/collections/{collection_name}/snapshots/{snapshot_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
)
|
||||
|
||||
def _build_for_get_full_snapshot(
|
||||
self,
|
||||
snapshot_name: str,
|
||||
):
|
||||
"""
|
||||
Download specified snapshot of the whole storage as a file
|
||||
"""
|
||||
path_params = {
|
||||
"snapshot_name": str(snapshot_name),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=file,
|
||||
method="GET",
|
||||
url="/snapshots/{snapshot_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_get_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
snapshot_name: str,
|
||||
):
|
||||
"""
|
||||
Download specified snapshot of a shard from a collection as a file
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"shard_id": str(shard_id),
|
||||
"snapshot_name": str(snapshot_name),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=file,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}/shards/{shard_id}/snapshots/{snapshot_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_get_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
snapshot_name: str,
|
||||
):
|
||||
"""
|
||||
Download specified snapshot from a collection as a file
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"snapshot_name": str(snapshot_name),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=file,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}/snapshots/{snapshot_name}",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_list_full_snapshots(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Get list of snapshots of the whole storage
|
||||
"""
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20010,
|
||||
method="GET",
|
||||
url="/snapshots",
|
||||
headers=headers if headers else None,
|
||||
)
|
||||
|
||||
def _build_for_list_shard_snapshots(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
):
|
||||
"""
|
||||
Get list of snapshots for a shard of a collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"shard_id": str(shard_id),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20010,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}/shards/{shard_id}/snapshots",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_list_snapshots(
|
||||
self,
|
||||
collection_name: str,
|
||||
):
|
||||
"""
|
||||
Get list of snapshots for a collection
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse20010,
|
||||
method="GET",
|
||||
url="/collections/{collection_name}/snapshots",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
)
|
||||
|
||||
def _build_for_recover_from_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
snapshot_recover: m.SnapshotRecover = None,
|
||||
):
|
||||
"""
|
||||
Recover local collection data from a snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(snapshot_recover)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2009,
|
||||
method="PUT",
|
||||
url="/collections/{collection_name}/snapshots/recover",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_recover_from_uploaded_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
priority: SnapshotPriority = None,
|
||||
checksum: str = None,
|
||||
snapshot: IO[Any] = None,
|
||||
):
|
||||
"""
|
||||
Recover local collection data from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if priority is not None:
|
||||
query_params["priority"] = str(priority)
|
||||
if checksum is not None:
|
||||
query_params["checksum"] = str(checksum)
|
||||
|
||||
headers = {}
|
||||
files: Dict[str, IO[Any]] = {} # noqa F841
|
||||
data: Dict[str, Any] = {} # noqa F841
|
||||
if snapshot is not None:
|
||||
files["snapshot"] = snapshot
|
||||
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2009,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/snapshots/upload",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
data=data,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def _build_for_recover_shard_from_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
shard_snapshot_recover: m.ShardSnapshotRecover = None,
|
||||
):
|
||||
"""
|
||||
Recover shard of a local collection data from a snapshot. This will overwrite any data, stored in this shard, for the collection.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"shard_id": str(shard_id),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
|
||||
headers = {}
|
||||
body = jsonable_encoder(shard_snapshot_recover)
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2009,
|
||||
method="PUT",
|
||||
url="/collections/{collection_name}/shards/{shard_id}/snapshots/recover",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
content=body,
|
||||
)
|
||||
|
||||
def _build_for_recover_shard_from_uploaded_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
priority: SnapshotPriority = None,
|
||||
checksum: str = None,
|
||||
snapshot: IO[Any] = None,
|
||||
):
|
||||
"""
|
||||
Recover shard of a local collection from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection shard.
|
||||
"""
|
||||
path_params = {
|
||||
"collection_name": str(collection_name),
|
||||
"shard_id": str(shard_id),
|
||||
}
|
||||
|
||||
query_params = {}
|
||||
if wait is not None:
|
||||
query_params["wait"] = str(wait).lower()
|
||||
if priority is not None:
|
||||
query_params["priority"] = str(priority)
|
||||
if checksum is not None:
|
||||
query_params["checksum"] = str(checksum)
|
||||
|
||||
headers = {}
|
||||
files: Dict[str, IO[Any]] = {} # noqa F841
|
||||
data: Dict[str, Any] = {} # noqa F841
|
||||
if snapshot is not None:
|
||||
files["snapshot"] = snapshot
|
||||
|
||||
return self.api_client.request(
|
||||
type_=m.InlineResponse2009,
|
||||
method="POST",
|
||||
url="/collections/{collection_name}/shards/{shard_id}/snapshots/upload",
|
||||
headers=headers if headers else None,
|
||||
path_params=path_params,
|
||||
params=query_params,
|
||||
data=data,
|
||||
files=files,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSnapshotsApi(_SnapshotsApi):
|
||||
async def create_full_snapshot(
|
||||
self,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse20011:
|
||||
"""
|
||||
Create new snapshot of the whole storage
|
||||
"""
|
||||
return await self._build_for_create_full_snapshot(
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
async def create_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse20011:
|
||||
"""
|
||||
Create new snapshot of a shard for a collection
|
||||
"""
|
||||
return await self._build_for_create_shard_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
async def create_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse20011:
|
||||
"""
|
||||
Create new snapshot for a collection
|
||||
"""
|
||||
return await self._build_for_create_snapshot(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
async def delete_full_snapshot(
|
||||
self,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Delete snapshot of the whole storage
|
||||
"""
|
||||
return await self._build_for_delete_full_snapshot(
|
||||
snapshot_name=snapshot_name,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
async def delete_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Delete snapshot of a shard for a collection
|
||||
"""
|
||||
return await self._build_for_delete_shard_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
snapshot_name=snapshot_name,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
async def delete_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Delete snapshot for a collection
|
||||
"""
|
||||
return await self._build_for_delete_snapshot(
|
||||
collection_name=collection_name,
|
||||
snapshot_name=snapshot_name,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
async def get_full_snapshot(
|
||||
self,
|
||||
snapshot_name: str,
|
||||
) -> file:
|
||||
"""
|
||||
Download specified snapshot of the whole storage as a file
|
||||
"""
|
||||
return await self._build_for_get_full_snapshot(
|
||||
snapshot_name=snapshot_name,
|
||||
)
|
||||
|
||||
async def get_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
snapshot_name: str,
|
||||
) -> file:
|
||||
"""
|
||||
Download specified snapshot of a shard from a collection as a file
|
||||
"""
|
||||
return await self._build_for_get_shard_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
snapshot_name=snapshot_name,
|
||||
)
|
||||
|
||||
async def get_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
snapshot_name: str,
|
||||
) -> file:
|
||||
"""
|
||||
Download specified snapshot from a collection as a file
|
||||
"""
|
||||
return await self._build_for_get_snapshot(
|
||||
collection_name=collection_name,
|
||||
snapshot_name=snapshot_name,
|
||||
)
|
||||
|
||||
async def list_full_snapshots(
|
||||
self,
|
||||
) -> m.InlineResponse20010:
|
||||
"""
|
||||
Get list of snapshots of the whole storage
|
||||
"""
|
||||
return await self._build_for_list_full_snapshots()
|
||||
|
||||
async def list_shard_snapshots(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
) -> m.InlineResponse20010:
|
||||
"""
|
||||
Get list of snapshots for a shard of a collection
|
||||
"""
|
||||
return await self._build_for_list_shard_snapshots(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
|
||||
async def list_snapshots(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse20010:
|
||||
"""
|
||||
Get list of snapshots for a collection
|
||||
"""
|
||||
return await self._build_for_list_snapshots(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
async def recover_from_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
snapshot_recover: m.SnapshotRecover = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Recover local collection data from a snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
|
||||
"""
|
||||
return await self._build_for_recover_from_snapshot(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
snapshot_recover=snapshot_recover,
|
||||
)
|
||||
|
||||
async def recover_from_uploaded_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
priority: SnapshotPriority = None,
|
||||
checksum: str = None,
|
||||
snapshot: IO[Any] = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Recover local collection data from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
|
||||
"""
|
||||
return await self._build_for_recover_from_uploaded_snapshot(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
priority=priority,
|
||||
checksum=checksum,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
async def recover_shard_from_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
shard_snapshot_recover: m.ShardSnapshotRecover = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Recover shard of a local collection data from a snapshot. This will overwrite any data, stored in this shard, for the collection.
|
||||
"""
|
||||
return await self._build_for_recover_shard_from_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
wait=wait,
|
||||
shard_snapshot_recover=shard_snapshot_recover,
|
||||
)
|
||||
|
||||
async def recover_shard_from_uploaded_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
priority: SnapshotPriority = None,
|
||||
checksum: str = None,
|
||||
snapshot: IO[Any] = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Recover shard of a local collection from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection shard.
|
||||
"""
|
||||
return await self._build_for_recover_shard_from_uploaded_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
wait=wait,
|
||||
priority=priority,
|
||||
checksum=checksum,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
|
||||
class SyncSnapshotsApi(_SnapshotsApi):
|
||||
def create_full_snapshot(
|
||||
self,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse20011:
|
||||
"""
|
||||
Create new snapshot of the whole storage
|
||||
"""
|
||||
return self._build_for_create_full_snapshot(
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
def create_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse20011:
|
||||
"""
|
||||
Create new snapshot of a shard for a collection
|
||||
"""
|
||||
return self._build_for_create_shard_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
def create_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse20011:
|
||||
"""
|
||||
Create new snapshot for a collection
|
||||
"""
|
||||
return self._build_for_create_snapshot(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
def delete_full_snapshot(
|
||||
self,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Delete snapshot of the whole storage
|
||||
"""
|
||||
return self._build_for_delete_full_snapshot(
|
||||
snapshot_name=snapshot_name,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
def delete_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Delete snapshot of a shard for a collection
|
||||
"""
|
||||
return self._build_for_delete_shard_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
snapshot_name=snapshot_name,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
def delete_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
snapshot_name: str,
|
||||
wait: bool = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Delete snapshot for a collection
|
||||
"""
|
||||
return self._build_for_delete_snapshot(
|
||||
collection_name=collection_name,
|
||||
snapshot_name=snapshot_name,
|
||||
wait=wait,
|
||||
)
|
||||
|
||||
def get_full_snapshot(
|
||||
self,
|
||||
snapshot_name: str,
|
||||
) -> file:
|
||||
"""
|
||||
Download specified snapshot of the whole storage as a file
|
||||
"""
|
||||
return self._build_for_get_full_snapshot(
|
||||
snapshot_name=snapshot_name,
|
||||
)
|
||||
|
||||
def get_shard_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
snapshot_name: str,
|
||||
) -> file:
|
||||
"""
|
||||
Download specified snapshot of a shard from a collection as a file
|
||||
"""
|
||||
return self._build_for_get_shard_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
snapshot_name=snapshot_name,
|
||||
)
|
||||
|
||||
def get_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
snapshot_name: str,
|
||||
) -> file:
|
||||
"""
|
||||
Download specified snapshot from a collection as a file
|
||||
"""
|
||||
return self._build_for_get_snapshot(
|
||||
collection_name=collection_name,
|
||||
snapshot_name=snapshot_name,
|
||||
)
|
||||
|
||||
def list_full_snapshots(
|
||||
self,
|
||||
) -> m.InlineResponse20010:
|
||||
"""
|
||||
Get list of snapshots of the whole storage
|
||||
"""
|
||||
return self._build_for_list_full_snapshots()
|
||||
|
||||
def list_shard_snapshots(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
) -> m.InlineResponse20010:
|
||||
"""
|
||||
Get list of snapshots for a shard of a collection
|
||||
"""
|
||||
return self._build_for_list_shard_snapshots(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
|
||||
def list_snapshots(
|
||||
self,
|
||||
collection_name: str,
|
||||
) -> m.InlineResponse20010:
|
||||
"""
|
||||
Get list of snapshots for a collection
|
||||
"""
|
||||
return self._build_for_list_snapshots(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
def recover_from_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
snapshot_recover: m.SnapshotRecover = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Recover local collection data from a snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
|
||||
"""
|
||||
return self._build_for_recover_from_snapshot(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
snapshot_recover=snapshot_recover,
|
||||
)
|
||||
|
||||
def recover_from_uploaded_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
wait: bool = None,
|
||||
priority: SnapshotPriority = None,
|
||||
checksum: str = None,
|
||||
snapshot: IO[Any] = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Recover local collection data from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection. If collection does not exist - it will be created.
|
||||
"""
|
||||
return self._build_for_recover_from_uploaded_snapshot(
|
||||
collection_name=collection_name,
|
||||
wait=wait,
|
||||
priority=priority,
|
||||
checksum=checksum,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
def recover_shard_from_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
shard_snapshot_recover: m.ShardSnapshotRecover = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Recover shard of a local collection data from a snapshot. This will overwrite any data, stored in this shard, for the collection.
|
||||
"""
|
||||
return self._build_for_recover_shard_from_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
wait=wait,
|
||||
shard_snapshot_recover=shard_snapshot_recover,
|
||||
)
|
||||
|
||||
def recover_shard_from_uploaded_snapshot(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_id: int,
|
||||
wait: bool = None,
|
||||
priority: SnapshotPriority = None,
|
||||
checksum: str = None,
|
||||
snapshot: IO[Any] = None,
|
||||
) -> m.InlineResponse2009:
|
||||
"""
|
||||
Recover shard of a local collection from an uploaded snapshot. This will overwrite any data, stored on this node, for the collection shard.
|
||||
"""
|
||||
return self._build_for_recover_shard_from_uploaded_snapshot(
|
||||
collection_name=collection_name,
|
||||
shard_id=shard_id,
|
||||
wait=wait,
|
||||
priority=priority,
|
||||
checksum=checksum,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
from asyncio import get_event_loop
|
||||
from functools import lru_cache
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, Type, TypeVar, overload
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from httpx import AsyncClient, Client, Request, Response
|
||||
from pydantic import ValidationError
|
||||
from qdrant_client.common.client_exceptions import ResourceExhaustedResponse
|
||||
from qdrant_client.http.api.aliases_api import AsyncAliasesApi, SyncAliasesApi
|
||||
from qdrant_client.http.api.beta_api import AsyncBetaApi, SyncBetaApi
|
||||
from qdrant_client.http.api.collections_api import AsyncCollectionsApi, SyncCollectionsApi
|
||||
from qdrant_client.http.api.distributed_api import AsyncDistributedApi, SyncDistributedApi
|
||||
from qdrant_client.http.api.indexes_api import AsyncIndexesApi, SyncIndexesApi
|
||||
from qdrant_client.http.api.points_api import AsyncPointsApi, SyncPointsApi
|
||||
from qdrant_client.http.api.search_api import AsyncSearchApi, SyncSearchApi
|
||||
from qdrant_client.http.api.service_api import AsyncServiceApi, SyncServiceApi
|
||||
from qdrant_client.http.api.snapshots_api import AsyncSnapshotsApi, SyncSnapshotsApi
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException, UnexpectedResponse
|
||||
|
||||
ClientT = TypeVar("ClientT", bound="ApiClient")
|
||||
AsyncClientT = TypeVar("AsyncClientT", bound="AsyncApiClient")
|
||||
|
||||
|
||||
class AsyncApis(Generic[AsyncClientT]):
|
||||
def __init__(self, host: str, **kwargs: Any):
|
||||
self.client = AsyncApiClient(host, **kwargs)
|
||||
|
||||
self.aliases_api = AsyncAliasesApi(self.client)
|
||||
self.beta_api = AsyncBetaApi(self.client)
|
||||
self.collections_api = AsyncCollectionsApi(self.client)
|
||||
self.distributed_api = AsyncDistributedApi(self.client)
|
||||
self.indexes_api = AsyncIndexesApi(self.client)
|
||||
self.points_api = AsyncPointsApi(self.client)
|
||||
self.search_api = AsyncSearchApi(self.client)
|
||||
self.service_api = AsyncServiceApi(self.client)
|
||||
self.snapshots_api = AsyncSnapshotsApi(self.client)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
class SyncApis(Generic[ClientT]):
|
||||
def __init__(self, host: str, **kwargs: Any):
|
||||
self.client = ApiClient(host, **kwargs)
|
||||
|
||||
self.aliases_api = SyncAliasesApi(self.client)
|
||||
self.beta_api = SyncBetaApi(self.client)
|
||||
self.collections_api = SyncCollectionsApi(self.client)
|
||||
self.distributed_api = SyncDistributedApi(self.client)
|
||||
self.indexes_api = SyncIndexesApi(self.client)
|
||||
self.points_api = SyncPointsApi(self.client)
|
||||
self.search_api = SyncSearchApi(self.client)
|
||||
self.service_api = SyncServiceApi(self.client)
|
||||
self.snapshots_api = SyncSnapshotsApi(self.client)
|
||||
|
||||
def close(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
Send = Callable[[Request], Response]
|
||||
SendAsync = Callable[[Request], Awaitable[Response]]
|
||||
MiddlewareT = Callable[[Request, Send], Response]
|
||||
AsyncMiddlewareT = Callable[[Request, SendAsync], Awaitable[Response]]
|
||||
|
||||
|
||||
class ApiClient:
|
||||
def __init__(self, host: str, **kwargs: Any) -> None:
|
||||
self.host = host
|
||||
self.middleware: MiddlewareT = BaseMiddleware()
|
||||
self._client = Client(**kwargs)
|
||||
|
||||
@overload
|
||||
def request(self, *, type_: Type[T], method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any) -> T:
|
||||
...
|
||||
|
||||
@overload # noqa F811
|
||||
def request(self, *, type_: None, method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any) -> None:
|
||||
...
|
||||
|
||||
def request( # noqa F811
|
||||
self, *, type_: Any, method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if path_params is None:
|
||||
path_params = {}
|
||||
|
||||
host = self.host if self.host.endswith("/") else self.host + "/"
|
||||
url = url[1:] if url.startswith("/") else url
|
||||
# in order to do a correct join, url join requires base_url to end with /, and url to not start with /,
|
||||
# since url is treated as an absolute path and might truncate prefix in base_url
|
||||
url = urljoin(host, url.format(**path_params))
|
||||
if "params" in kwargs and "timeout" in kwargs["params"]:
|
||||
kwargs["timeout"] = int(kwargs["params"]["timeout"])
|
||||
request = self._client.build_request(method, url, **kwargs)
|
||||
return self.send(request, type_)
|
||||
|
||||
@overload
|
||||
def request_sync(self, *, type_: Type[T], **kwargs: Any) -> T:
|
||||
...
|
||||
|
||||
@overload # noqa F811
|
||||
def request_sync(self, *, type_: None, **kwargs: Any) -> None:
|
||||
...
|
||||
|
||||
def request_sync(self, *, type_: Any, **kwargs: Any) -> Any: # noqa F811
|
||||
"""
|
||||
This method is not used by the generated apis, but is included for convenience
|
||||
"""
|
||||
return get_event_loop().run_until_complete(self.request(type_=type_, **kwargs))
|
||||
|
||||
def send(self, request: Request, type_: Type[T]) -> T:
|
||||
response = self.middleware(request, self.send_inner)
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after_s = response.headers.get("Retry-After", None)
|
||||
try:
|
||||
resp = response.json()
|
||||
message = resp["status"]["error"] if resp["status"] and resp["status"]["error"] else ""
|
||||
except Exception:
|
||||
message = ""
|
||||
|
||||
if retry_after_s:
|
||||
raise ResourceExhaustedResponse(message, retry_after_s)
|
||||
|
||||
if response.status_code in [200, 201, 202]:
|
||||
try:
|
||||
return parse_as_type(response.json(), type_)
|
||||
except ValidationError as e:
|
||||
raise ResponseHandlingException(e)
|
||||
raise UnexpectedResponse.for_response(response)
|
||||
|
||||
def send_inner(self, request: Request) -> Response:
|
||||
try:
|
||||
response = self._client.send(request)
|
||||
except Exception as e:
|
||||
raise ResponseHandlingException(e)
|
||||
return response
|
||||
|
||||
def close(self) -> None:
|
||||
self._client.close()
|
||||
|
||||
def add_middleware(self, middleware: MiddlewareT) -> None:
|
||||
current_middleware = self.middleware
|
||||
|
||||
def new_middleware(request: Request, call_next: Send) -> Response:
|
||||
def inner_send(request: Request) -> Response:
|
||||
return current_middleware(request, call_next)
|
||||
|
||||
return middleware(request, inner_send)
|
||||
|
||||
self.middleware = new_middleware
|
||||
|
||||
|
||||
class AsyncApiClient:
|
||||
def __init__(self, host: str = None, **kwargs: Any) -> None:
|
||||
self.host = host
|
||||
self.middleware: AsyncMiddlewareT = BaseAsyncMiddleware()
|
||||
self._async_client = AsyncClient(**kwargs)
|
||||
|
||||
@overload
|
||||
async def request(
|
||||
self, *, type_: Type[T], method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any
|
||||
) -> T:
|
||||
...
|
||||
|
||||
@overload # noqa F811
|
||||
async def request(
|
||||
self, *, type_: None, method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def request( # noqa F811
|
||||
self, *, type_: Any, method: str, url: str, path_params: Dict[str, Any] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if path_params is None:
|
||||
path_params = {}
|
||||
|
||||
host = self.host if self.host.endswith("/") else self.host + "/"
|
||||
url = url[1:] if url.startswith("/") else url
|
||||
# in order to do a correct join, url join requires base_url to end with /, and url to not start with /,
|
||||
# since url is treated as an absolute path and might truncate prefix in base_url
|
||||
url = urljoin(host, url.format(**path_params))
|
||||
request = self._async_client.build_request(method, url, **kwargs)
|
||||
return await self.send(request, type_)
|
||||
|
||||
@overload
|
||||
def request_sync(self, *, type_: Type[T], **kwargs: Any) -> T:
|
||||
...
|
||||
|
||||
@overload # noqa F811
|
||||
def request_sync(self, *, type_: None, **kwargs: Any) -> None:
|
||||
...
|
||||
|
||||
def request_sync(self, *, type_: Any, **kwargs: Any) -> Any: # noqa F811
|
||||
"""
|
||||
This method is not used by the generated apis, but is included for convenience
|
||||
"""
|
||||
return get_event_loop().run_until_complete(self.request(type_=type_, **kwargs))
|
||||
|
||||
async def send(self, request: Request, type_: Type[T]) -> T:
|
||||
response = await self.middleware(request, self.send_inner)
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after_s = response.headers.get("Retry-After", None)
|
||||
try:
|
||||
resp = response.json()
|
||||
message = resp["status"]["error"] if resp["status"] and resp["status"]["error"] else ""
|
||||
except Exception:
|
||||
message = ""
|
||||
|
||||
if retry_after_s:
|
||||
raise ResourceExhaustedResponse(message, retry_after_s)
|
||||
|
||||
if response.status_code in [200, 201, 202]:
|
||||
try:
|
||||
return parse_as_type(response.json(), type_)
|
||||
except ValidationError as e:
|
||||
raise ResponseHandlingException(e)
|
||||
raise UnexpectedResponse.for_response(response)
|
||||
|
||||
async def send_inner(self, request: Request) -> Response:
|
||||
try:
|
||||
response = await self._async_client.send(request)
|
||||
except Exception as e:
|
||||
raise ResponseHandlingException(e)
|
||||
return response
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._async_client.aclose()
|
||||
|
||||
def add_middleware(self, middleware: AsyncMiddlewareT) -> None:
|
||||
current_middleware = self.middleware
|
||||
|
||||
async def new_middleware(request: Request, call_next: SendAsync) -> Response:
|
||||
async def inner_send(request: Request) -> Response:
|
||||
return await current_middleware(request, call_next)
|
||||
|
||||
return await middleware(request, inner_send)
|
||||
|
||||
self.middleware = new_middleware
|
||||
|
||||
|
||||
class BaseAsyncMiddleware:
|
||||
async def __call__(self, request: Request, call_next: SendAsync) -> Response:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class BaseMiddleware:
|
||||
def __call__(self, request: Request, call_next: Send) -> Response:
|
||||
return call_next(request)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_parsing_type(type_: Any, source: str) -> Any:
|
||||
from pydantic.main import create_model
|
||||
|
||||
type_name = getattr(type_, "__name__", str(type_))
|
||||
return create_model(f"ParsingModel[{type_name}] (for {source})", obj=(type_, ...))
|
||||
|
||||
|
||||
def parse_as_type(obj: Any, type_: Type[T]) -> T:
|
||||
model_type = _get_parsing_type(type_, source=parse_as_type.__name__)
|
||||
return model_type(obj=obj).obj
|
||||
@@ -0,0 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
# This is a dirty hack - proper way it to upgrade OpenAPI generator for at least 5.0
|
||||
# But this upgrade will also require updating of all templates. Maybe some other day
|
||||
AnyOfstringinteger = Union[str, int]
|
||||
@@ -0,0 +1,46 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
MAX_CONTENT = 200
|
||||
|
||||
|
||||
class ApiException(Exception):
|
||||
"""Base class"""
|
||||
|
||||
|
||||
class UnexpectedResponse(ApiException):
|
||||
def __init__(self, status_code: Optional[int], reason_phrase: str, content: bytes, headers: Headers) -> None:
|
||||
self.status_code = status_code
|
||||
self.reason_phrase = reason_phrase
|
||||
self.content = content
|
||||
self.headers = headers
|
||||
|
||||
@staticmethod
|
||||
def for_response(response: Response) -> "ApiException":
|
||||
return UnexpectedResponse(
|
||||
status_code=response.status_code,
|
||||
reason_phrase=response.reason_phrase,
|
||||
content=response.content,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
status_code_str = f"{self.status_code}" if self.status_code is not None else ""
|
||||
if self.reason_phrase == "" and self.status_code is not None:
|
||||
reason_phrase_str = "(Unrecognized Status Code)"
|
||||
else:
|
||||
reason_phrase_str = f"({self.reason_phrase})"
|
||||
status_str = f"{status_code_str} {reason_phrase_str}".strip()
|
||||
short_content = self.content if len(self.content) <= MAX_CONTENT else self.content[: MAX_CONTENT - 3] + b" ..."
|
||||
raw_content_str = f"Raw response content:\n{short_content!r}"
|
||||
return f"Unexpected Response: {status_str}\n{raw_content_str}"
|
||||
|
||||
def structured(self) -> Dict[str, Any]:
|
||||
return json.loads(self.content)
|
||||
|
||||
|
||||
class ResponseHandlingException(ApiException):
|
||||
def __init__(self, source: Exception):
|
||||
self.source = source
|
||||
@@ -0,0 +1 @@
|
||||
from .models import *
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,370 @@
|
||||
import math
|
||||
|
||||
|
||||
from qdrant_client.conversions.common_types import get_args_subscribed
|
||||
from qdrant_client.http import models
|
||||
from typing import Union, Any, Tuple
|
||||
|
||||
from qdrant_client.local import datetime_utils
|
||||
from qdrant_client.local.geo import geo_distance
|
||||
from qdrant_client.local.payload_filters import check_condition
|
||||
from qdrant_client.local.payload_value_extractor import value_by_key
|
||||
|
||||
DEFAULT_SCORE = 0.0
|
||||
DEFAULT_DECAY_TARGET = 0.0
|
||||
DEFAULT_DECAY_MIDPOINT = 0.5
|
||||
DEFAULT_DECAY_SCALE = 1.0
|
||||
|
||||
|
||||
def evaluate_expression(
|
||||
expression: models.Expression,
|
||||
point_id: models.ExtendedPointId,
|
||||
scores: list[dict[models.ExtendedPointId, float]],
|
||||
payload: models.Payload,
|
||||
has_vector: dict[str, bool],
|
||||
defaults: dict[str, Any],
|
||||
) -> float:
|
||||
if isinstance(expression, (float, int)): # Constant
|
||||
return float(expression)
|
||||
|
||||
elif isinstance(expression, str): # Variable
|
||||
return evaluate_variable(expression, point_id, scores, payload, defaults)
|
||||
|
||||
elif isinstance(expression, get_args_subscribed(models.Condition)):
|
||||
if check_condition(expression, payload, point_id, has_vector): # type: ignore
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
elif isinstance(expression, models.MultExpression):
|
||||
factors: list[float] = []
|
||||
|
||||
for expr in expression.mult:
|
||||
factor = evaluate_expression(expr, point_id, scores, payload, has_vector, defaults)
|
||||
# Return early if any factor is zero
|
||||
if factor == 0.0:
|
||||
return factor
|
||||
|
||||
factors.append(factor)
|
||||
|
||||
return math.prod(factors)
|
||||
|
||||
elif isinstance(expression, models.SumExpression):
|
||||
return sum(
|
||||
evaluate_expression(expr, point_id, scores, payload, has_vector, defaults)
|
||||
for expr in expression.sum
|
||||
)
|
||||
|
||||
elif isinstance(expression, models.NegExpression):
|
||||
value = evaluate_expression(
|
||||
expression.neg, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
return -value
|
||||
|
||||
elif isinstance(expression, models.AbsExpression):
|
||||
return abs(
|
||||
evaluate_expression(expression.abs, point_id, scores, payload, has_vector, defaults)
|
||||
)
|
||||
|
||||
elif isinstance(expression, models.DivExpression):
|
||||
left = evaluate_expression(
|
||||
expression.div.left, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
if left == 0.0:
|
||||
return left
|
||||
|
||||
right = evaluate_expression(
|
||||
expression.div.right, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
if right == 0.0:
|
||||
if expression.div.by_zero_default is not None:
|
||||
return expression.div.by_zero_default
|
||||
raise_non_finite_error(f"{left}/{right}")
|
||||
|
||||
result = left / right
|
||||
if math.isfinite(result):
|
||||
return result
|
||||
|
||||
raise_non_finite_error(f"{left}/{right}")
|
||||
|
||||
elif isinstance(expression, models.SqrtExpression):
|
||||
value = evaluate_expression(
|
||||
expression.sqrt, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
if value >= 0:
|
||||
return math.sqrt(value)
|
||||
|
||||
raise_non_finite_error(f"√{value}")
|
||||
|
||||
elif isinstance(expression, models.PowExpression):
|
||||
base = evaluate_expression(
|
||||
expression.pow.base, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
exponent = evaluate_expression(
|
||||
expression.pow.exponent, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
# Check for valid input
|
||||
if base >= 0 or (base != 0 and exponent.is_integer()):
|
||||
try:
|
||||
return math.pow(base, exponent)
|
||||
except OverflowError:
|
||||
pass
|
||||
|
||||
raise_non_finite_error(f"{base}^{exponent}")
|
||||
|
||||
elif isinstance(expression, models.ExpExpression):
|
||||
value = evaluate_expression(
|
||||
expression.exp, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
try:
|
||||
return math.exp(value)
|
||||
except OverflowError:
|
||||
raise_non_finite_error(f"exp({value})")
|
||||
|
||||
elif isinstance(expression, models.Log10Expression):
|
||||
value = evaluate_expression(
|
||||
expression.log10, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
if value > 0:
|
||||
try:
|
||||
return math.log10(value)
|
||||
except OverflowError:
|
||||
pass
|
||||
|
||||
raise_non_finite_error(f"log10({value})")
|
||||
|
||||
elif isinstance(expression, models.LnExpression):
|
||||
value = evaluate_expression(expression.ln, point_id, scores, payload, has_vector, defaults)
|
||||
|
||||
if value > 0:
|
||||
try:
|
||||
return math.log(value)
|
||||
except OverflowError:
|
||||
pass
|
||||
|
||||
raise_non_finite_error(f"ln({value})")
|
||||
|
||||
elif isinstance(expression, models.GeoDistance):
|
||||
origin = expression.geo_distance.origin
|
||||
to = expression.geo_distance.to
|
||||
|
||||
# Get value from payload
|
||||
geo_value = try_extract_payload_value(to, payload, defaults)
|
||||
|
||||
if isinstance(geo_value, dict):
|
||||
# let this fail if it is not a valid geo point
|
||||
destination = models.GeoPoint(**geo_value)
|
||||
return geo_distance(origin.lon, origin.lat, destination.lon, destination.lat)
|
||||
|
||||
raise ValueError(
|
||||
f"Expected geo point for {to} in the payload and/or in the formula defaults."
|
||||
)
|
||||
|
||||
elif isinstance(expression, models.DatetimeExpression):
|
||||
# try to parse as datetime
|
||||
dt = datetime_utils.parse(expression.datetime)
|
||||
if dt is None:
|
||||
raise ValueError(f"Expected datetime in supported format for {expression.datetime}")
|
||||
|
||||
return dt.timestamp()
|
||||
|
||||
elif isinstance(expression, models.DatetimeKeyExpression):
|
||||
dt_str = try_extract_payload_value(expression.datetime_key, payload, defaults)
|
||||
dt = datetime_utils.parse(dt_str)
|
||||
if dt is None:
|
||||
raise ValueError(
|
||||
f"Expected datetime for {expression.datetime_key} in the payload and/or in the formula defaults."
|
||||
)
|
||||
|
||||
return dt.timestamp()
|
||||
|
||||
elif isinstance(expression, models.LinDecayExpression):
|
||||
x, target, midpoint, scale = evaluate_decay_params(
|
||||
expression.lin_decay, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
lambda_factor = (1.0 - midpoint) / scale
|
||||
diff = abs(x - target)
|
||||
return max(0.0, -lambda_factor * diff + 1.0)
|
||||
|
||||
elif isinstance(expression, models.ExpDecayExpression):
|
||||
x, target, midpoint, scale = evaluate_decay_params(
|
||||
expression.exp_decay, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
lambda_factor = math.log(midpoint) / scale
|
||||
diff = abs(x - target)
|
||||
return math.exp(lambda_factor * diff)
|
||||
|
||||
elif isinstance(expression, models.GaussDecayExpression):
|
||||
x, target, midpoint, scale = evaluate_decay_params(
|
||||
expression.gauss_decay, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
lambda_factor = math.log(midpoint) / (scale * scale)
|
||||
diff = x - target
|
||||
return math.exp(lambda_factor * diff * diff)
|
||||
|
||||
raise ValueError(f"Unsupported expression type: {type(expression)}")
|
||||
|
||||
|
||||
def evaluate_decay_params(
|
||||
params: models.DecayParamsExpression,
|
||||
point_id: models.ExtendedPointId,
|
||||
scores: list[dict[models.ExtendedPointId, float]],
|
||||
payload: models.Payload,
|
||||
has_vector: dict[str, bool],
|
||||
defaults: dict[str, Any],
|
||||
) -> Tuple[float, float, float, float]:
|
||||
x = evaluate_expression(params.x, point_id, scores, payload, has_vector, defaults)
|
||||
|
||||
if params.target is None:
|
||||
target = DEFAULT_DECAY_TARGET
|
||||
else:
|
||||
target = evaluate_expression(
|
||||
params.target, point_id, scores, payload, has_vector, defaults
|
||||
)
|
||||
|
||||
midpoint = params.midpoint if params.midpoint is not None else DEFAULT_DECAY_MIDPOINT
|
||||
|
||||
if midpoint <= 0.0 or midpoint >= 1.0:
|
||||
raise ValueError(f"Midpoint must be between 0 and 1, got {midpoint}")
|
||||
|
||||
scale = params.scale if params.scale is not None else DEFAULT_DECAY_SCALE
|
||||
if scale <= 0.0:
|
||||
raise ValueError(f"Scale must be non-zero positive, got {scale}")
|
||||
|
||||
return x, target, midpoint, scale
|
||||
|
||||
|
||||
def try_extract_payload_value(key: str, payload: models.Payload, defaults: dict[str, Any]) -> Any:
|
||||
# Get value from payload
|
||||
value = value_by_key(payload, key)
|
||||
|
||||
if value is None or len(value) == 0:
|
||||
# Or from defaults
|
||||
value = defaults.get(key, None)
|
||||
# Consider it None if it is an empty list
|
||||
if isinstance(value, list) and len(value) == 0:
|
||||
value = None
|
||||
|
||||
# Consider it a single value if it's a list with one element
|
||||
if isinstance(value, list) and len(value) == 1:
|
||||
return value[0]
|
||||
|
||||
if value is None:
|
||||
raise ValueError(f"No value found for {key} in the payload nor the formula defaults")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def evaluate_variable(
|
||||
variable: str,
|
||||
point_id: models.ExtendedPointId,
|
||||
scores: list[dict[models.ExtendedPointId, float]],
|
||||
payload: models.Payload,
|
||||
defaults: dict[str, Any],
|
||||
) -> float:
|
||||
var = parse_variable(variable)
|
||||
if isinstance(var, str):
|
||||
value = try_extract_payload_value(var, payload, defaults)
|
||||
|
||||
if is_number(value):
|
||||
return value
|
||||
|
||||
raise ValueError(
|
||||
f"Expected number value for {var} in the payload and/or in the formula defaults. Error: Value is not a number"
|
||||
)
|
||||
|
||||
elif isinstance(var, int):
|
||||
# Get score from scores
|
||||
score = None
|
||||
if var < len(scores):
|
||||
score = scores[var].get(point_id, None)
|
||||
if score is not None:
|
||||
return score
|
||||
|
||||
defined_default = defaults.get(variable, None)
|
||||
if defined_default is not None:
|
||||
return defined_default
|
||||
|
||||
return DEFAULT_SCORE
|
||||
|
||||
raise ValueError(f"Invalid variable type: {type(var)}")
|
||||
|
||||
|
||||
def parse_variable(var: str) -> Union[str, int]:
|
||||
# Try to parse score pattern
|
||||
if not var.startswith("$score"):
|
||||
# Treat as payload path
|
||||
return var
|
||||
|
||||
remaining = var.replace("$score", "", 1)
|
||||
if remaining == "":
|
||||
# end of string, default idx is 0
|
||||
return 0
|
||||
|
||||
# it must proceed with brackets
|
||||
if not remaining.startswith("["):
|
||||
raise ValueError(f"Invalid score pattern: {var}")
|
||||
|
||||
remaining = remaining.replace("[", "", 1)
|
||||
bracket_end = remaining.find("]")
|
||||
if bracket_end == -1:
|
||||
raise ValueError(f"Invalid score pattern: {var}")
|
||||
|
||||
# try parsing the content in between brackets as integer
|
||||
try:
|
||||
idx = int(remaining[:bracket_end])
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid score pattern: {var}")
|
||||
|
||||
# make sure the string ends after the closing bracket
|
||||
if len(remaining) > bracket_end + 1:
|
||||
raise ValueError(f"Invalid score pattern: {var}")
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
def raise_non_finite_error(expression: str) -> None:
|
||||
raise ValueError(f"The expression {expression} produced a non-finite number")
|
||||
|
||||
|
||||
def is_number(value: Any) -> bool:
|
||||
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def test_parsing_variable() -> None:
|
||||
assert parse_variable("$score") == 0
|
||||
assert parse_variable("$score[0]") == 0
|
||||
assert parse_variable("$score[1]") == 1
|
||||
assert parse_variable("$score[2]") == 2
|
||||
|
||||
try:
|
||||
parse_variable("$score[invalid]")
|
||||
assert False
|
||||
except ValueError as e:
|
||||
assert str(e) == "Invalid score pattern: $score[invalid]"
|
||||
|
||||
try:
|
||||
parse_variable("$score[10].other")
|
||||
assert False
|
||||
except ValueError as e:
|
||||
assert str(e) == "Invalid score pattern: $score[10].other"
|
||||
|
||||
|
||||
def test_try_extract_payload_value() -> None:
|
||||
for payload_value, expected in [(1.2, 1.2), ([1.2], 1.2), ([1.2, 2.3], [1.2, 2.3])]:
|
||||
empty_defaults: dict[str, Any] = {}
|
||||
|
||||
payload = {"key": payload_value}
|
||||
assert try_extract_payload_value("key", payload, empty_defaults) == expected
|
||||
|
||||
defaults = {"key": payload_value}
|
||||
empty_payload: dict[str, Any] = {}
|
||||
assert try_extract_payload_value("key", empty_payload, defaults) == expected
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Optional
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
|
||||
DEFAULT_RANKING_CONSTANT_K = 2
|
||||
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
responses: list[list[models.ScoredPoint]],
|
||||
limit: int = 10,
|
||||
ranking_constant_k: Optional[int] = None,
|
||||
) -> list[models.ScoredPoint]:
|
||||
def compute_score(pos: int) -> float:
|
||||
ranking_constant = (
|
||||
ranking_constant_k if ranking_constant_k is not None else DEFAULT_RANKING_CONSTANT_K
|
||||
) # mitigates the impact of high rankings by outlier systems
|
||||
return 1 / (ranking_constant + pos)
|
||||
|
||||
scores: dict[models.ExtendedPointId, float] = {}
|
||||
point_pile = {}
|
||||
for response in responses:
|
||||
for i, scored_point in enumerate(response):
|
||||
if scored_point.id in scores:
|
||||
scores[scored_point.id] += compute_score(i)
|
||||
else:
|
||||
point_pile[scored_point.id] = scored_point
|
||||
scores[scored_point.id] = compute_score(i)
|
||||
|
||||
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
|
||||
sorted_points = []
|
||||
for point_id, score in sorted_scores[:limit]:
|
||||
point = point_pile[point_id]
|
||||
point.score = score
|
||||
sorted_points.append(point)
|
||||
return sorted_points
|
||||
|
||||
|
||||
def distribution_based_score_fusion(
|
||||
responses: list[list[models.ScoredPoint]], limit: int
|
||||
) -> list[models.ScoredPoint]:
|
||||
def normalize(response: list[models.ScoredPoint]) -> list[models.ScoredPoint]:
|
||||
if len(response) == 1:
|
||||
response[0].score = 0.5
|
||||
return response
|
||||
|
||||
total = sum([point.score for point in response])
|
||||
mean = total / len(response)
|
||||
variance = sum([(point.score - mean) ** 2 for point in response]) / (len(response) - 1)
|
||||
|
||||
if variance == 0:
|
||||
for point in response:
|
||||
point.score = 0.5
|
||||
return response
|
||||
|
||||
std_dev = variance**0.5
|
||||
low = mean - 3 * std_dev
|
||||
high = mean + 3 * std_dev
|
||||
|
||||
for point in response:
|
||||
point.score = (point.score - low) / (high - low)
|
||||
|
||||
return response
|
||||
|
||||
points_map: dict[models.ExtendedPointId, models.ScoredPoint] = {}
|
||||
for response in responses:
|
||||
if not response:
|
||||
continue
|
||||
normalized = normalize(response)
|
||||
for point in normalized:
|
||||
entry = points_map.get(point.id)
|
||||
if entry is None:
|
||||
points_map[point.id] = point
|
||||
else:
|
||||
entry.score += point.score
|
||||
|
||||
sorted_points = sorted(points_map.values(), key=lambda item: item.score, reverse=True)
|
||||
|
||||
return sorted_points[:limit]
|
||||
@@ -0,0 +1,117 @@
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion, distribution_based_score_fusion
|
||||
|
||||
|
||||
def test_reciprocal_rank_fusion() -> None:
|
||||
responses = [
|
||||
[
|
||||
models.ScoredPoint(id="1", score=0.1, version=1),
|
||||
models.ScoredPoint(id="2", score=0.2, version=1),
|
||||
models.ScoredPoint(id="3", score=0.3, version=1),
|
||||
],
|
||||
[
|
||||
models.ScoredPoint(id="5", score=12.0, version=1),
|
||||
models.ScoredPoint(id="6", score=8.0, version=1),
|
||||
models.ScoredPoint(id="7", score=5.0, version=1),
|
||||
models.ScoredPoint(id="2", score=3.0, version=1),
|
||||
],
|
||||
]
|
||||
|
||||
fused = reciprocal_rank_fusion(responses)
|
||||
|
||||
assert fused[0].id == "2"
|
||||
assert fused[1].id in ["1", "5"]
|
||||
assert np.isclose(fused[1].score, 1 / 2)
|
||||
assert fused[2].id in ["1", "5"]
|
||||
assert np.isclose(fused[2].score, 1 / 2)
|
||||
|
||||
|
||||
def test_distribution_based_score_fusion() -> None:
|
||||
responses = [
|
||||
[
|
||||
models.ScoredPoint(id=1, version=0, score=85.0),
|
||||
models.ScoredPoint(id=0, version=0, score=76.0),
|
||||
models.ScoredPoint(id=5, version=0, score=68.0),
|
||||
],
|
||||
[
|
||||
models.ScoredPoint(id=1, version=0, score=62.0),
|
||||
models.ScoredPoint(id=0, version=0, score=61.0),
|
||||
models.ScoredPoint(id=4, version=0, score=57.0),
|
||||
models.ScoredPoint(id=3, version=0, score=51.0),
|
||||
models.ScoredPoint(id=2, version=0, score=44.0),
|
||||
],
|
||||
]
|
||||
|
||||
fused = distribution_based_score_fusion(responses, limit=3)
|
||||
|
||||
assert fused[0].id == 1
|
||||
assert fused[1].id == 0
|
||||
assert fused[2].id == 4
|
||||
|
||||
|
||||
def test_reciprocal_rank_fusion_empty_responses() -> None:
|
||||
responses: list[list[models.ScoredPoint]] = [[]]
|
||||
fused = reciprocal_rank_fusion(responses)
|
||||
assert fused == []
|
||||
|
||||
responses = [
|
||||
[
|
||||
models.ScoredPoint(id="1", score=0.1, version=1),
|
||||
models.ScoredPoint(id="2", score=0.2, version=1),
|
||||
models.ScoredPoint(id="3", score=0.3, version=1),
|
||||
],
|
||||
[],
|
||||
]
|
||||
|
||||
fused = reciprocal_rank_fusion(responses)
|
||||
|
||||
assert fused[0].id == "1"
|
||||
assert np.isclose(fused[0].score, 1 / 2)
|
||||
assert fused[1].id == "2"
|
||||
assert np.isclose(fused[1].score, 1 / 3)
|
||||
assert fused[2].id == "3"
|
||||
assert np.isclose(fused[2].score, 1 / 4)
|
||||
|
||||
|
||||
def test_distribution_based_score_fusion_empty_response() -> None:
|
||||
responses: list[list[models.ScoredPoint]] = [[]]
|
||||
fused = distribution_based_score_fusion(responses, limit=3)
|
||||
assert fused == []
|
||||
|
||||
responses = [
|
||||
[
|
||||
models.ScoredPoint(id=1, version=0, score=85.0),
|
||||
models.ScoredPoint(id=0, version=0, score=76.0),
|
||||
models.ScoredPoint(id=5, version=0, score=68.0),
|
||||
],
|
||||
[],
|
||||
]
|
||||
|
||||
fused = distribution_based_score_fusion(responses, limit=3)
|
||||
|
||||
assert fused[0].id == 1
|
||||
assert fused[1].id == 0
|
||||
assert fused[2].id == 5
|
||||
|
||||
|
||||
def test_distribution_based_score_fusion_zero_variance() -> None:
|
||||
score = 85.0
|
||||
responses = [
|
||||
[
|
||||
models.ScoredPoint(id=1, version=0, score=score),
|
||||
models.ScoredPoint(id=0, version=0, score=score),
|
||||
models.ScoredPoint(id=5, version=0, score=score),
|
||||
],
|
||||
[],
|
||||
]
|
||||
fused = distribution_based_score_fusion(
|
||||
[[models.ScoredPoint(id=1, version=0, score=score)]], limit=3
|
||||
)
|
||||
assert fused[0].id == 1
|
||||
assert fused[0].score == 0.5
|
||||
|
||||
fused = distribution_based_score_fusion(responses, limit=3)
|
||||
assert len(fused) == 3
|
||||
assert all([p.score == 0.5 for p in fused])
|
||||
@@ -0,0 +1,973 @@
|
||||
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
|
||||
#
|
||||
# This file is autogenerated. Do not edit it manually.
|
||||
# To regenerate this file, use
|
||||
#
|
||||
# ```
|
||||
# bash -x tools/generate_async_client.sh
|
||||
# ```
|
||||
#
|
||||
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
|
||||
|
||||
import importlib.metadata
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from io import TextIOWrapper
|
||||
from typing import Any, Generator, Iterable, Mapping, Optional, Sequence, Union, get_args
|
||||
from uuid import uuid4
|
||||
import numpy as np
|
||||
import portalocker
|
||||
from qdrant_client.common.client_warnings import show_warning, show_warning_once
|
||||
from qdrant_client._pydantic_compat import to_dict
|
||||
from qdrant_client.async_client_base import AsyncQdrantBase
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.http import models as rest_models
|
||||
from qdrant_client.local.local_collection import (
|
||||
LocalCollection,
|
||||
DEFAULT_VECTOR_NAME,
|
||||
ignore_mentioned_ids_filter,
|
||||
)
|
||||
|
||||
META_INFO_FILENAME = "meta.json"
|
||||
|
||||
|
||||
class AsyncQdrantLocal(AsyncQdrantBase):
|
||||
"""
|
||||
Everything Qdrant server can do, but locally.
|
||||
|
||||
Use this implementation to run vector search without running a Qdrant server.
|
||||
Everything that works with local Qdrant will work with server Qdrant as well.
|
||||
|
||||
Use for small-scale data, demos, and tests.
|
||||
If you need more speed or size, use Qdrant server.
|
||||
"""
|
||||
|
||||
LARGE_DATA_THRESHOLD = 20000
|
||||
|
||||
def __init__(self, location: str, force_disable_check_same_thread: bool = False) -> None:
|
||||
"""
|
||||
Initialize local Qdrant.
|
||||
|
||||
Args:
|
||||
location: Where to store data. Can be a path to a directory or `:memory:` for in-memory storage.
|
||||
force_disable_check_same_thread: Disable SQLite check_same_thread check. Use only if you know what you are doing.
|
||||
"""
|
||||
super().__init__()
|
||||
self.force_disable_check_same_thread = force_disable_check_same_thread
|
||||
self.location = location
|
||||
self.persistent = location != ":memory:"
|
||||
self.collections: dict[str, LocalCollection] = {}
|
||||
self.aliases: dict[str, str] = {}
|
||||
self._flock_file: Optional[TextIOWrapper] = None
|
||||
self._load()
|
||||
self._closed: bool = False
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def close(self, **kwargs: Any) -> None:
|
||||
self._closed = True
|
||||
for collection in self.collections.values():
|
||||
if collection is not None:
|
||||
collection.close()
|
||||
else:
|
||||
show_warning(
|
||||
message=f"Collection appears to be None before closing. The existing collections are: {list(self.collections.keys())}",
|
||||
category=UserWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
try:
|
||||
if self._flock_file is not None and (not self._flock_file.closed):
|
||||
portalocker.unlock(self._flock_file)
|
||||
self._flock_file.close()
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
def _load(self) -> None:
|
||||
deprecated_config_fields = ("init_from",)
|
||||
if not self.persistent:
|
||||
return
|
||||
meta_path = os.path.join(self.location, META_INFO_FILENAME)
|
||||
if not os.path.exists(meta_path):
|
||||
os.makedirs(self.location, exist_ok=True)
|
||||
with open(meta_path, "w") as f:
|
||||
f.write(json.dumps({"collections": {}, "aliases": {}}))
|
||||
else:
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
for collection_name, config_json in meta["collections"].items():
|
||||
for key in deprecated_config_fields:
|
||||
config_json.pop(key, None)
|
||||
config = rest_models.CreateCollection(**config_json)
|
||||
collection_path = self._collection_path(collection_name)
|
||||
collection = LocalCollection(
|
||||
config,
|
||||
collection_path,
|
||||
force_disable_check_same_thread=self.force_disable_check_same_thread,
|
||||
)
|
||||
self.collections[collection_name] = collection
|
||||
if len(collection.ids) > self.LARGE_DATA_THRESHOLD:
|
||||
show_warning(
|
||||
f"Local mode is not recommended for collections with more than {self.LARGE_DATA_THRESHOLD:,} points. Collection <{collection_name}> contains {len(collection.ids)} points. Consider using Qdrant in Docker or Qdrant Cloud for better performance with large datasets.",
|
||||
category=UserWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
self.aliases = meta["aliases"]
|
||||
lock_file_path = os.path.join(self.location, ".lock")
|
||||
if not os.path.exists(lock_file_path):
|
||||
os.makedirs(self.location, exist_ok=True)
|
||||
with open(lock_file_path, "w") as f:
|
||||
f.write("tmp lock file")
|
||||
self._flock_file = open(lock_file_path, "r+")
|
||||
try:
|
||||
portalocker.lock(
|
||||
self._flock_file,
|
||||
portalocker.LockFlags.EXCLUSIVE | portalocker.LockFlags.NON_BLOCKING,
|
||||
)
|
||||
except portalocker.exceptions.LockException:
|
||||
raise RuntimeError(
|
||||
f"Storage folder {self.location} is already accessed by another instance of Qdrant client. If you require concurrent access, use Qdrant server instead."
|
||||
)
|
||||
|
||||
def _save(self) -> None:
|
||||
if not self.persistent:
|
||||
return
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
meta_path = os.path.join(self.location, META_INFO_FILENAME)
|
||||
with open(meta_path, "w") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"collections": {
|
||||
collection_name: to_dict(collection.config)
|
||||
for (collection_name, collection) in self.collections.items()
|
||||
},
|
||||
"aliases": self.aliases,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _get_collection(self, collection_name: str) -> LocalCollection:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
if collection_name in self.collections:
|
||||
return self.collections[collection_name]
|
||||
if collection_name in self.aliases:
|
||||
return self.collections[self.aliases[collection_name]]
|
||||
raise ValueError(f"Collection {collection_name} not found")
|
||||
|
||||
def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: Union[
|
||||
types.NumpyArray,
|
||||
Sequence[float],
|
||||
tuple[str, list[float]],
|
||||
types.NamedVector,
|
||||
types.NamedSparseVector,
|
||||
],
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
offset: Optional[int] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[types.ScoredPoint]:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.search(
|
||||
query_vector=query_vector,
|
||||
query_filter=query_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
async def search_matrix_offsets(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
limit: int = 3,
|
||||
sample: int = 10,
|
||||
using: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.SearchMatrixOffsetsResponse:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.search_matrix_offsets(
|
||||
query_filter=query_filter, limit=limit, sample=sample, using=using
|
||||
)
|
||||
|
||||
async def search_matrix_pairs(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
limit: int = 3,
|
||||
sample: int = 10,
|
||||
using: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.SearchMatrixPairsResponse:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.search_matrix_pairs(
|
||||
query_filter=query_filter, limit=limit, sample=sample, using=using
|
||||
)
|
||||
|
||||
def _resolve_query_input(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: Optional[types.Query],
|
||||
using: Optional[str],
|
||||
lookup_from: Optional[types.LookupLocation],
|
||||
) -> tuple[types.Query, set[types.PointId]]:
|
||||
"""
|
||||
Resolves any possible ids into vectors and returns a new query object, along with a set of the mentioned
|
||||
point ids that should be filtered when searching.
|
||||
"""
|
||||
lookup_collection_name = lookup_from.collection if lookup_from else collection_name
|
||||
collection = self._get_collection(lookup_collection_name)
|
||||
search_in_vector_name = using if using is not None else DEFAULT_VECTOR_NAME
|
||||
vector_name = (
|
||||
lookup_from.vector
|
||||
if lookup_from is not None and lookup_from.vector is not None
|
||||
else search_in_vector_name
|
||||
)
|
||||
sparse = vector_name in collection.sparse_vectors
|
||||
multi = vector_name in collection.multivectors
|
||||
if sparse:
|
||||
collection_vectors = collection.sparse_vectors
|
||||
elif multi:
|
||||
collection_vectors = collection.multivectors
|
||||
else:
|
||||
collection_vectors = collection.vectors
|
||||
mentioned_ids: set[types.PointId] = set()
|
||||
|
||||
def input_into_vector(vector_input: types.VectorInput) -> types.VectorInput:
|
||||
if isinstance(vector_input, get_args(types.PointId)):
|
||||
if isinstance(vector_input, uuid.UUID):
|
||||
vector_input = str(vector_input)
|
||||
point_id = vector_input
|
||||
if point_id not in collection.ids:
|
||||
raise ValueError(f"Point {point_id} is not found in the collection")
|
||||
idx = collection.ids[point_id]
|
||||
if vector_name in collection_vectors:
|
||||
vec = collection_vectors[vector_name][idx]
|
||||
else:
|
||||
raise ValueError(f"Vector {vector_name} not found")
|
||||
if isinstance(vec, np.ndarray):
|
||||
vec = vec.tolist()
|
||||
if collection_name == lookup_collection_name:
|
||||
mentioned_ids.add(point_id)
|
||||
return vec
|
||||
else:
|
||||
return vector_input
|
||||
|
||||
query = deepcopy(query)
|
||||
if isinstance(query, rest_models.NearestQuery):
|
||||
query.nearest = input_into_vector(query.nearest)
|
||||
elif isinstance(query, rest_models.RecommendQuery):
|
||||
if query.recommend.negative is not None:
|
||||
query.recommend.negative = [
|
||||
input_into_vector(vector_input) for vector_input in query.recommend.negative
|
||||
]
|
||||
if query.recommend.positive is not None:
|
||||
query.recommend.positive = [
|
||||
input_into_vector(vector_input) for vector_input in query.recommend.positive
|
||||
]
|
||||
elif isinstance(query, rest_models.DiscoverQuery):
|
||||
query.discover.target = input_into_vector(query.discover.target)
|
||||
pairs = (
|
||||
query.discover.context
|
||||
if isinstance(query.discover.context, list)
|
||||
else [query.discover.context]
|
||||
)
|
||||
query.discover.context = [
|
||||
rest_models.ContextPair(
|
||||
positive=input_into_vector(pair.positive),
|
||||
negative=input_into_vector(pair.negative),
|
||||
)
|
||||
for pair in pairs
|
||||
]
|
||||
elif isinstance(query, rest_models.ContextQuery):
|
||||
pairs = query.context if isinstance(query.context, list) else [query.context]
|
||||
query.context = [
|
||||
rest_models.ContextPair(
|
||||
positive=input_into_vector(pair.positive),
|
||||
negative=input_into_vector(pair.negative),
|
||||
)
|
||||
for pair in pairs
|
||||
]
|
||||
elif isinstance(query, rest_models.OrderByQuery):
|
||||
pass
|
||||
elif isinstance(query, rest_models.FusionQuery):
|
||||
pass
|
||||
elif isinstance(query, rest_models.RrfQuery):
|
||||
pass
|
||||
return (query, mentioned_ids)
|
||||
|
||||
def _resolve_prefetches_input(
|
||||
self,
|
||||
prefetch: Optional[Union[Sequence[types.Prefetch], types.Prefetch]],
|
||||
collection_name: str,
|
||||
) -> list[types.Prefetch]:
|
||||
if prefetch is None:
|
||||
return []
|
||||
if isinstance(prefetch, list) and len(prefetch) == 0:
|
||||
return []
|
||||
prefetches = []
|
||||
if isinstance(prefetch, types.Prefetch):
|
||||
prefetches = [prefetch]
|
||||
prefetches.extend(
|
||||
prefetch.prefetch if isinstance(prefetch.prefetch, list) else [prefetch.prefetch]
|
||||
)
|
||||
elif isinstance(prefetch, Sequence):
|
||||
prefetches = list(prefetch)
|
||||
return [
|
||||
self._resolve_prefetch_input(prefetch, collection_name)
|
||||
for prefetch in prefetches
|
||||
if prefetch is not None
|
||||
]
|
||||
|
||||
def _resolve_prefetch_input(
|
||||
self, prefetch: types.Prefetch, collection_name: str
|
||||
) -> types.Prefetch:
|
||||
if prefetch.query is None:
|
||||
return prefetch
|
||||
prefetch = deepcopy(prefetch)
|
||||
(query, mentioned_ids) = self._resolve_query_input(
|
||||
collection_name, prefetch.query, prefetch.using, prefetch.lookup_from
|
||||
)
|
||||
prefetch.query = query
|
||||
prefetch.filter = ignore_mentioned_ids_filter(prefetch.filter, list(mentioned_ids))
|
||||
prefetch.prefetch = self._resolve_prefetches_input(prefetch.prefetch, collection_name)
|
||||
return prefetch
|
||||
|
||||
async def query_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: Optional[types.Query] = None,
|
||||
using: Optional[str] = None,
|
||||
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
offset: Optional[int] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
lookup_from: Optional[types.LookupLocation] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.QueryResponse:
|
||||
collection = self._get_collection(collection_name)
|
||||
if query is not None:
|
||||
(query, mentioned_ids) = self._resolve_query_input(
|
||||
collection_name, query, using, lookup_from
|
||||
)
|
||||
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
|
||||
prefetch = self._resolve_prefetches_input(prefetch, collection_name)
|
||||
return collection.query_points(
|
||||
query=query,
|
||||
prefetch=prefetch,
|
||||
query_filter=query_filter,
|
||||
using=using,
|
||||
score_threshold=score_threshold,
|
||||
limit=limit,
|
||||
offset=offset or 0,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
)
|
||||
|
||||
async def query_batch_points(
|
||||
self, collection_name: str, requests: Sequence[types.QueryRequest], **kwargs: Any
|
||||
) -> list[types.QueryResponse]:
|
||||
return [
|
||||
await self.query_points(
|
||||
collection_name=collection_name,
|
||||
query=request.query,
|
||||
prefetch=request.prefetch,
|
||||
query_filter=request.filter,
|
||||
limit=request.limit or 10,
|
||||
offset=request.offset,
|
||||
with_payload=request.with_payload,
|
||||
with_vectors=request.with_vector,
|
||||
score_threshold=request.score_threshold,
|
||||
using=request.using,
|
||||
lookup_from=request.lookup_from,
|
||||
)
|
||||
for request in requests
|
||||
]
|
||||
|
||||
async def query_points_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
group_by: str,
|
||||
query: Union[
|
||||
types.PointId,
|
||||
list[float],
|
||||
list[list[float]],
|
||||
types.SparseVector,
|
||||
types.Query,
|
||||
types.NumpyArray,
|
||||
types.Document,
|
||||
types.Image,
|
||||
types.InferenceObject,
|
||||
None,
|
||||
] = None,
|
||||
using: Optional[str] = None,
|
||||
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
group_size: int = 3,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
with_lookup: Optional[types.WithLookupInterface] = None,
|
||||
lookup_from: Optional[types.LookupLocation] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.GroupsResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
if query is not None:
|
||||
(query, mentioned_ids) = self._resolve_query_input(
|
||||
collection_name, query, using, lookup_from
|
||||
)
|
||||
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
|
||||
with_lookup_collection = None
|
||||
if with_lookup is not None:
|
||||
if isinstance(with_lookup, str):
|
||||
with_lookup_collection = self._get_collection(with_lookup)
|
||||
else:
|
||||
with_lookup_collection = self._get_collection(with_lookup.collection)
|
||||
return collection.query_groups(
|
||||
query=query,
|
||||
query_filter=query_filter,
|
||||
using=using,
|
||||
prefetch=prefetch,
|
||||
limit=limit,
|
||||
group_by=group_by,
|
||||
group_size=group_size,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
with_lookup=with_lookup,
|
||||
with_lookup_collection=with_lookup_collection,
|
||||
)
|
||||
|
||||
async def scroll(
|
||||
self,
|
||||
collection_name: str,
|
||||
scroll_filter: Optional[types.Filter] = None,
|
||||
limit: int = 10,
|
||||
order_by: Optional[types.OrderBy] = None,
|
||||
offset: Optional[types.PointId] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
**kwargs: Any,
|
||||
) -> tuple[list[types.Record], Optional[types.PointId]]:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.scroll(
|
||||
scroll_filter=scroll_filter,
|
||||
limit=limit,
|
||||
order_by=order_by,
|
||||
offset=offset,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
)
|
||||
|
||||
async def count(
|
||||
self,
|
||||
collection_name: str,
|
||||
count_filter: Optional[types.Filter] = None,
|
||||
exact: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> types.CountResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.count(count_filter=count_filter)
|
||||
|
||||
async def facet(
|
||||
self,
|
||||
collection_name: str,
|
||||
key: str,
|
||||
facet_filter: Optional[types.Filter] = None,
|
||||
limit: int = 10,
|
||||
exact: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> types.FacetResponse:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.facet(key=key, facet_filter=facet_filter, limit=limit)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: types.Points,
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.upsert(points, update_filter=update_filter)
|
||||
return self._default_update_result()
|
||||
|
||||
async def update_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: Sequence[types.PointVectors],
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.update_vectors(points, update_filter=update_filter)
|
||||
return self._default_update_result()
|
||||
|
||||
async def delete_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: Sequence[str],
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.delete_vectors(vectors, points)
|
||||
return self._default_update_result()
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Sequence[types.PointId],
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
**kwargs: Any,
|
||||
) -> list[types.Record]:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.retrieve(ids, with_payload, with_vectors)
|
||||
|
||||
@classmethod
|
||||
def _default_update_result(cls, operation_id: int = 0) -> types.UpdateResult:
|
||||
return types.UpdateResult(
|
||||
operation_id=operation_id, status=rest_models.UpdateStatus.COMPLETED
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.delete(points_selector)
|
||||
return self._default_update_result()
|
||||
|
||||
async def set_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: types.Payload,
|
||||
points: types.PointsSelector,
|
||||
key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.set_payload(payload=payload, selector=points, key=key)
|
||||
return self._default_update_result()
|
||||
|
||||
async def overwrite_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: types.Payload,
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.overwrite_payload(payload=payload, selector=points)
|
||||
return self._default_update_result()
|
||||
|
||||
async def delete_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
keys: Sequence[str],
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.delete_payload(keys=keys, selector=points)
|
||||
return self._default_update_result()
|
||||
|
||||
async def clear_payload(
|
||||
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.clear_payload(selector=points_selector)
|
||||
return self._default_update_result()
|
||||
|
||||
async def batch_update_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
update_operations: Sequence[types.UpdateOperation],
|
||||
**kwargs: Any,
|
||||
) -> list[types.UpdateResult]:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.batch_update_points(update_operations)
|
||||
return [self._default_update_result()] * len(update_operations)
|
||||
|
||||
async def update_collection_aliases(
|
||||
self, change_aliases_operations: Sequence[types.AliasOperations], **kwargs: Any
|
||||
) -> bool:
|
||||
for operation in change_aliases_operations:
|
||||
if isinstance(operation, rest_models.CreateAliasOperation):
|
||||
self._get_collection(operation.create_alias.collection_name)
|
||||
self.aliases[operation.create_alias.alias_name] = (
|
||||
operation.create_alias.collection_name
|
||||
)
|
||||
elif isinstance(operation, rest_models.DeleteAliasOperation):
|
||||
self.aliases.pop(operation.delete_alias.alias_name, None)
|
||||
elif isinstance(operation, rest_models.RenameAliasOperation):
|
||||
new_name = operation.rename_alias.new_alias_name
|
||||
old_name = operation.rename_alias.old_alias_name
|
||||
self.aliases[new_name] = self.aliases.pop(old_name)
|
||||
else:
|
||||
raise ValueError(f"Unknown operation: {operation}")
|
||||
self._save()
|
||||
return True
|
||||
|
||||
async def get_collection_aliases(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> types.CollectionsAliasesResponse:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
return types.CollectionsAliasesResponse(
|
||||
aliases=[
|
||||
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
|
||||
for (alias_name, name) in self.aliases.items()
|
||||
if name == collection_name
|
||||
]
|
||||
)
|
||||
|
||||
async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
return types.CollectionsAliasesResponse(
|
||||
aliases=[
|
||||
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
|
||||
for (alias_name, name) in self.aliases.items()
|
||||
]
|
||||
)
|
||||
|
||||
async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
return types.CollectionsResponse(
|
||||
collections=[
|
||||
rest_models.CollectionDescription(name=name)
|
||||
for (name, _) in self.collections.items()
|
||||
]
|
||||
)
|
||||
|
||||
async def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.info()
|
||||
|
||||
async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
try:
|
||||
self._get_collection(collection_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
async def update_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
||||
metadata: Optional[types.Payload] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
_collection = self._get_collection(collection_name)
|
||||
updated = False
|
||||
if sparse_vectors_config is not None:
|
||||
for vector_name, vector_params in sparse_vectors_config.items():
|
||||
_collection.update_sparse_vectors_config(vector_name, vector_params)
|
||||
updated = True
|
||||
if metadata is not None:
|
||||
if _collection.config.metadata is not None:
|
||||
_collection.config.metadata.update(metadata)
|
||||
else:
|
||||
_collection.config.metadata = deepcopy(metadata)
|
||||
updated = True
|
||||
self._save()
|
||||
return updated
|
||||
|
||||
def _collection_path(self, collection_name: str) -> Optional[str]:
|
||||
if self.persistent:
|
||||
return os.path.join(self.location, "collection", collection_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
_collection = self.collections.pop(collection_name, None)
|
||||
del _collection
|
||||
self.aliases = {
|
||||
alias_name: name
|
||||
for (alias_name, name) in self.aliases.items()
|
||||
if name != collection_name
|
||||
}
|
||||
collection_path = self._collection_path(collection_name)
|
||||
if collection_path is not None:
|
||||
shutil.rmtree(collection_path, ignore_errors=True)
|
||||
self._save()
|
||||
return True
|
||||
|
||||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors_config: Optional[
|
||||
Union[types.VectorParams, Mapping[str, types.VectorParams]]
|
||||
] = None,
|
||||
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
||||
metadata: Optional[types.Payload] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
if collection_name in self.collections:
|
||||
raise ValueError(f"Collection {collection_name} already exists")
|
||||
collection_path = self._collection_path(collection_name)
|
||||
if collection_path is not None:
|
||||
os.makedirs(collection_path, exist_ok=True)
|
||||
collection = LocalCollection(
|
||||
rest_models.CreateCollection(
|
||||
vectors=vectors_config or {},
|
||||
sparse_vectors=sparse_vectors_config,
|
||||
metadata=deepcopy(metadata),
|
||||
),
|
||||
location=collection_path,
|
||||
force_disable_check_same_thread=self.force_disable_check_same_thread,
|
||||
)
|
||||
self.collections[collection_name] = collection
|
||||
self._save()
|
||||
return True
|
||||
|
||||
async def recreate_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
|
||||
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
||||
metadata: Optional[types.Payload] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
await self.delete_collection(collection_name)
|
||||
return await self.create_collection(
|
||||
collection_name, vectors_config, sparse_vectors_config, metadata=metadata
|
||||
)
|
||||
|
||||
def upload_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: Iterable[types.PointStruct],
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._upload_points(collection_name, points, update_filter=update_filter)
|
||||
|
||||
def _upload_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: Iterable[Union[types.PointStruct, types.Record]],
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
) -> None:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.upsert(
|
||||
[
|
||||
rest_models.PointStruct(
|
||||
id=point.id, vector=point.vector or {}, payload=point.payload or {}
|
||||
)
|
||||
for point in points
|
||||
],
|
||||
update_filter=update_filter,
|
||||
)
|
||||
|
||||
def upload_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: Union[
|
||||
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
|
||||
],
|
||||
payload: Optional[Iterable[dict[Any, Any]]] = None,
|
||||
ids: Optional[Iterable[types.PointId]] = None,
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
def uuid_generator() -> Generator[str, None, None]:
|
||||
while True:
|
||||
yield str(uuid4())
|
||||
|
||||
collection = self._get_collection(collection_name)
|
||||
if isinstance(vectors, dict) and any(
|
||||
(isinstance(v, np.ndarray) for v in vectors.values())
|
||||
):
|
||||
assert (
|
||||
len(set([arr.shape[0] for arr in vectors.values()])) == 1
|
||||
), "Each named vector should have the same number of vectors"
|
||||
num_vectors = next(iter(vectors.values())).shape[0]
|
||||
vectors = [
|
||||
{name: vectors[name][i].tolist() for name in vectors.keys()}
|
||||
for i in range(num_vectors)
|
||||
]
|
||||
collection.upsert(
|
||||
[
|
||||
rest_models.PointStruct(
|
||||
id=str(point_id) if isinstance(point_id, uuid.UUID) else point_id,
|
||||
vector=(vector.tolist() if isinstance(vector, np.ndarray) else vector) or {},
|
||||
payload=payload or {},
|
||||
)
|
||||
for (point_id, vector, payload) in zip(
|
||||
ids or uuid_generator(), iter(vectors), payload or itertools.cycle([{}])
|
||||
)
|
||||
],
|
||||
update_filter=update_filter,
|
||||
)
|
||||
|
||||
async def create_payload_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
field_name: str,
|
||||
field_schema: Optional[types.PayloadSchemaType] = None,
|
||||
field_type: Optional[types.PayloadSchemaType] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
show_warning_once(
|
||||
message="Payload indexes have no effect in the local Qdrant. Please use server Qdrant if you need payload indexes.",
|
||||
category=UserWarning,
|
||||
idx="create-local-payload-indexes",
|
||||
stacklevel=5,
|
||||
)
|
||||
return self._default_update_result()
|
||||
|
||||
async def delete_payload_index(
|
||||
self, collection_name: str, field_name: str, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
show_warning_once(
|
||||
message="Payload indexes have no effect in the local Qdrant. Please use server Qdrant if you need payload indexes.",
|
||||
category=UserWarning,
|
||||
idx="delete-local-payload-indexes",
|
||||
stacklevel=5,
|
||||
)
|
||||
return self._default_update_result()
|
||||
|
||||
async def list_snapshots(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> list[types.SnapshotDescription]:
|
||||
return []
|
||||
|
||||
async def create_snapshot(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def delete_snapshot(
|
||||
self, collection_name: str, snapshot_name: str, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
|
||||
return []
|
||||
|
||||
async def create_full_snapshot(self, **kwargs: Any) -> types.SnapshotDescription:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def delete_full_snapshot(self, snapshot_name: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def recover_snapshot(self, collection_name: str, location: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def list_shard_snapshots(
|
||||
self, collection_name: str, shard_id: int, **kwargs: Any
|
||||
) -> list[types.SnapshotDescription]:
|
||||
return []
|
||||
|
||||
async def create_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, **kwargs: Any
|
||||
) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
|
||||
)
|
||||
|
||||
async def delete_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, snapshot_name: str, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
|
||||
)
|
||||
|
||||
async def recover_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, location: str, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
|
||||
)
|
||||
|
||||
async def create_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_key: types.ShardKey,
|
||||
shards_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
placement: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Sharding is not supported in the local Qdrant. Please use server Qdrant if you need sharding."
|
||||
)
|
||||
|
||||
async def delete_shard_key(
|
||||
self, collection_name: str, shard_key: types.ShardKey, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Sharding is not supported in the local Qdrant. Please use server Qdrant if you need sharding."
|
||||
)
|
||||
|
||||
async def info(self) -> types.VersionInfo:
|
||||
version = importlib.metadata.version("qdrant-client")
|
||||
return rest_models.VersionInfo(
|
||||
title="qdrant - vector search engine", version=version, commit=None
|
||||
)
|
||||
|
||||
async def cluster_collection_update(
|
||||
self, collection_name: str, cluster_operation: types.ClusterOperations, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Cluster collection update is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
|
||||
async def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo:
|
||||
raise NotImplementedError(
|
||||
"Collection cluster info is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
|
||||
async def cluster_status(self) -> types.ClusterStatus:
|
||||
raise NotImplementedError(
|
||||
"Cluster status is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
|
||||
async def recover_current_peer(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Recover current peer is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
|
||||
async def remove_peer(self, peer_id: int, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Remove peer info is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
# These are the formats accepted by qdrant core
|
||||
available_formats = [
|
||||
"%Y-%m-%dT%H:%M:%S.%f%z",
|
||||
"%Y-%m-%d %H:%M:%S.%f%z",
|
||||
"%Y-%m-%dT%H:%M:%S%z",
|
||||
"%Y-%m-%d %H:%M:%S%z",
|
||||
"%Y-%m-%dT%H:%M:%S.%f",
|
||||
"%Y-%m-%d %H:%M:%S.%f",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
]
|
||||
|
||||
|
||||
def parse(date_str: str) -> Optional[datetime]:
|
||||
"""Parses one section of the date string at a time.
|
||||
|
||||
Args:
|
||||
date_str (str): Accepts any of the formats in qdrant core (see https://github.com/qdrant/qdrant/blob/0ed86ce0575d35930268db19e1f7680287072c58/lib/segment/src/types.rs#L1388-L1410)
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: the datetime if the string is valid, otherwise None
|
||||
"""
|
||||
|
||||
def parse_available_formats(datetime_str: str) -> Optional[datetime]:
|
||||
for fmt in available_formats:
|
||||
try:
|
||||
dt = datetime.strptime(datetime_str, fmt)
|
||||
if dt.tzinfo is None:
|
||||
# Assume UTC if no timezone is provided
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
parsed_dt = parse_available_formats(date_str)
|
||||
if parsed_dt is not None:
|
||||
return parsed_dt
|
||||
|
||||
# Python can't parse timezones containing only hours (+HH), but it can parse timezones with hours and minutes
|
||||
# So we add :00 to the assumed timezone and try parsing it again
|
||||
# dt examples to handle:
|
||||
# "2021-01-01 00:00:00.000+01"
|
||||
# "2021-01-01 00:00:00.000-10"
|
||||
return parse_available_formats(date_str + ":00")
|
||||
@@ -0,0 +1,302 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.http import models
|
||||
|
||||
EPSILON = 1.1920929e-7 # https://doc.rust-lang.org/std/f32/constant.EPSILON.html
|
||||
# https://github.com/qdrant/qdrant/blob/7164ac4a5987d28f1c93f5712aef8e09e7d93555/lib/segment/src/spaces/simple_avx.rs#L99C10-L99C10
|
||||
|
||||
|
||||
class DistanceOrder(str, Enum):
|
||||
BIGGER_IS_BETTER = "bigger_is_better"
|
||||
SMALLER_IS_BETTER = "smaller_is_better"
|
||||
|
||||
|
||||
class RecoQuery:
|
||||
def __init__(
|
||||
self,
|
||||
positive: Optional[list[list[float]]] = None,
|
||||
negative: Optional[list[list[float]]] = None,
|
||||
strategy: Optional[models.RecommendStrategy] = None,
|
||||
):
|
||||
assert strategy is not None, "Recommend strategy must be provided"
|
||||
|
||||
self.strategy = strategy
|
||||
positive = positive if positive is not None else []
|
||||
negative = negative if negative is not None else []
|
||||
|
||||
self.positive: list[types.NumpyArray] = [np.array(vector) for vector in positive]
|
||||
self.negative: list[types.NumpyArray] = [np.array(vector) for vector in negative]
|
||||
|
||||
assert not np.isnan(self.positive).any(), "Positive vectors must not contain NaN"
|
||||
assert not np.isnan(self.negative).any(), "Negative vectors must not contain NaN"
|
||||
|
||||
|
||||
class ContextPair:
|
||||
def __init__(self, positive: list[float], negative: list[float]):
|
||||
self.positive: types.NumpyArray = np.array(positive)
|
||||
self.negative: types.NumpyArray = np.array(negative)
|
||||
|
||||
assert not np.isnan(self.positive).any(), "Positive vector must not contain NaN"
|
||||
assert not np.isnan(self.negative).any(), "Negative vector must not contain NaN"
|
||||
|
||||
|
||||
class DiscoveryQuery:
|
||||
def __init__(self, target: list[float], context: list[ContextPair]):
|
||||
self.target: types.NumpyArray = np.array(target)
|
||||
self.context = context
|
||||
|
||||
assert not np.isnan(self.target).any(), "Target vector must not contain NaN"
|
||||
|
||||
|
||||
class ContextQuery:
|
||||
def __init__(self, context_pairs: list[ContextPair]):
|
||||
self.context_pairs = context_pairs
|
||||
|
||||
|
||||
DenseQueryVector = Union[
|
||||
DiscoveryQuery,
|
||||
ContextQuery,
|
||||
RecoQuery,
|
||||
]
|
||||
|
||||
|
||||
def distance_to_order(distance: models.Distance) -> DistanceOrder:
|
||||
"""
|
||||
Convert distance to order
|
||||
Args:
|
||||
distance: distance to convert
|
||||
Returns:
|
||||
order
|
||||
"""
|
||||
if distance == models.Distance.EUCLID:
|
||||
return DistanceOrder.SMALLER_IS_BETTER
|
||||
elif distance == models.Distance.MANHATTAN:
|
||||
return DistanceOrder.SMALLER_IS_BETTER
|
||||
|
||||
return DistanceOrder.BIGGER_IS_BETTER
|
||||
|
||||
|
||||
def cosine_similarity(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate cosine distance between query and vectors
|
||||
Args:
|
||||
query: query vector
|
||||
vectors: vectors to calculate distance with
|
||||
Returns:
|
||||
distances
|
||||
"""
|
||||
vectors_norm = np.linalg.norm(vectors, axis=-1)[:, np.newaxis]
|
||||
vectors /= np.where(vectors_norm != 0.0, vectors_norm, EPSILON)
|
||||
|
||||
if len(query.shape) == 1:
|
||||
query_norm = np.linalg.norm(query)
|
||||
query /= np.where(query_norm != 0.0, query_norm, EPSILON)
|
||||
return np.dot(vectors, query)
|
||||
|
||||
query_norm = np.linalg.norm(query, axis=-1)[:, np.newaxis]
|
||||
query /= np.where(query_norm != 0.0, query_norm, EPSILON)
|
||||
return np.dot(query, vectors.T)
|
||||
|
||||
|
||||
def dot_product(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate dot product between query and vectors
|
||||
Args:
|
||||
query: query vector.
|
||||
vectors: vectors to calculate distance with
|
||||
Returns:
|
||||
distances
|
||||
"""
|
||||
if len(query.shape) == 1:
|
||||
return np.dot(vectors, query)
|
||||
else:
|
||||
return np.dot(query, vectors.T)
|
||||
|
||||
|
||||
def euclidean_distance(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate euclidean distance between query and vectors
|
||||
Args:
|
||||
query: query vector.
|
||||
vectors: vectors to calculate distance with
|
||||
Returns:
|
||||
distances
|
||||
"""
|
||||
if len(query.shape) == 1:
|
||||
return np.linalg.norm(vectors - query, axis=-1)
|
||||
else:
|
||||
return np.linalg.norm(vectors - query[:, np.newaxis], axis=-1)
|
||||
|
||||
|
||||
def manhattan_distance(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate manhattan distance between query and vectors
|
||||
Args:
|
||||
query: query vector.
|
||||
vectors: vectors to calculate distance with
|
||||
Returns:
|
||||
distances
|
||||
"""
|
||||
if len(query.shape) == 1:
|
||||
return np.sum(np.abs(vectors - query), axis=-1)
|
||||
else:
|
||||
return np.sum(np.abs(vectors - query[:, np.newaxis]), axis=-1)
|
||||
|
||||
|
||||
def calculate_distance(
|
||||
query: types.NumpyArray, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
assert not np.isnan(query).any(), "Query vector must not contain NaN"
|
||||
|
||||
if distance_type == models.Distance.COSINE:
|
||||
return cosine_similarity(query, vectors)
|
||||
elif distance_type == models.Distance.DOT:
|
||||
return dot_product(query, vectors)
|
||||
elif distance_type == models.Distance.EUCLID:
|
||||
return euclidean_distance(query, vectors)
|
||||
elif distance_type == models.Distance.MANHATTAN:
|
||||
return manhattan_distance(query, vectors)
|
||||
else:
|
||||
raise ValueError(f"Unknown distance type {distance_type}")
|
||||
|
||||
|
||||
def calculate_distance_core(
|
||||
query: types.NumpyArray, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate same internal distances as in core, rather than the final displayed distance
|
||||
"""
|
||||
assert not np.isnan(query).any(), "Query vector must not contain NaN"
|
||||
|
||||
if distance_type == models.Distance.EUCLID:
|
||||
return -np.square(vectors - query, dtype=np.float32).sum(axis=1, dtype=np.float32)
|
||||
if distance_type == models.Distance.MANHATTAN:
|
||||
return -np.abs(vectors - query, dtype=np.float32).sum(axis=1, dtype=np.float32)
|
||||
else:
|
||||
return calculate_distance(query, vectors, distance_type)
|
||||
|
||||
|
||||
def fast_sigmoid(x: np.float32) -> np.float32:
|
||||
if np.isnan(x) or np.isinf(x):
|
||||
# To avoid divisions on NaNs or inf, which gets: RuntimeWarning: invalid value encountered in scalar divide
|
||||
return x
|
||||
return x / np.add(1.0, abs(x))
|
||||
|
||||
|
||||
def scaled_fast_sigmoid(x: np.float32) -> np.float32:
|
||||
return 0.5 * (np.add(fast_sigmoid(x), 1.0))
|
||||
|
||||
|
||||
def calculate_recommend_best_scores(
|
||||
query: RecoQuery, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
def get_best_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
|
||||
vector_count = vectors.shape[0]
|
||||
|
||||
# Get scores to all examples
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_distance_core(example, vectors, distance_type)
|
||||
scores.append(score)
|
||||
|
||||
# Keep only max for each vector
|
||||
if len(scores) == 0:
|
||||
scores.append(np.full(vector_count, -np.inf))
|
||||
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
|
||||
|
||||
return best_scores
|
||||
|
||||
pos = get_best_scores(query.positive)
|
||||
neg = get_best_scores(query.negative)
|
||||
|
||||
# Choose from best positive or best negative,
|
||||
# in in both cases we apply sigmoid and then negate depending on the order
|
||||
return np.where(
|
||||
pos > neg,
|
||||
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
|
||||
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
|
||||
)
|
||||
|
||||
|
||||
def calculate_recommend_sum_scores(
|
||||
query: RecoQuery, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
def get_sum_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
|
||||
vector_count = vectors.shape[0]
|
||||
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_distance_core(example, vectors, distance_type)
|
||||
scores.append(score)
|
||||
|
||||
if len(scores) == 0:
|
||||
scores.append(np.zeros(vector_count))
|
||||
|
||||
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
|
||||
|
||||
return sum_scores
|
||||
|
||||
pos = get_sum_scores(query.positive)
|
||||
neg = get_sum_scores(query.negative)
|
||||
|
||||
return pos - neg
|
||||
|
||||
|
||||
def calculate_discovery_ranks(
|
||||
context: list[ContextPair],
|
||||
vectors: types.NumpyArray,
|
||||
distance_type: models.Distance,
|
||||
) -> types.NumpyArray:
|
||||
overall_ranks = np.zeros(vectors.shape[0], dtype=np.int32)
|
||||
for pair in context:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_distance_core(pair.positive, vectors, distance_type)
|
||||
neg = calculate_distance_core(pair.negative, vectors, distance_type)
|
||||
|
||||
pair_ranks = np.array(
|
||||
[
|
||||
1 if is_bigger else 0 if is_equal else -1
|
||||
for is_bigger, is_equal in zip(pos > neg, pos == neg)
|
||||
]
|
||||
)
|
||||
|
||||
overall_ranks += pair_ranks
|
||||
|
||||
return overall_ranks
|
||||
|
||||
|
||||
def calculate_discovery_scores(
|
||||
query: DiscoveryQuery, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
ranks = calculate_discovery_ranks(query.context, vectors, distance_type)
|
||||
|
||||
# Get distances to target
|
||||
distances_to_target = calculate_distance_core(query.target, vectors, distance_type)
|
||||
|
||||
sigmoided_distances = np.fromiter(
|
||||
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
|
||||
)
|
||||
|
||||
return ranks + sigmoided_distances
|
||||
|
||||
|
||||
def calculate_context_scores(
|
||||
query: ContextQuery, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
overall_scores = np.zeros(vectors.shape[0], dtype=np.float32)
|
||||
for pair in query.context_pairs:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_distance_core(pair.positive, vectors, distance_type)
|
||||
neg = calculate_distance_core(pair.negative, vectors, distance_type)
|
||||
|
||||
difference = pos - neg - EPSILON
|
||||
pair_scores = np.fromiter(
|
||||
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
|
||||
)
|
||||
overall_scores += pair_scores
|
||||
|
||||
return overall_scores
|
||||
@@ -0,0 +1,90 @@
|
||||
from math import asin, cos, radians, sin, sqrt
|
||||
|
||||
# Radius of earth in meters, [as recommended by the IUGG](ftp://athena.fsv.cvut.cz/ZFG/grs80-Moritz.pdf)
|
||||
MEAN_EARTH_RADIUS = 6371008.8
|
||||
|
||||
|
||||
def geo_distance(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
|
||||
"""
|
||||
Calculate distance between two points on Earth using Haversine formula.
|
||||
|
||||
Args:
|
||||
lon1: longitude of first point
|
||||
lat1: latitude of first point
|
||||
lon2: longitude of second point
|
||||
lat2: latitude of second point
|
||||
|
||||
Returns:
|
||||
distance in meters
|
||||
"""
|
||||
|
||||
# convert decimal degrees to radians
|
||||
lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
|
||||
# haversine formula
|
||||
dlon = lon2 - lon1
|
||||
dlat = lat2 - lat1
|
||||
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
|
||||
c = 2 * asin(sqrt(a))
|
||||
|
||||
return MEAN_EARTH_RADIUS * c
|
||||
|
||||
|
||||
def test_geo_distance() -> None:
|
||||
moscow = {"lon": 37.6173, "lat": 55.7558}
|
||||
london = {"lon": -0.1278, "lat": 51.5074}
|
||||
berlin = {"lon": 13.4050, "lat": 52.5200}
|
||||
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], moscow["lon"], moscow["lat"]) < 1.0
|
||||
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], london["lon"], london["lat"]) > 2400 * 1000
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], london["lon"], london["lat"]) < 2600 * 1000
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], berlin["lon"], berlin["lat"]) > 1600 * 1000
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], berlin["lon"], berlin["lat"]) < 1650 * 1000
|
||||
|
||||
|
||||
def boolean_point_in_polygon(
|
||||
point: tuple[float, float],
|
||||
exterior: list[tuple[float, float]],
|
||||
interiors: list[list[tuple[float, float]]],
|
||||
) -> bool:
|
||||
inside_poly = False
|
||||
|
||||
if in_ring(point, exterior, True):
|
||||
in_hole = False
|
||||
k = 0
|
||||
while k < len(interiors) and not in_hole:
|
||||
if in_ring(point, interiors[k], False):
|
||||
in_hole = True
|
||||
k += 1
|
||||
if not in_hole:
|
||||
inside_poly = True
|
||||
|
||||
return inside_poly
|
||||
|
||||
|
||||
def in_ring(
|
||||
pt: tuple[float, float], ring: list[tuple[float, float]], ignore_boundary: bool
|
||||
) -> bool:
|
||||
is_inside = False
|
||||
if ring[0][0] == ring[len(ring) - 1][0] and ring[0][1] == ring[len(ring) - 1][1]:
|
||||
ring = ring[0 : len(ring) - 1]
|
||||
j = len(ring) - 1
|
||||
for i in range(0, len(ring)):
|
||||
xi = ring[i][0]
|
||||
yi = ring[i][1]
|
||||
xj = ring[j][0]
|
||||
yj = ring[j][1]
|
||||
on_boundary = (
|
||||
(pt[1] * (xi - xj) + yi * (xj - pt[0]) + yj * (pt[0] - xi) == 0)
|
||||
and ((xi - pt[0]) * (xj - pt[0]) <= 0)
|
||||
and ((yi - pt[1]) * (yj - pt[1]) <= 0)
|
||||
)
|
||||
if on_boundary:
|
||||
return not ignore_boundary
|
||||
intersect = ((yi > pt[1]) != (yj > pt[1])) and (
|
||||
pt[0] < (xj - xi) * (pt[1] - yi) / (yj - yi) + xi
|
||||
)
|
||||
if intersect:
|
||||
is_inside = not is_inside
|
||||
j = i
|
||||
return is_inside
|
||||
@@ -0,0 +1,151 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class JsonPathItemType(str, Enum):
|
||||
KEY = "key"
|
||||
INDEX = "index"
|
||||
WILDCARD_INDEX = "wildcard_index"
|
||||
|
||||
|
||||
class JsonPathItem(BaseModel):
|
||||
item_type: JsonPathItemType
|
||||
index: Optional[int] = (
|
||||
None # split into index and key instead of using Union, because pydantic coerces
|
||||
)
|
||||
# int to str even in case of Union[int, str]. Tested with pydantic==1.10.14
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
def parse_json_path(key: str) -> list[JsonPathItem]:
|
||||
"""Parse and validate json path
|
||||
|
||||
Args:
|
||||
key: json path
|
||||
|
||||
Returns:
|
||||
list[JsonPathItem]: json path split into separate keys
|
||||
|
||||
Raises:
|
||||
ValueError: if json path is invalid or empty
|
||||
|
||||
Examples:
|
||||
|
||||
# >>> parse_json_path("a[0][1].b")
|
||||
# [
|
||||
# JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='a'),
|
||||
# JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=0),
|
||||
# JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=1),
|
||||
# JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='b')
|
||||
# ]
|
||||
"""
|
||||
keys = []
|
||||
json_path = key
|
||||
while json_path:
|
||||
json_path_item, rest = match_quote(json_path)
|
||||
if json_path_item is None:
|
||||
json_path_item, rest = match_key(json_path)
|
||||
|
||||
if json_path_item is None:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
keys.append(json_path_item)
|
||||
brackets_chunks, rest = match_brackets(rest)
|
||||
keys.extend(brackets_chunks)
|
||||
json_path = trunk_sep(rest)
|
||||
if not json_path:
|
||||
return keys
|
||||
continue
|
||||
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
|
||||
def trunk_sep(path: str) -> str:
|
||||
if not path:
|
||||
return path
|
||||
|
||||
if len(path) == 1:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
if path.startswith("."):
|
||||
return path[1:]
|
||||
|
||||
elif path.startswith("["):
|
||||
return path
|
||||
else:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
|
||||
def match_quote(path: str) -> tuple[Optional[JsonPathItem], str]:
|
||||
if not path.startswith('"'):
|
||||
return None, path
|
||||
|
||||
left_quote_pos = 0
|
||||
right_quote_pos = path.find('"', 1)
|
||||
|
||||
if path.count('"') < 2:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
return (
|
||||
JsonPathItem(
|
||||
item_type=JsonPathItemType.KEY, key=path[left_quote_pos + 1 : right_quote_pos]
|
||||
),
|
||||
path[right_quote_pos + 1 :],
|
||||
)
|
||||
|
||||
|
||||
def match_key(path: str) -> tuple[Optional[JsonPathItem], str]:
|
||||
char_counter = 0
|
||||
for char in path:
|
||||
if not char.isalnum() and char not in ["_", "-"]:
|
||||
break
|
||||
char_counter += 1
|
||||
if char_counter == 0:
|
||||
return None, path
|
||||
|
||||
return (
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key=path[:char_counter]),
|
||||
path[char_counter:],
|
||||
)
|
||||
|
||||
|
||||
def match_brackets(rest: str) -> tuple[list[JsonPathItem], str]:
|
||||
keys = []
|
||||
|
||||
while rest:
|
||||
json_path_item, rest = _match_brackets(rest)
|
||||
|
||||
if json_path_item is None:
|
||||
break
|
||||
|
||||
keys.append(json_path_item)
|
||||
|
||||
return keys, rest
|
||||
|
||||
|
||||
def _match_brackets(path: str) -> tuple[Optional[JsonPathItem], str]:
|
||||
if "[" not in path or not path.startswith("["):
|
||||
return None, path
|
||||
|
||||
left_bracket_pos = 0
|
||||
right_bracket_pos = path.find("]", left_bracket_pos + 1)
|
||||
|
||||
if right_bracket_pos == -1:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
if right_bracket_pos == (left_bracket_pos + 1):
|
||||
return (
|
||||
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX),
|
||||
path[right_bracket_pos + 1 :],
|
||||
)
|
||||
|
||||
try:
|
||||
index = int(path[left_bracket_pos + 1 : right_bracket_pos])
|
||||
return (
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=index),
|
||||
path[right_bracket_pos + 1 :],
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError("Invalid path") from e
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,220 @@
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.local.distances import (
|
||||
calculate_distance,
|
||||
scaled_fast_sigmoid,
|
||||
EPSILON,
|
||||
fast_sigmoid,
|
||||
)
|
||||
|
||||
|
||||
class MultiRecoQuery:
|
||||
def __init__(
|
||||
self,
|
||||
positive: Optional[list[list[list[float]]]] = None, # list of matrices
|
||||
negative: Optional[list[list[list[float]]]] = None, # list of matrices
|
||||
strategy: Optional[models.RecommendStrategy] = None,
|
||||
):
|
||||
assert strategy is not None, "Recommend strategy must be provided"
|
||||
|
||||
self.strategy = strategy
|
||||
|
||||
positive = positive if positive is not None else []
|
||||
negative = negative if negative is not None else []
|
||||
|
||||
for vector in positive:
|
||||
assert not np.isnan(vector).any(), "Positive vectors must not contain NaN"
|
||||
for vector in negative:
|
||||
assert not np.isnan(vector).any(), "Negative vectors must not contain NaN"
|
||||
|
||||
self.positive: list[types.NumpyArray] = [np.array(vector) for vector in positive]
|
||||
self.negative: list[types.NumpyArray] = [np.array(vector) for vector in negative]
|
||||
|
||||
|
||||
class MultiContextPair:
|
||||
def __init__(self, positive: list[list[float]], negative: list[list[float]]):
|
||||
self.positive: types.NumpyArray = np.array(positive)
|
||||
self.negative: types.NumpyArray = np.array(negative)
|
||||
|
||||
assert not np.isnan(self.positive).any(), "Positive vector must not contain NaN"
|
||||
assert not np.isnan(self.negative).any(), "Negative vector must not contain NaN"
|
||||
|
||||
|
||||
class MultiDiscoveryQuery:
|
||||
def __init__(self, target: list[list[float]], context: list[MultiContextPair]):
|
||||
self.target: types.NumpyArray = np.array(target)
|
||||
self.context = context
|
||||
|
||||
assert not np.isnan(self.target).any(), "Target vector must not contain NaN"
|
||||
|
||||
|
||||
class MultiContextQuery:
|
||||
def __init__(self, context_pairs: list[MultiContextPair]):
|
||||
self.context_pairs = context_pairs
|
||||
|
||||
|
||||
MultiQueryVector = Union[
|
||||
MultiDiscoveryQuery,
|
||||
MultiContextQuery,
|
||||
MultiRecoQuery,
|
||||
]
|
||||
|
||||
|
||||
def calculate_multi_distance(
|
||||
query_matrix: types.NumpyArray,
|
||||
matrices: list[types.NumpyArray],
|
||||
distance_type: models.Distance,
|
||||
) -> types.NumpyArray:
|
||||
assert not np.isnan(query_matrix).any(), "Query matrix must not contain NaN"
|
||||
assert len(query_matrix.shape) == 2, "Query must be a matrix"
|
||||
|
||||
distances = calculate_multi_distance_core(query_matrix, matrices, distance_type)
|
||||
|
||||
if distance_type == models.Distance.EUCLID:
|
||||
distances = np.sqrt(np.abs(distances))
|
||||
elif distance_type == models.Distance.MANHATTAN:
|
||||
distances = np.abs(distances)
|
||||
return distances
|
||||
|
||||
|
||||
def calculate_multi_distance_core(
|
||||
query_matrix: types.NumpyArray,
|
||||
matrices: list[types.NumpyArray],
|
||||
distance_type: models.Distance,
|
||||
) -> types.NumpyArray:
|
||||
def euclidean(q: types.NumpyArray, m: types.NumpyArray, *_: Any) -> types.NumpyArray:
|
||||
return -np.square(m - q, dtype=np.float32).sum(axis=-1, dtype=np.float32)
|
||||
|
||||
def manhattan(q: types.NumpyArray, m: types.NumpyArray, *_: Any) -> types.NumpyArray:
|
||||
return -np.abs(m - q, dtype=np.float32).sum(axis=-1, dtype=np.float32)
|
||||
|
||||
assert not np.isnan(query_matrix).any(), "Query vector must not contain NaN"
|
||||
similarities: list[float] = []
|
||||
|
||||
# Euclid and Manhattan are the only ones which are calculated differently during candidate selection
|
||||
# in core, here we make sure to use the same internal similarity function as in core.
|
||||
if distance_type in [models.Distance.EUCLID, models.Distance.MANHATTAN]:
|
||||
query_matrix = query_matrix[:, np.newaxis]
|
||||
dist_func = euclidean if distance_type == models.Distance.EUCLID else manhattan
|
||||
else:
|
||||
dist_func = calculate_distance # type: ignore
|
||||
|
||||
for matrix in matrices:
|
||||
sim_matrix = dist_func(query_matrix, matrix, distance_type)
|
||||
similarity = float(np.sum(np.max(sim_matrix, axis=-1)))
|
||||
similarities.append(similarity)
|
||||
return np.array(similarities)
|
||||
|
||||
|
||||
def calculate_multi_recommend_best_scores(
|
||||
query: MultiRecoQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
def get_best_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
|
||||
matrix_count = len(matrices)
|
||||
|
||||
# Get scores to all examples
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_multi_distance_core(example, matrices, distance_type)
|
||||
scores.append(score)
|
||||
|
||||
# Keep only max for each vector
|
||||
if len(scores) == 0:
|
||||
scores.append(np.full(matrix_count, -np.inf))
|
||||
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
|
||||
|
||||
return best_scores
|
||||
|
||||
pos = get_best_scores(query.positive)
|
||||
neg = get_best_scores(query.negative)
|
||||
|
||||
# Choose from the best positive or the best negative,
|
||||
# in both cases we apply sigmoid and then negate depending on the order
|
||||
return np.where(
|
||||
pos > neg,
|
||||
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
|
||||
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
|
||||
)
|
||||
|
||||
|
||||
def calculate_multi_recommend_sum_scores(
|
||||
query: MultiRecoQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
def get_sum_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
|
||||
matrix_count = len(matrices)
|
||||
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_multi_distance_core(example, matrices, distance_type)
|
||||
scores.append(score)
|
||||
|
||||
if len(scores) == 0:
|
||||
scores.append(np.zeros(matrix_count))
|
||||
|
||||
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
|
||||
return sum_scores
|
||||
|
||||
pos = get_sum_scores(query.positive)
|
||||
neg = get_sum_scores(query.negative)
|
||||
|
||||
return pos - neg
|
||||
|
||||
|
||||
def calculate_multi_discovery_ranks(
|
||||
context: list[MultiContextPair],
|
||||
matrices: list[types.NumpyArray],
|
||||
distance_type: models.Distance,
|
||||
) -> types.NumpyArray:
|
||||
overall_ranks: types.NumpyArray = np.zeros(len(matrices), dtype=np.int32)
|
||||
for pair in context:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_multi_distance_core(pair.positive, matrices, distance_type)
|
||||
neg = calculate_multi_distance_core(pair.negative, matrices, distance_type)
|
||||
|
||||
pair_ranks = np.array(
|
||||
[
|
||||
1 if is_bigger else 0 if is_equal else -1
|
||||
for is_bigger, is_equal in zip(pos > neg, pos == neg)
|
||||
]
|
||||
)
|
||||
|
||||
overall_ranks += pair_ranks
|
||||
|
||||
return overall_ranks
|
||||
|
||||
|
||||
def calculate_multi_discovery_scores(
|
||||
query: MultiDiscoveryQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
ranks = calculate_multi_discovery_ranks(query.context, matrices, distance_type)
|
||||
|
||||
# Get distances to target
|
||||
distances_to_target = calculate_multi_distance_core(query.target, matrices, distance_type)
|
||||
|
||||
sigmoided_distances = np.fromiter(
|
||||
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
|
||||
)
|
||||
|
||||
return ranks + sigmoided_distances
|
||||
|
||||
|
||||
def calculate_multi_context_scores(
|
||||
query: MultiContextQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
overall_scores: types.NumpyArray = np.zeros(len(matrices), dtype=np.float32)
|
||||
for pair in query.context_pairs:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_multi_distance_core(pair.positive, matrices, distance_type)
|
||||
neg = calculate_multi_distance_core(pair.negative, matrices, distance_type)
|
||||
|
||||
difference = pos - neg - EPSILON
|
||||
pair_scores = np.fromiter(
|
||||
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
|
||||
)
|
||||
overall_scores += pair_scores
|
||||
|
||||
return overall_scores
|
||||
@@ -0,0 +1,30 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from qdrant_client.http.models import OrderValue
|
||||
from qdrant_client.local.datetime_utils import parse
|
||||
|
||||
MICROS_PER_SECOND = 1_000_000
|
||||
|
||||
|
||||
def datetime_to_microseconds(dt: datetime) -> int:
|
||||
return int(dt.timestamp() * MICROS_PER_SECOND)
|
||||
|
||||
|
||||
def to_order_value(value: Union[None, OrderValue, datetime, str]) -> Optional[OrderValue]:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# check if OrderValue
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
|
||||
if isinstance(value, datetime):
|
||||
return datetime_to_microseconds(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
dt = parse(value)
|
||||
if dt is not None:
|
||||
return datetime_to_microseconds(dt)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,337 @@
|
||||
from datetime import date, datetime, timezone
|
||||
from typing import Any, Optional, Union, Dict
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.local import datetime_utils
|
||||
from qdrant_client.local.geo import boolean_point_in_polygon, geo_distance
|
||||
from qdrant_client.local.payload_value_extractor import value_by_key
|
||||
from qdrant_client.conversions import common_types as types
|
||||
|
||||
|
||||
def get_value_counts(values: list[Any]) -> list[int]:
|
||||
counts = []
|
||||
|
||||
if all(value is None for value in values):
|
||||
counts.append(0)
|
||||
else:
|
||||
for value in values:
|
||||
if value is None:
|
||||
counts.append(0)
|
||||
elif hasattr(value, "__len__") and not isinstance(value, str):
|
||||
counts.append(len(value))
|
||||
else:
|
||||
counts.append(1)
|
||||
return counts
|
||||
|
||||
|
||||
def check_values_count(condition: models.ValuesCount, values: Optional[list[Any]]) -> bool:
|
||||
if values is None:
|
||||
return False
|
||||
|
||||
counts = get_value_counts(values)
|
||||
|
||||
if condition.lt is not None and all(count >= condition.lt for count in counts):
|
||||
return False
|
||||
if condition.lte is not None and all(count > condition.lte for count in counts):
|
||||
return False
|
||||
if condition.gt is not None and all(count <= condition.gt for count in counts):
|
||||
return False
|
||||
if condition.gte is not None and all(count < condition.gte for count in counts):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_geo_radius(condition: models.GeoRadius, values: Any) -> bool:
|
||||
if isinstance(values, dict) and "lat" in values and "lon" in values:
|
||||
lat = values["lat"]
|
||||
lon = values["lon"]
|
||||
|
||||
distance = geo_distance(
|
||||
lon1=lon,
|
||||
lat1=lat,
|
||||
lon2=condition.center.lon,
|
||||
lat2=condition.center.lat,
|
||||
)
|
||||
|
||||
return distance < condition.radius
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_geo_bounding_box(condition: models.GeoBoundingBox, values: Any) -> bool:
|
||||
if isinstance(values, dict) and "lat" in values and "lon" in values:
|
||||
lat = values["lat"]
|
||||
lon = values["lon"]
|
||||
|
||||
# handle anti-meridian crossing case
|
||||
if condition.top_left.lon > condition.bottom_right.lon:
|
||||
longitude_condition = (
|
||||
condition.top_left.lon <= lon <= 180 or -180 <= lon <= condition.bottom_right.lon
|
||||
)
|
||||
else:
|
||||
longitude_condition = condition.top_left.lon <= lon <= condition.bottom_right.lon
|
||||
|
||||
latitude_condition = condition.top_left.lat >= lat >= condition.bottom_right.lat
|
||||
|
||||
return longitude_condition and latitude_condition
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_geo_polygon(condition: models.GeoPolygon, values: Any) -> bool:
|
||||
if isinstance(values, dict) and "lat" in values and "lon" in values:
|
||||
lat = values["lat"]
|
||||
lon = values["lon"]
|
||||
exterior = [(point.lat, point.lon) for point in condition.exterior.points]
|
||||
interiors = []
|
||||
if condition.interiors is not None:
|
||||
interiors = [
|
||||
[(point.lat, point.lon) for point in interior.points]
|
||||
for interior in condition.interiors
|
||||
]
|
||||
return boolean_point_in_polygon(point=(lat, lon), exterior=exterior, interiors=interiors)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_range_interface(condition: models.RangeInterface, value: Any) -> bool:
|
||||
if isinstance(condition, models.Range):
|
||||
return check_range(condition, value)
|
||||
if isinstance(condition, models.DatetimeRange):
|
||||
return check_datetime_range(condition, value)
|
||||
return False
|
||||
|
||||
|
||||
def check_range(condition: models.Range, value: Any) -> bool:
|
||||
if not isinstance(value, (int, float)):
|
||||
return False
|
||||
return (
|
||||
(condition.lt is None or value < condition.lt)
|
||||
and (condition.lte is None or value <= condition.lte)
|
||||
and (condition.gt is None or value > condition.gt)
|
||||
and (condition.gte is None or value >= condition.gte)
|
||||
)
|
||||
|
||||
|
||||
def check_datetime_range(condition: models.DatetimeRange, value: Any) -> bool:
|
||||
def make_condition_tz_aware(dt: Optional[Union[datetime, date]]) -> Optional[datetime]:
|
||||
if isinstance(dt, date) and not isinstance(dt, datetime):
|
||||
dt = datetime.combine(dt, datetime.min.time())
|
||||
|
||||
if dt is None or dt.tzinfo is not None:
|
||||
return dt
|
||||
|
||||
# Assume UTC if no timezone is provided
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
|
||||
dt = datetime_utils.parse(value)
|
||||
|
||||
if dt is None:
|
||||
return False
|
||||
|
||||
condition.lt = make_condition_tz_aware(condition.lt)
|
||||
condition.lte = make_condition_tz_aware(condition.lte)
|
||||
condition.gt = make_condition_tz_aware(condition.gt)
|
||||
condition.gte = make_condition_tz_aware(condition.gte)
|
||||
|
||||
return (
|
||||
(condition.lt is None or dt < condition.lt)
|
||||
and (condition.lte is None or dt <= condition.lte)
|
||||
and (condition.gt is None or dt > condition.gt)
|
||||
and (condition.gte is None or dt >= condition.gte)
|
||||
)
|
||||
|
||||
|
||||
def check_match(condition: models.Match, value: Any) -> bool:
|
||||
if isinstance(condition, models.MatchValue):
|
||||
return value == condition.value
|
||||
if isinstance(condition, models.MatchText):
|
||||
return value is not None and condition.text in value
|
||||
if isinstance(condition, models.MatchTextAny):
|
||||
return value is not None and any(word in value for word in condition.text_any.split())
|
||||
if isinstance(condition, models.MatchAny):
|
||||
return value in condition.any
|
||||
if isinstance(condition, models.MatchExcept):
|
||||
return value not in condition.except_
|
||||
raise ValueError(f"Unknown match condition: {condition}")
|
||||
|
||||
|
||||
def check_nested_filter(nested_filter: models.Filter, values: list[Any]) -> bool:
|
||||
return any(check_filter(nested_filter, v, point_id=-1, has_vector={}) for v in values)
|
||||
|
||||
|
||||
def check_condition(
|
||||
condition: models.Condition,
|
||||
payload: dict[str, Any],
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
if isinstance(condition, models.IsNullCondition):
|
||||
values = value_by_key(payload, condition.is_null.key, flat=False)
|
||||
if values is None:
|
||||
return False
|
||||
if any(v is None for v in values):
|
||||
return True
|
||||
elif isinstance(condition, models.IsEmptyCondition):
|
||||
values = value_by_key(payload, condition.is_empty.key, flat=False)
|
||||
if (
|
||||
values is None
|
||||
or len(values) == 0
|
||||
or all((v is None or (isinstance(v, list) and len(v) == 0)) for v in values)
|
||||
):
|
||||
return True
|
||||
elif isinstance(condition, models.HasIdCondition):
|
||||
ids = [str(id_) if isinstance(id_, UUID) else id_ for id_ in condition.has_id]
|
||||
if point_id in ids:
|
||||
return True
|
||||
elif isinstance(condition, models.HasVectorCondition):
|
||||
if condition.has_vector in has_vector and has_vector[condition.has_vector]:
|
||||
return True
|
||||
elif isinstance(condition, models.FieldCondition):
|
||||
values = value_by_key(payload, condition.key)
|
||||
if condition.match is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_match(condition.match, v) for v in values)
|
||||
if condition.range is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_range_interface(condition.range, v) for v in values)
|
||||
if condition.geo_bounding_box is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_geo_bounding_box(condition.geo_bounding_box, v) for v in values)
|
||||
if condition.geo_radius is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_geo_radius(condition.geo_radius, v) for v in values)
|
||||
if condition.values_count is not None:
|
||||
values = value_by_key(payload, condition.key, flat=False)
|
||||
return check_values_count(condition.values_count, values)
|
||||
if condition.geo_polygon is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_geo_polygon(condition.geo_polygon, v) for v in values)
|
||||
elif isinstance(condition, models.NestedCondition):
|
||||
values = value_by_key(payload, condition.nested.key)
|
||||
if values is None:
|
||||
return False
|
||||
return check_nested_filter(condition.nested.filter, values)
|
||||
elif isinstance(condition, models.Filter):
|
||||
return check_filter(condition, payload, point_id, has_vector)
|
||||
else:
|
||||
raise ValueError(f"Unknown condition: {condition}")
|
||||
return False
|
||||
|
||||
|
||||
def check_must(
|
||||
conditions: list[models.Condition],
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
return all(
|
||||
check_condition(condition, payload, point_id, has_vector) for condition in conditions
|
||||
)
|
||||
|
||||
|
||||
def check_must_not(
|
||||
conditions: list[models.Condition],
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
return all(
|
||||
not check_condition(condition, payload, point_id, has_vector) for condition in conditions
|
||||
)
|
||||
|
||||
|
||||
def check_should(
|
||||
conditions: list[models.Condition],
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
return any(
|
||||
check_condition(condition, payload, point_id, has_vector) for condition in conditions
|
||||
)
|
||||
|
||||
|
||||
def check_min_should(
|
||||
conditions: list[models.Condition],
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
vectors: Dict[str, Any],
|
||||
min_count: int,
|
||||
) -> bool:
|
||||
return (
|
||||
sum(check_condition(condition, payload, point_id, vectors) for condition in conditions)
|
||||
>= min_count
|
||||
)
|
||||
|
||||
|
||||
def check_filter(
|
||||
payload_filter: models.Filter,
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
def ensure_condition_list(
|
||||
condition: Union[models.Condition, list[models.Condition]],
|
||||
) -> list[models.Condition]:
|
||||
if isinstance(condition, list):
|
||||
return condition
|
||||
return [condition]
|
||||
|
||||
if payload_filter.must is not None:
|
||||
if not check_must(
|
||||
ensure_condition_list(payload_filter.must), payload, point_id, has_vector
|
||||
):
|
||||
return False
|
||||
if payload_filter.must_not is not None:
|
||||
if not check_must_not(
|
||||
ensure_condition_list(payload_filter.must_not), payload, point_id, has_vector
|
||||
):
|
||||
return False
|
||||
if payload_filter.should is not None:
|
||||
if not check_should(
|
||||
ensure_condition_list(payload_filter.should), payload, point_id, has_vector
|
||||
):
|
||||
return False
|
||||
if payload_filter.min_should is not None:
|
||||
if not check_min_should(
|
||||
payload_filter.min_should.conditions,
|
||||
payload,
|
||||
point_id,
|
||||
has_vector,
|
||||
payload_filter.min_should.min_count,
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def calculate_payload_mask(
|
||||
payloads: list[dict],
|
||||
payload_filter: Optional[models.Filter],
|
||||
ids_inv: list[models.ExtendedPointId],
|
||||
deleted_per_vector: Dict[str, np.ndarray],
|
||||
) -> types.NumpyArray:
|
||||
if payload_filter is None:
|
||||
return np.ones(len(payloads), dtype=bool)
|
||||
|
||||
mask: types.NumpyArray = np.zeros(len(payloads), dtype=bool)
|
||||
for i, payload in enumerate(payloads):
|
||||
has_vector = {}
|
||||
for vector_name, deleted in deleted_per_vector.items():
|
||||
if not deleted[i]:
|
||||
has_vector[vector_name] = True
|
||||
|
||||
if check_filter(payload_filter, payload, ids_inv[i], has_vector):
|
||||
mask[i] = True
|
||||
return mask
|
||||
+92
@@ -0,0 +1,92 @@
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from qdrant_client.local.json_path_parser import (
|
||||
JsonPathItem,
|
||||
JsonPathItemType,
|
||||
parse_json_path,
|
||||
)
|
||||
|
||||
|
||||
def value_by_key(payload: dict[str, Any], key: str, flat: bool = True) -> Optional[list[Any]]:
|
||||
"""
|
||||
Get value from payload by key.
|
||||
Args:
|
||||
payload: arbitrary json-like object
|
||||
flat: If True, extend list of values. If False, append. By default, we use True and flatten the arrays,
|
||||
we need it for filters, however for `count` method we need to keep the arrays as is.
|
||||
key:
|
||||
Key or path to value in payload.
|
||||
Examples:
|
||||
- "name"
|
||||
- "address.city"
|
||||
- "location[].name"
|
||||
- "location[0].name"
|
||||
|
||||
Returns:
|
||||
List of values or None if key not found.
|
||||
"""
|
||||
keys = parse_json_path(key)
|
||||
result = []
|
||||
|
||||
def _get_value(data: Any, k_list: list[JsonPathItem]) -> None:
|
||||
if not k_list:
|
||||
return
|
||||
|
||||
current_key = k_list.pop(0)
|
||||
if len(k_list) == 0:
|
||||
if isinstance(data, dict) and current_key.item_type == JsonPathItemType.KEY:
|
||||
if current_key.key in data:
|
||||
value = data[current_key.key]
|
||||
if isinstance(value, list) and flat:
|
||||
result.extend(value)
|
||||
else:
|
||||
result.append(value)
|
||||
|
||||
elif isinstance(data, list):
|
||||
if current_key.item_type == JsonPathItemType.WILDCARD_INDEX:
|
||||
result.extend(data)
|
||||
|
||||
elif current_key.item_type == JsonPathItemType.INDEX:
|
||||
assert current_key.index is not None
|
||||
|
||||
if current_key.index < len(data):
|
||||
result.append(data[current_key.index])
|
||||
|
||||
elif current_key.item_type == JsonPathItemType.KEY:
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
if current_key.key in data:
|
||||
_get_value(data[current_key.key], k_list.copy())
|
||||
|
||||
elif current_key.item_type == JsonPathItemType.INDEX:
|
||||
assert current_key.index is not None
|
||||
|
||||
if not isinstance(data, list):
|
||||
return
|
||||
|
||||
if current_key.index < len(data):
|
||||
_get_value(data[current_key.index], k_list.copy())
|
||||
|
||||
elif current_key.item_type == JsonPathItemType.WILDCARD_INDEX:
|
||||
if not isinstance(data, list):
|
||||
return
|
||||
|
||||
for item in data:
|
||||
_get_value(item, k_list.copy())
|
||||
|
||||
_get_value(payload, keys)
|
||||
return result if result else None
|
||||
|
||||
|
||||
def parse_uuid(value: Any) -> Optional[uuid.UUID]:
|
||||
"""
|
||||
Parse UUID from value.
|
||||
Args:
|
||||
value: arbitrary value
|
||||
"""
|
||||
try:
|
||||
return uuid.UUID(str(value))
|
||||
except ValueError:
|
||||
return None
|
||||
@@ -0,0 +1,247 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from qdrant_client.local.json_path_parser import JsonPathItem, JsonPathItemType
|
||||
|
||||
|
||||
def set_value_by_key(payload: dict, keys: list[JsonPathItem], value: Any) -> None:
|
||||
"""
|
||||
Set value in payload by key.
|
||||
Args:
|
||||
payload: arbitrary json-like object
|
||||
keys:
|
||||
list of json path items, e.g.:
|
||||
[
|
||||
JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='a'),
|
||||
JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=0),
|
||||
JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=1),
|
||||
JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='b')
|
||||
]
|
||||
|
||||
The original keys could look like this:
|
||||
- "name"
|
||||
- "address.city"
|
||||
- "location[].name"
|
||||
- "location[0].name"
|
||||
|
||||
value: value to set
|
||||
"""
|
||||
Setter.set(payload, keys.copy(), value, None, None)
|
||||
|
||||
|
||||
class Setter:
|
||||
TYPE: Any
|
||||
SETTERS: dict[JsonPathItemType, Type["Setter"]] = {}
|
||||
|
||||
@classmethod
|
||||
def add_setter(cls, item_type: JsonPathItemType, setter: Type["Setter"]) -> None:
|
||||
cls.SETTERS[item_type] = setter
|
||||
|
||||
@classmethod
|
||||
def set(
|
||||
cls,
|
||||
data: Any,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
if not k_list:
|
||||
return
|
||||
|
||||
current_key = k_list.pop(0)
|
||||
cls.SETTERS[current_key.item_type]._set(
|
||||
data,
|
||||
current_key,
|
||||
k_list,
|
||||
value,
|
||||
prev_data,
|
||||
prev_key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _set(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
if isinstance(data, cls.TYPE):
|
||||
cls._set_compatible_types(
|
||||
data=data, current_key=current_key, k_list=k_list, value=value
|
||||
)
|
||||
else:
|
||||
cls._set_incompatible_types(
|
||||
current_key=current_key,
|
||||
k_list=k_list,
|
||||
value=value,
|
||||
prev_data=prev_data,
|
||||
prev_key=prev_key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _set_incompatible_types(
|
||||
cls,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class KeySetter(Setter):
|
||||
TYPE = dict
|
||||
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
if current_key.key not in data:
|
||||
data[current_key.key] = {}
|
||||
|
||||
if len(k_list) == 0:
|
||||
if isinstance(data[current_key.key], dict):
|
||||
data[current_key.key].update(value)
|
||||
else:
|
||||
data[current_key.key] = value
|
||||
else:
|
||||
cls.set(data[current_key.key], k_list.copy(), value, data, current_key)
|
||||
|
||||
@classmethod
|
||||
def _set_incompatible_types(
|
||||
cls,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
assert prev_key is not None
|
||||
|
||||
if len(k_list) == 0:
|
||||
if prev_key.item_type == JsonPathItemType.KEY:
|
||||
prev_data[prev_key.key] = {current_key.key: value}
|
||||
else: # if prev key was WILDCARD, we need to pass INDEX instead with an index set
|
||||
prev_data[prev_key.index] = {current_key.key: value}
|
||||
else:
|
||||
if prev_key.item_type == JsonPathItemType.KEY:
|
||||
prev_data[prev_key.key] = {current_key.key: {}}
|
||||
cls.set(
|
||||
prev_data[prev_key.key][current_key.key],
|
||||
k_list.copy(),
|
||||
value,
|
||||
prev_data[prev_key.key],
|
||||
current_key,
|
||||
)
|
||||
else:
|
||||
prev_data[prev_key.index] = {current_key.key: {}}
|
||||
cls.set(
|
||||
prev_data[prev_key.index][current_key.key],
|
||||
k_list.copy(),
|
||||
value,
|
||||
prev_data[prev_key.index],
|
||||
current_key,
|
||||
)
|
||||
|
||||
|
||||
class _ListSetter(Setter):
|
||||
TYPE = list
|
||||
|
||||
@classmethod
|
||||
def _set_incompatible_types(
|
||||
cls,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
assert prev_key is not None
|
||||
|
||||
if prev_key.item_type == JsonPathItemType.KEY:
|
||||
prev_data[prev_key.key] = []
|
||||
return
|
||||
else:
|
||||
prev_data[prev_key.index] = []
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class IndexSetter(_ListSetter):
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
assert current_key.index is not None
|
||||
|
||||
if current_key.index < len(data):
|
||||
if len(k_list) == 0:
|
||||
if isinstance(data[current_key.index], dict):
|
||||
data[current_key.index].update(value)
|
||||
else:
|
||||
data[current_key.index] = value
|
||||
return
|
||||
|
||||
cls.set(data[current_key.index], k_list.copy(), value, data, current_key)
|
||||
|
||||
|
||||
class WildcardIndexSetter(_ListSetter):
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
if len(k_list) == 0:
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
data[i].update(value)
|
||||
else:
|
||||
data[i] = value
|
||||
else:
|
||||
for i, item in enumerate(data):
|
||||
cls.set(
|
||||
item,
|
||||
k_list.copy(),
|
||||
value,
|
||||
data,
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=i),
|
||||
)
|
||||
|
||||
|
||||
Setter.add_setter(JsonPathItemType.KEY, KeySetter)
|
||||
Setter.add_setter(JsonPathItemType.INDEX, IndexSetter)
|
||||
Setter.add_setter(JsonPathItemType.WILDCARD_INDEX, WildcardIndexSetter)
|
||||
@@ -0,0 +1,175 @@
|
||||
import base64
|
||||
import dbm
|
||||
import logging
|
||||
import pickle
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
STORAGE_FILE_NAME_OLD = "storage.dbm"
|
||||
STORAGE_FILE_NAME = "storage.sqlite"
|
||||
|
||||
|
||||
def try_migrate_to_sqlite(location: str) -> None:
|
||||
dbm_path = Path(location) / STORAGE_FILE_NAME_OLD
|
||||
sql_path = Path(location) / STORAGE_FILE_NAME
|
||||
|
||||
if sql_path.exists():
|
||||
return
|
||||
|
||||
if not dbm_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
dbm_storage = dbm.open(str(dbm_path), "c")
|
||||
|
||||
con = sqlite3.connect(str(sql_path))
|
||||
cur = con.cursor()
|
||||
|
||||
# Create table
|
||||
cur.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
|
||||
|
||||
for key in dbm_storage.keys():
|
||||
value = dbm_storage[key]
|
||||
if isinstance(key, str):
|
||||
key = key.encode("utf-8")
|
||||
key = pickle.loads(key)
|
||||
sqlite_key = CollectionPersistence.encode_key(key)
|
||||
# Insert a row of data
|
||||
cur.execute(
|
||||
"INSERT INTO points VALUES (?, ?)",
|
||||
(
|
||||
sqlite_key,
|
||||
sqlite3.Binary(value),
|
||||
),
|
||||
)
|
||||
con.commit()
|
||||
con.close()
|
||||
dbm_storage.close()
|
||||
dbm_path.unlink()
|
||||
except Exception as e:
|
||||
logging.error("Failed to migrate dbm to sqlite:", e)
|
||||
logging.error(
|
||||
"Please try to use previous version of qdrant-client or re-create collection"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class CollectionPersistence:
|
||||
CHECK_SAME_THREAD: Optional[bool] = None
|
||||
|
||||
@classmethod
|
||||
def encode_key(cls, key: models.ExtendedPointId) -> str:
|
||||
return base64.b64encode(pickle.dumps(key)).decode("utf-8")
|
||||
|
||||
def __init__(self, location: str, force_disable_check_same_thread: bool = False):
|
||||
"""
|
||||
Create or load a collection from the local storage.
|
||||
Args:
|
||||
location: path to the collection directory.
|
||||
"""
|
||||
|
||||
try_migrate_to_sqlite(location)
|
||||
|
||||
self.location = Path(location) / STORAGE_FILE_NAME
|
||||
self.location.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if self.CHECK_SAME_THREAD is None and force_disable_check_same_thread is False:
|
||||
with sqlite3.connect(":memory:") as tmp_conn:
|
||||
# it is unsafe to use `sqlite3.threadsafety` until python3.11 since it was hardcoded to 1, thus we
|
||||
# need to fetch threadsafe with a query
|
||||
# THREADSAFE = 0: Threads may not share the module
|
||||
# THREADSAFE = 1: Threads may share the module, connections and cursors. Default for Linux.
|
||||
# THREADSAFE = 2: Threads may share the module, but not connections. Default for macOS.
|
||||
threadsafe = tmp_conn.execute(
|
||||
"select * from pragma_compile_options where compile_options like 'THREADSAFE=%'"
|
||||
).fetchone()[0]
|
||||
self.__class__.CHECK_SAME_THREAD = threadsafe != "THREADSAFE=1"
|
||||
|
||||
if force_disable_check_same_thread:
|
||||
self.__class__.CHECK_SAME_THREAD = False
|
||||
|
||||
self.storage = sqlite3.connect(
|
||||
str(self.location), check_same_thread=self.CHECK_SAME_THREAD # type: ignore
|
||||
)
|
||||
|
||||
self._ensure_table()
|
||||
|
||||
def close(self) -> None:
|
||||
self.storage.close()
|
||||
|
||||
def _ensure_table(self) -> None:
|
||||
cursor = self.storage.cursor()
|
||||
cursor.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
|
||||
self.storage.commit()
|
||||
|
||||
def persist(self, point: models.PointStruct) -> None:
|
||||
"""
|
||||
Persist a point in the local storage.
|
||||
Args:
|
||||
point: point to persist
|
||||
"""
|
||||
key = self.encode_key(point.id)
|
||||
value = pickle.dumps(point)
|
||||
|
||||
cursor = self.storage.cursor()
|
||||
# Insert or update by key
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO points VALUES (?, ?)",
|
||||
(
|
||||
key,
|
||||
sqlite3.Binary(value),
|
||||
),
|
||||
)
|
||||
|
||||
self.storage.commit()
|
||||
|
||||
def delete(self, point_id: models.ExtendedPointId) -> None:
|
||||
"""
|
||||
Delete a point from the local storage.
|
||||
Args:
|
||||
point_id: id of the point to delete
|
||||
"""
|
||||
key = self.encode_key(point_id)
|
||||
cursor = self.storage.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM points WHERE id = ?",
|
||||
(key,),
|
||||
)
|
||||
self.storage.commit()
|
||||
|
||||
def load(self) -> Iterable[models.PointStruct]:
|
||||
"""
|
||||
Load a point from the local storage.
|
||||
Returns:
|
||||
point: loaded point
|
||||
"""
|
||||
cursor = self.storage.cursor()
|
||||
cursor.execute("SELECT point FROM points")
|
||||
for row in cursor.fetchall():
|
||||
yield pickle.loads(row[0])
|
||||
|
||||
|
||||
def test_persistence() -> None:
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
persistence = CollectionPersistence(tmpdir)
|
||||
point = models.PointStruct(id=1, vector=[1.0, 2.0, 3.0], payload={"a": 1})
|
||||
persistence.persist(point)
|
||||
for loaded_point in persistence.load():
|
||||
assert loaded_point == point
|
||||
break
|
||||
|
||||
del persistence
|
||||
persistence = CollectionPersistence(tmpdir)
|
||||
for loaded_point in persistence.load():
|
||||
assert loaded_point == point
|
||||
break
|
||||
|
||||
persistence.delete(point.id)
|
||||
persistence.delete(point.id)
|
||||
for _ in persistence.load():
|
||||
assert False, "Should not load anything"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http.models import SparseVector
|
||||
|
||||
|
||||
def empty_sparse_vector() -> SparseVector:
|
||||
return SparseVector(
|
||||
indices=[],
|
||||
values=[],
|
||||
)
|
||||
|
||||
|
||||
def validate_sparse_vector(vector: SparseVector) -> None:
|
||||
assert len(vector.indices) == len(
|
||||
vector.values
|
||||
), "Indices and values must have the same length"
|
||||
assert not np.isnan(vector.values).any(), "Values must not contain NaN"
|
||||
assert len(vector.indices) == len(set(vector.indices)), "Indices must be unique"
|
||||
|
||||
|
||||
def is_sorted(vector: SparseVector) -> bool:
|
||||
for i in range(1, len(vector.indices)):
|
||||
if vector.indices[i] < vector.indices[i - 1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def sort_sparse_vector(vector: SparseVector) -> SparseVector:
|
||||
if is_sorted(vector):
|
||||
return vector
|
||||
|
||||
sorted_indices = np.argsort(vector.indices)
|
||||
return SparseVector(
|
||||
indices=[vector.indices[i] for i in sorted_indices],
|
||||
values=[vector.values[i] for i in sorted_indices],
|
||||
)
|
||||
@@ -0,0 +1,314 @@
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.http.models import SparseVector
|
||||
from qdrant_client.local.distances import EPSILON, fast_sigmoid, scaled_fast_sigmoid
|
||||
from qdrant_client.local.sparse import (
|
||||
empty_sparse_vector,
|
||||
is_sorted,
|
||||
sort_sparse_vector,
|
||||
validate_sparse_vector,
|
||||
)
|
||||
|
||||
|
||||
class SparseRecoQuery:
|
||||
def __init__(
|
||||
self,
|
||||
positive: Optional[list[SparseVector]] = None,
|
||||
negative: Optional[list[SparseVector]] = None,
|
||||
strategy: Optional[types.RecommendStrategy] = None,
|
||||
):
|
||||
assert strategy is not None, "Recommend strategy must be provided"
|
||||
|
||||
self.strategy = strategy
|
||||
|
||||
positive = positive if positive is not None else []
|
||||
negative = negative if negative is not None else []
|
||||
|
||||
for i, vector in enumerate(positive):
|
||||
validate_sparse_vector(vector)
|
||||
positive[i] = sort_sparse_vector(vector)
|
||||
|
||||
for i, vector in enumerate(negative):
|
||||
validate_sparse_vector(vector)
|
||||
negative[i] = sort_sparse_vector(vector)
|
||||
|
||||
self.positive = positive
|
||||
self.negative = negative
|
||||
|
||||
def transform_sparse(
|
||||
self, foo: Callable[["SparseVector"], "SparseVector"]
|
||||
) -> "SparseRecoQuery":
|
||||
return SparseRecoQuery(
|
||||
positive=[foo(vector) for vector in self.positive],
|
||||
negative=[foo(vector) for vector in self.negative],
|
||||
strategy=self.strategy,
|
||||
)
|
||||
|
||||
|
||||
class SparseContextPair:
|
||||
def __init__(self, positive: SparseVector, negative: SparseVector):
|
||||
validate_sparse_vector(positive)
|
||||
validate_sparse_vector(negative)
|
||||
self.positive: SparseVector = sort_sparse_vector(positive)
|
||||
self.negative: SparseVector = sort_sparse_vector(negative)
|
||||
|
||||
|
||||
class SparseDiscoveryQuery:
|
||||
def __init__(self, target: SparseVector, context: list[SparseContextPair]):
|
||||
validate_sparse_vector(target)
|
||||
self.target: SparseVector = sort_sparse_vector(target)
|
||||
self.context = context
|
||||
|
||||
def transform_sparse(
|
||||
self, foo: Callable[["SparseVector"], "SparseVector"]
|
||||
) -> "SparseDiscoveryQuery":
|
||||
return SparseDiscoveryQuery(
|
||||
target=foo(self.target),
|
||||
context=[
|
||||
SparseContextPair(foo(pair.positive), foo(pair.negative)) for pair in self.context
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class SparseContextQuery:
|
||||
def __init__(self, context_pairs: list[SparseContextPair]):
|
||||
self.context_pairs = context_pairs
|
||||
|
||||
def transform_sparse(
|
||||
self, foo: Callable[["SparseVector"], "SparseVector"]
|
||||
) -> "SparseContextQuery":
|
||||
return SparseContextQuery(
|
||||
context_pairs=[
|
||||
SparseContextPair(foo(pair.positive), foo(pair.negative))
|
||||
for pair in self.context_pairs
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
SparseQueryVector = Union[
|
||||
SparseVector,
|
||||
SparseDiscoveryQuery,
|
||||
SparseContextQuery,
|
||||
SparseRecoQuery,
|
||||
]
|
||||
|
||||
|
||||
def calculate_distance_sparse(
|
||||
query: SparseVector, vectors: list[SparseVector], empty_is_zero: bool = False
|
||||
) -> types.NumpyArray:
|
||||
"""Calculate distances between a query sparse vector and a list of sparse vectors.
|
||||
|
||||
Args:
|
||||
query (SparseVector): The query sparse vector.
|
||||
vectors (list[SparseVector]): A list of sparse vectors to compare against.
|
||||
empty_is_zero (bool): If True, distance between vectors with no overlap is treated as zero.
|
||||
Otherwise, it is treated as negative infinity.
|
||||
Simple nearest search requires `empty_is_zero` to be False, while methods like
|
||||
recommend, discovery, and context search require True.
|
||||
"""
|
||||
scores = []
|
||||
|
||||
for vector in vectors:
|
||||
score = sparse_dot_product(query, vector)
|
||||
if score is not None:
|
||||
scores.append(score)
|
||||
elif not empty_is_zero:
|
||||
# means no overlap
|
||||
scores.append(np.float32("-inf"))
|
||||
else:
|
||||
scores.append(np.float32(0.0))
|
||||
|
||||
return np.array(scores, dtype=np.float32)
|
||||
|
||||
|
||||
# Expects sorted indices
|
||||
# Returns None if no overlap
|
||||
def sparse_dot_product(vector1: SparseVector, vector2: SparseVector) -> Optional[np.float32]:
|
||||
result = 0.0
|
||||
i, j = 0, 0
|
||||
overlap = False
|
||||
|
||||
assert is_sorted(vector1), "Query sparse vector must be sorted"
|
||||
assert is_sorted(vector2), "Sparse vector to compare with must be sorted"
|
||||
|
||||
while i < len(vector1.indices) and j < len(vector2.indices):
|
||||
if vector1.indices[i] == vector2.indices[j]:
|
||||
overlap = True
|
||||
result += vector1.values[i] * vector2.values[j]
|
||||
i += 1
|
||||
j += 1
|
||||
elif vector1.indices[i] < vector2.indices[j]:
|
||||
i += 1
|
||||
else:
|
||||
j += 1
|
||||
|
||||
if overlap:
|
||||
return np.float32(result)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def calculate_sparse_discovery_ranks(
|
||||
context: list[SparseContextPair],
|
||||
vectors: list[SparseVector],
|
||||
) -> types.NumpyArray:
|
||||
overall_ranks: types.NumpyArray = np.zeros(len(vectors), dtype=np.int32)
|
||||
for pair in context:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_distance_sparse(pair.positive, vectors, empty_is_zero=True)
|
||||
neg = calculate_distance_sparse(pair.negative, vectors, empty_is_zero=True)
|
||||
|
||||
pair_ranks = np.array(
|
||||
[
|
||||
1 if is_bigger else 0 if is_equal else -1
|
||||
for is_bigger, is_equal in zip(pos > neg, pos == neg)
|
||||
]
|
||||
)
|
||||
|
||||
overall_ranks += pair_ranks
|
||||
|
||||
return overall_ranks
|
||||
|
||||
|
||||
def calculate_sparse_discovery_scores(
|
||||
query: SparseDiscoveryQuery, vectors: list[SparseVector]
|
||||
) -> types.NumpyArray:
|
||||
ranks = calculate_sparse_discovery_ranks(query.context, vectors)
|
||||
|
||||
# Get distances to target
|
||||
distances_to_target = calculate_distance_sparse(query.target, vectors, empty_is_zero=True)
|
||||
|
||||
sigmoided_distances = np.fromiter(
|
||||
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
|
||||
)
|
||||
|
||||
return ranks + sigmoided_distances
|
||||
|
||||
|
||||
def calculate_sparse_context_scores(
|
||||
query: SparseContextQuery, vectors: list[SparseVector]
|
||||
) -> types.NumpyArray:
|
||||
overall_scores: types.NumpyArray = np.zeros(len(vectors), dtype=np.float32)
|
||||
for pair in query.context_pairs:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_distance_sparse(pair.positive, vectors, empty_is_zero=True)
|
||||
neg = calculate_distance_sparse(pair.negative, vectors, empty_is_zero=True)
|
||||
|
||||
difference = pos - neg - EPSILON
|
||||
pair_scores = np.fromiter(
|
||||
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
|
||||
)
|
||||
overall_scores += pair_scores
|
||||
|
||||
return overall_scores
|
||||
|
||||
|
||||
def calculate_sparse_recommend_best_scores(
|
||||
query: SparseRecoQuery, vectors: list[SparseVector]
|
||||
) -> types.NumpyArray:
|
||||
def get_best_scores(examples: list[SparseVector]) -> types.NumpyArray:
|
||||
vector_count = len(vectors)
|
||||
|
||||
# Get scores to all examples
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_distance_sparse(example, vectors, empty_is_zero=True)
|
||||
scores.append(score)
|
||||
|
||||
# Keep only max for each vector
|
||||
if len(scores) == 0:
|
||||
scores.append(np.full(vector_count, -np.inf))
|
||||
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
|
||||
|
||||
return best_scores
|
||||
|
||||
pos = get_best_scores(query.positive)
|
||||
neg = get_best_scores(query.negative)
|
||||
|
||||
# Choose from best positive or best negative,
|
||||
# in both cases we apply sigmoid and then negate depending on the order
|
||||
return np.where(
|
||||
pos > neg,
|
||||
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
|
||||
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
|
||||
)
|
||||
|
||||
|
||||
def calculate_sparse_recommend_sum_scores(
|
||||
query: SparseRecoQuery, vectors: list[SparseVector]
|
||||
) -> types.NumpyArray:
|
||||
def get_sum_scores(examples: list[SparseVector]) -> types.NumpyArray:
|
||||
vector_count = len(vectors)
|
||||
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_distance_sparse(example, vectors, empty_is_zero=True)
|
||||
scores.append(score)
|
||||
|
||||
if len(scores) == 0:
|
||||
scores.append(np.zeros(vector_count))
|
||||
|
||||
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
|
||||
return sum_scores
|
||||
|
||||
pos = get_sum_scores(query.positive)
|
||||
neg = get_sum_scores(query.negative)
|
||||
|
||||
return pos - neg
|
||||
|
||||
|
||||
# Expects sorted indices
|
||||
def combine_aggregate(vector1: SparseVector, vector2: SparseVector, op: Callable) -> SparseVector:
|
||||
result = empty_sparse_vector()
|
||||
i, j = 0, 0
|
||||
while i < len(vector1.indices) and j < len(vector2.indices):
|
||||
if vector1.indices[i] == vector2.indices[j]:
|
||||
result.indices.append(vector1.indices[i])
|
||||
result.values.append(op(vector1.values[i], vector2.values[j]))
|
||||
i += 1
|
||||
j += 1
|
||||
elif vector1.indices[i] < vector2.indices[j]:
|
||||
result.indices.append(vector1.indices[i])
|
||||
result.values.append(op(vector1.values[i], 0.0))
|
||||
i += 1
|
||||
else:
|
||||
result.indices.append(vector2.indices[j])
|
||||
result.values.append(op(0.0, vector2.values[j]))
|
||||
j += 1
|
||||
|
||||
while i < len(vector1.indices):
|
||||
result.indices.append(vector1.indices[i])
|
||||
result.values.append(op(vector1.values[i], 0.0))
|
||||
i += 1
|
||||
|
||||
while j < len(vector2.indices):
|
||||
result.indices.append(vector2.indices[j])
|
||||
result.values.append(op(0.0, vector2.values[j]))
|
||||
j += 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Expects sorted indices
|
||||
def sparse_avg(vectors: Sequence[SparseVector]) -> SparseVector:
|
||||
result = empty_sparse_vector()
|
||||
if len(vectors) == 0:
|
||||
return result
|
||||
|
||||
sparse_count = 0
|
||||
for vector in vectors:
|
||||
sparse_count += 1
|
||||
result = combine_aggregate(result, vector, lambda v1, v2: v1 + v2)
|
||||
|
||||
result.values = np.divide(result.values, sparse_count).tolist()
|
||||
return result
|
||||
|
||||
|
||||
# Expects sorted indices
|
||||
def merge_positive_and_negative_avg(
|
||||
positive: SparseVector, negative: SparseVector
|
||||
) -> SparseVector:
|
||||
return combine_aggregate(positive, negative, lambda pos, neg: pos + pos - neg)
|
||||
@@ -0,0 +1,57 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from qdrant_client.local.datetime_utils import parse
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # type: ignore
|
||||
"date_str, expected",
|
||||
[
|
||||
("2021-01-01T00:00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01T00:00:00Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01T00:00:00+00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01T00:00:00.000000", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01T00:00:00.000000Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
(
|
||||
"2021-01-01T00:00:00.000000+01:00",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=1))),
|
||||
),
|
||||
(
|
||||
"2021-01-01T00:00:00.000000-10:00",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-10))),
|
||||
),
|
||||
("2021-01-01", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01 00:00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01 00:00:00Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
(
|
||||
"2021-01-01 00:00:00+0200",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=2))),
|
||||
),
|
||||
("2021-01-01 00:00:00.000000", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01 00:00:00.000000Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
(
|
||||
"2021-01-01 00:00:00.000000+00:30",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(minutes=30))),
|
||||
),
|
||||
(
|
||||
"2021-01-01 00:00:00.000009+00:30",
|
||||
datetime(2021, 1, 1, 0, 0, 0, 9, tzinfo=timezone(timedelta(minutes=30))),
|
||||
),
|
||||
# this is accepted in core but not here, there is no specifier for only-hour offset
|
||||
(
|
||||
"2021-01-01 00:00:00.000+01",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=1))),
|
||||
),
|
||||
(
|
||||
"2021-01-01 00:00:00.000-10",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-10))),
|
||||
),
|
||||
(
|
||||
"2021-01-01 00:00:00-03:00",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-3))),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_dates(date_str: str, expected: datetime):
|
||||
assert parse(date_str) == expected
|
||||
@@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.local.distances import calculate_distance
|
||||
from qdrant_client.local.multi_distances import calculate_multi_distance
|
||||
from qdrant_client.local.sparse_distances import calculate_distance_sparse
|
||||
|
||||
|
||||
def test_distances() -> None:
|
||||
query = np.array([1.0, 2.0, 3.0])
|
||||
vectors = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
|
||||
assert np.allclose(calculate_distance(query, vectors, models.Distance.DOT), [14.0, 14.0])
|
||||
assert np.allclose(calculate_distance(query, vectors, models.Distance.EUCLID), [0.0, 0.0])
|
||||
assert np.allclose(calculate_distance(query, vectors, models.Distance.MANHATTAN), [0.0, 0.0])
|
||||
# cosine modifies vectors inplace
|
||||
assert np.allclose(calculate_distance(query, vectors, models.Distance.COSINE), [1.0, 1.0])
|
||||
|
||||
query = np.array([1.0, 0.0, 1.0])
|
||||
vectors = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 0.0]])
|
||||
|
||||
assert np.allclose(
|
||||
calculate_distance(query, vectors, models.Distance.DOT), [4.0, 0.0], atol=0.0001
|
||||
)
|
||||
assert np.allclose(
|
||||
calculate_distance(query, vectors, models.Distance.EUCLID),
|
||||
[2.82842712, 1.7320508],
|
||||
atol=0.0001,
|
||||
)
|
||||
|
||||
assert np.allclose(
|
||||
calculate_distance(query, vectors, models.Distance.MANHATTAN),
|
||||
[4.0, 3.0],
|
||||
atol=0.0001,
|
||||
)
|
||||
# cosine modifies vectors inplace
|
||||
assert np.allclose(
|
||||
calculate_distance(query, vectors, models.Distance.COSINE),
|
||||
[0.75592895, 0.0],
|
||||
atol=0.0001,
|
||||
)
|
||||
|
||||
sparse_query = models.SparseVector(indices=[1, 2], values=[1, 2])
|
||||
sparse_vectors = [models.SparseVector(indices=[10, 20], values=[1, 2])]
|
||||
|
||||
assert calculate_distance_sparse(sparse_query, sparse_vectors) == [np.float32("-inf")]
|
||||
|
||||
sparse_vectors = [
|
||||
models.SparseVector(indices=[1, 2], values=[3, 4]),
|
||||
models.SparseVector(indices=[1, 2, 3], values=[1, 2, 3]),
|
||||
]
|
||||
assert np.allclose(
|
||||
calculate_distance_sparse(sparse_query, sparse_vectors), [11.0, 5], atol=0.0001
|
||||
)
|
||||
|
||||
multivector_query = np.array([[1, 2, 3], [3, 4, 5]])
|
||||
docs = [np.array([[1, 2, 3], [0, 1, 2]])]
|
||||
assert calculate_multi_distance(multivector_query, docs, models.Distance.DOT)[0] == 40.0
|
||||
+189
@@ -0,0 +1,189 @@
|
||||
from qdrant_client.http.models import models
|
||||
from qdrant_client.local.payload_filters import check_filter
|
||||
|
||||
|
||||
def test_nested_payload_filters():
|
||||
payload = {
|
||||
"country": {
|
||||
"name": "Germany",
|
||||
"capital": "Berlin",
|
||||
"cities": [
|
||||
{
|
||||
"name": "Berlin",
|
||||
"population": 3.7,
|
||||
"location": {
|
||||
"lon": 13.76116,
|
||||
"lat": 52.33826,
|
||||
},
|
||||
"sightseeing": ["Brandenburg Gate", "Reichstag"],
|
||||
},
|
||||
{
|
||||
"name": "Munich",
|
||||
"population": 1.5,
|
||||
"location": {
|
||||
"lon": 11.57549,
|
||||
"lat": 48.13743,
|
||||
},
|
||||
"sightseeing": ["Marienplatz", "Olympiapark"],
|
||||
},
|
||||
{
|
||||
"name": "Hamburg",
|
||||
"population": 1.8,
|
||||
"location": {
|
||||
"lon": 9.99368,
|
||||
"lat": 53.55108,
|
||||
},
|
||||
"sightseeing": ["Reeperbahn", "Elbphilharmonie"],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"nested": {
|
||||
"key": "country.cities",
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "population",
|
||||
"range": {
|
||||
"gte": 1.0,
|
||||
},
|
||||
}
|
||||
],
|
||||
"must_not": [{"key": "sightseeing", "values_count": {"gt": 1}}],
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is False
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"nested": {
|
||||
"key": "country.cities",
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "population",
|
||||
"range": {
|
||||
"gte": 1.0,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is True
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"nested": {
|
||||
"key": "country.cities",
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "population",
|
||||
"range": {
|
||||
"gte": 1.0,
|
||||
},
|
||||
},
|
||||
{"key": "sightseeing", "values_count": {"gt": 2}},
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is False
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"nested": {
|
||||
"key": "country.cities",
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "population",
|
||||
"range": {
|
||||
"gte": 9.0,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is False
|
||||
|
||||
|
||||
def test_geo_polygon_filter_query():
|
||||
payload = {
|
||||
"location": [
|
||||
{
|
||||
"lon": 70.0,
|
||||
"lat": 70.0,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"key": "location",
|
||||
"geo_polygon": {
|
||||
"exterior": {
|
||||
"points": [
|
||||
{"lon": 55.455868, "lat": 55.495862},
|
||||
{"lon": 86.455868, "lat": 55.495862},
|
||||
{"lon": 86.455868, "lat": 86.495862},
|
||||
{"lon": 55.455868, "lat": 86.495862},
|
||||
{"lon": 55.455868, "lat": 55.495862},
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is True
|
||||
|
||||
payload = {
|
||||
"location": [
|
||||
{
|
||||
"lon": 30.693738,
|
||||
"lat": 30.502165,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is False
|
||||
+549
@@ -0,0 +1,549 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from qdrant_client.local.json_path_parser import (
|
||||
JsonPathItem,
|
||||
JsonPathItemType,
|
||||
parse_json_path,
|
||||
)
|
||||
from qdrant_client.local.payload_value_extractor import value_by_key
|
||||
from qdrant_client.local.payload_value_setter import set_value_by_key
|
||||
|
||||
|
||||
def test_parse_json_path() -> None:
|
||||
jp_key = "a"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [JsonPathItem(item_type=JsonPathItemType.KEY, key="a")]
|
||||
|
||||
jp_key = "a.b"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
jp_key = 'a."a[b]".c'
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a[b]"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="c"),
|
||||
]
|
||||
|
||||
jp_key = "a[0]"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
]
|
||||
|
||||
jp_key = "a[0].b"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
jp_key = "a[0].b[1]"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
|
||||
]
|
||||
|
||||
jp_key = "a[][]"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX, index=None),
|
||||
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX, index=None),
|
||||
]
|
||||
|
||||
jp_key = "a[0][1]"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
|
||||
]
|
||||
|
||||
jp_key = "a[0][1].b"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
jp_key = 'a."k.c"'
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="k.c"),
|
||||
]
|
||||
|
||||
jp_key = 'a."c[][]".b'
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="c[][]"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
jp_key = 'a."c..q".b'
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="c..q"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a."k.c'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a."k.c".'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a."k.c".[]'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a.'k.c'"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a["
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[]]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[][]."
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[][]b"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = ".a"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[x]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a[]""'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = '""b'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "[]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[.]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a["1"]'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = ""
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a..c"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a.c[]b[]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a.c[].[]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
|
||||
def test_value_by_key() -> None:
|
||||
payload = {
|
||||
"name": "John",
|
||||
"age": 25,
|
||||
"counts": [1, 2, 3],
|
||||
"address": {
|
||||
"city": "New York",
|
||||
},
|
||||
"location": [
|
||||
{"name": "home", "counts": [1, 2, 3]},
|
||||
{"name": "work", "counts": [4, 5, 6]},
|
||||
],
|
||||
"nested": [{"empty": []}, {"empty": []}, {"empty": None}],
|
||||
"the_null": None,
|
||||
"the": {"nested.key": "cuckoo"},
|
||||
"double-nest-array": [[1, 2], [3, 4], [5, 6]],
|
||||
}
|
||||
# region flat=True
|
||||
assert value_by_key(payload, "name") == ["John"]
|
||||
assert value_by_key(payload, "address.city") == ["New York"]
|
||||
assert value_by_key(payload, "location[].name") == ["home", "work"]
|
||||
assert value_by_key(payload, "location[0].name") == ["home"]
|
||||
assert value_by_key(payload, "location[1].name") == ["work"]
|
||||
assert value_by_key(payload, "location[2].name") is None
|
||||
assert value_by_key(payload, "location[].name[0]") is None
|
||||
assert value_by_key(payload, "location[0]") == [{"name": "home", "counts": [1, 2, 3]}]
|
||||
assert value_by_key(payload, "not_exits") is None
|
||||
assert value_by_key(payload, "address") == [{"city": "New York"}]
|
||||
assert value_by_key(payload, "address.city[0]") is None
|
||||
assert value_by_key(payload, "counts") == [1, 2, 3]
|
||||
assert value_by_key(payload, "location[].counts") == [1, 2, 3, 4, 5, 6]
|
||||
assert value_by_key(payload, "nested[].empty") == [None]
|
||||
assert value_by_key(payload, "the_null") == [None]
|
||||
assert value_by_key(payload, 'the."nested.key"') == ["cuckoo"]
|
||||
assert value_by_key(payload, "double-nest-array[][]") == [1, 2, 3, 4, 5, 6]
|
||||
assert value_by_key(payload, "double-nest-array[0][]") == [1, 2]
|
||||
assert value_by_key(payload, "double-nest-array[0][0]") == [1]
|
||||
assert value_by_key(payload, "double-nest-array[0][0]") == [1]
|
||||
assert value_by_key(payload, "double-nest-array[][1]") == [2, 4, 6]
|
||||
# endregion
|
||||
|
||||
# region flat=False
|
||||
assert value_by_key(payload, "name", flat=False) == ["John"]
|
||||
assert value_by_key(payload, "address.city", flat=False) == ["New York"]
|
||||
assert value_by_key(payload, "location[].name", flat=False) == ["home", "work"]
|
||||
assert value_by_key(payload, "location[0].name", flat=False) == ["home"]
|
||||
assert value_by_key(payload, "location[1].name", flat=False) == ["work"]
|
||||
assert value_by_key(payload, "location[2].name", flat=False) is None
|
||||
assert value_by_key(payload, "location[].name[0]", flat=False) is None
|
||||
assert value_by_key(payload, "location[0]", flat=False) == [
|
||||
{"name": "home", "counts": [1, 2, 3]}
|
||||
]
|
||||
assert value_by_key(payload, "not_exist", flat=False) is None
|
||||
assert value_by_key(payload, "address", flat=False) == [{"city": "New York"}]
|
||||
assert value_by_key(payload, "address.city[0]", flat=False) is None
|
||||
assert value_by_key(payload, "counts", flat=False) == [[1, 2, 3]]
|
||||
assert value_by_key(payload, "location[].counts", flat=False) == [
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
]
|
||||
assert value_by_key(payload, "nested[].empty", flat=False) == [[], [], None]
|
||||
assert value_by_key(payload, "the_null", flat=False) == [None]
|
||||
|
||||
assert value_by_key(payload, "age.nested.not_exist") is None
|
||||
# endregion
|
||||
|
||||
|
||||
def test_set_value_by_key() -> None:
|
||||
# region valid keys
|
||||
payload: dict[str, Any] = {}
|
||||
new_value: dict[str, Any] = {}
|
||||
key = "a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {}}, payload
|
||||
|
||||
payload = {"a": {"a": 2}}
|
||||
new_value = {}
|
||||
key = "a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": 2}}, payload
|
||||
|
||||
payload = {"a": {"a": 2}}
|
||||
new_value = {"b": 3}
|
||||
key = "a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": 2, "b": 3}}, payload
|
||||
|
||||
payload = {"a": {"a": 2}}
|
||||
new_value = {"a": 3}
|
||||
key = "a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": 3}}, payload
|
||||
|
||||
payload = {"a": {"a": 2}}
|
||||
new_value = {"a": 3}
|
||||
key = "a.a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": {"a": 3}}}, payload
|
||||
|
||||
payload = {"a": {"a": {"a": 1}}}
|
||||
new_value = {"b": 2}
|
||||
key = "a.a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": {"a": 1, "b": 2}}}, payload
|
||||
|
||||
payload = {"a": {"a": {"a": 1}}}
|
||||
new_value = {"a": 2}
|
||||
key = "a.a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": {"a": 2}}}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": []}, payload
|
||||
|
||||
payload = {"a": [{}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"b": 2}]}, payload
|
||||
|
||||
payload = {"a": [{"a": 1}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"a": 1, "b": 2}]}, payload
|
||||
|
||||
payload = {"a": [[]]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"b": 2}]}, payload
|
||||
|
||||
payload = {"a": [[]]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[]]}, payload
|
||||
|
||||
payload = {"a": [{"a": []}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0].a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"a": {"b": 2}}]}, payload
|
||||
|
||||
payload = {"a": [{"a": []}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[].a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"a": {"b": 2}}]}, payload
|
||||
|
||||
payload = {"a": [{"a": []}, {"a": []}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[].a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"a": {"b": 2}}, {"a": {"b": 2}}]}, payload
|
||||
|
||||
payload = {"a": 1, "b": 2}
|
||||
new_value = {"c": 3}
|
||||
key = "c"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": 1, "b": 2, "c": {"c": 3}}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": 1}}}
|
||||
new_value = {"d": 2}
|
||||
key = "a.b.d"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": 1, "d": {"d": 2}}}}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": 1}}}
|
||||
new_value = {"c": 2}
|
||||
key = "a.b"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": 2}}}, payload
|
||||
|
||||
payload = {"a": [{"b": 1}, {"b": 2}]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"b": 1}, {"b": 2, "c": 3}]}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {"b": {"c": 1}}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": []}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": {"d": {"e": 1}}}}}
|
||||
new_value = {"f": 2}
|
||||
key = "a.b.c.d"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": {"d": {"e": 1, "f": 2}}}}}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": 1}}}
|
||||
new_value = {"d": {"e": 2}}
|
||||
key = "a.b.c"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": {"d": {"e": 2}}}}}, payload
|
||||
|
||||
payload = {"a": [{"b": 1}]}
|
||||
new_value = {"c": 2}
|
||||
key = "a[1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"b": 1}]}, payload
|
||||
|
||||
payload = {"a": {"b": [{"c": 1}, {"c": 2}]}}
|
||||
new_value = {"d": 3}
|
||||
key = "a.b[0].c"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": [{"c": {"d": 3}}, {"c": 2}]}}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": [{"d": 1}]}}}
|
||||
new_value = {"e": {"f": 2}}
|
||||
key = "a.b.c[0].d"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": [{"d": {"e": {"f": 2}}}]}}}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[0][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2}]]}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[1][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1}], [{"b": 2, "c": 3}]]}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[1][1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1}], [{"b": 2}]]}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2, "c": 3}]]}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[][]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2, "c": 3}]]}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {"c": 3}
|
||||
key = 'a."b.c"'
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b.c": {"c": 3}}}, payload
|
||||
|
||||
payload = {"a": {"c": [1]}}
|
||||
new_value = {"a": 1}
|
||||
key = "a.c[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"c": [{"a": 1}]}}, payload
|
||||
|
||||
payload = {"a": {"c": [1]}}
|
||||
new_value = {"a": 1}
|
||||
key = "a.c[0].d"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"c": [{"d": {"a": 1}}]}}, payload
|
||||
|
||||
payload = {"": 2}
|
||||
new_value = {"a": 1}
|
||||
key = '""'
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"": {"a": 1}}, payload
|
||||
# endregion
|
||||
|
||||
# region exceptions
|
||||
|
||||
try:
|
||||
payload = {"a": []}
|
||||
new_value = {"c": 3}
|
||||
key = "a.'b.c'"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert False, f"Should've raised an exception due to the key with incorrect quotes: {key}"
|
||||
except Exception:
|
||||
assert True
|
||||
|
||||
try:
|
||||
payload = {"a": [{"b": 1}, {"b": 2}]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[-1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert False, "Negative indexation is not supported"
|
||||
except Exception:
|
||||
assert True
|
||||
|
||||
try:
|
||||
payload = {"a": [{"b": 1}, {"b": 2}]}
|
||||
new_value = {"c": 3}
|
||||
key = "a["
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert False, f"Should've raised an exception due to the incorrect key: {key}"
|
||||
except Exception:
|
||||
assert True
|
||||
|
||||
try:
|
||||
payload = {"a": [{"b": 1}, {"b": 2}]}
|
||||
new_value = {"c": 3}
|
||||
key = "a]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert False, f"Should've raise an exception due to the incorrect key: {key}"
|
||||
except Exception:
|
||||
assert True
|
||||
|
||||
# endregion
|
||||
|
||||
# region wrong keys
|
||||
payload = {"a": []}
|
||||
new_value = {}
|
||||
key = "a.b[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": []}}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {}
|
||||
key = "a.b"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {}}}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {"c": 2}
|
||||
key = "a.b"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": 2}}}, payload
|
||||
|
||||
payload = {"a": [[{"a": 1}]]}
|
||||
new_value = {"a": 2}
|
||||
key = "a.b[0][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": []}}, payload
|
||||
|
||||
payload = {"a": {"c": 2}}
|
||||
new_value = {"a": 1}
|
||||
key = "a[]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": []}, payload
|
||||
|
||||
payload = {"a": {"c": 2}}
|
||||
new_value = {"a": 1}
|
||||
key = "a[].b"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": []}, payload
|
||||
|
||||
payload = {"a": {"c": [1]}}
|
||||
new_value = {"a": 1}
|
||||
key = "a.c[][][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"c": [[]]}}, payload
|
||||
|
||||
payload = {"a": {"c": [{"d": 1}]}}
|
||||
new_value = {"a": 1}
|
||||
key = "a.c[][]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"c": [[]]}}, payload
|
||||
# endregion
|
||||
+135
@@ -0,0 +1,135 @@
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from qdrant_client import models
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def client():
|
||||
"""
|
||||
Sets up multiple collections with a bunch of points
|
||||
"""
|
||||
client = QdrantLocal(":memory:")
|
||||
client.create_collection(
|
||||
"collection_default",
|
||||
vectors_config=models.VectorParams(
|
||||
size=4,
|
||||
distance=models.Distance.DOT,
|
||||
),
|
||||
)
|
||||
|
||||
client.create_collection(
|
||||
"collection_multiple_vectors",
|
||||
vectors_config={
|
||||
"": models.VectorParams(
|
||||
size=4,
|
||||
distance=models.Distance.DOT,
|
||||
),
|
||||
"byte": models.VectorParams(
|
||||
size=4, distance=models.Distance.DOT, datatype=models.Datatype.UINT8
|
||||
),
|
||||
"colbert": models.VectorParams(
|
||||
size=4,
|
||||
distance=models.Distance.DOT,
|
||||
multivector_config=models.MultiVectorConfig(
|
||||
comparator=models.MultiVectorComparator.MAX_SIM
|
||||
),
|
||||
),
|
||||
},
|
||||
sparse_vectors_config={"sparse": models.SparseVectorParams()},
|
||||
)
|
||||
|
||||
client.upsert(
|
||||
"collection_default",
|
||||
[
|
||||
models.PointStruct(id=1, vector=[0.25, 0.0, 0.0, 0.0]),
|
||||
],
|
||||
)
|
||||
|
||||
client.upsert(
|
||||
"collection_multiple_vectors",
|
||||
[
|
||||
models.PointStruct(
|
||||
id=1,
|
||||
vector={
|
||||
"": [0.0, 0.25, 0.0, 0.0],
|
||||
"byte": [0, 25, 0, 0],
|
||||
"colbert": [[0.0, 0.25, 0.0, 0.0], [0.0, 0.25, 0.0, 0.0]],
|
||||
"sparse": models.SparseVector(indices=[1], values=[0.25]),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
models.NearestQuery(nearest=1),
|
||||
models.RecommendQuery(recommend=models.RecommendInput(positive=[1], negative=[1])),
|
||||
models.DiscoverQuery(
|
||||
discover=models.DiscoverInput(
|
||||
target=1, context=[models.ContextPair(**{"positive": 1, "negative": 1})]
|
||||
)
|
||||
),
|
||||
models.ContextQuery(context=[models.ContextPair(**{"positive": 1, "negative": 1})]),
|
||||
models.OrderByQuery(order_by=models.OrderBy(key="price", direction=models.Direction.ASC)),
|
||||
models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"using, lookup_from, expected, mentioned",
|
||||
[
|
||||
(None, None, [0.25, 0.0, 0.0, 0.0], True),
|
||||
("", None, [0.25, 0.0, 0.0, 0.0], True),
|
||||
(
|
||||
"byte",
|
||||
models.LookupLocation(collection="collection_multiple_vectors"),
|
||||
[0, 25, 0, 0],
|
||||
False,
|
||||
),
|
||||
(
|
||||
"",
|
||||
models.LookupLocation(collection="collection_multiple_vectors", vector="colbert"),
|
||||
[[0.0, 0.25, 0.0, 0.0], [0.0, 0.25, 0.0, 0.0]],
|
||||
False,
|
||||
),
|
||||
(
|
||||
None,
|
||||
models.LookupLocation(collection="collection_multiple_vectors", vector="sparse"),
|
||||
models.SparseVector(indices=[1], values=[0.25]),
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_vector_dereferencing(client, query, using, lookup_from, expected, mentioned):
|
||||
resolved, mentioned_ids = client._resolve_query_input(
|
||||
collection_name="collection_default",
|
||||
query=copy.deepcopy(query),
|
||||
using=using,
|
||||
lookup_from=lookup_from,
|
||||
)
|
||||
|
||||
if isinstance(resolved, models.NearestQuery):
|
||||
assert resolved.nearest == expected
|
||||
elif isinstance(resolved, models.RecommendQuery):
|
||||
assert resolved.recommend.positive == [expected]
|
||||
assert resolved.recommend.negative == [expected]
|
||||
elif isinstance(resolved, models.DiscoverQuery):
|
||||
assert resolved.discover.target == expected
|
||||
assert resolved.discover.context[0].positive == expected
|
||||
assert resolved.discover.context[0].negative == expected
|
||||
elif isinstance(resolved, models.ContextQuery):
|
||||
assert resolved.context[0].positive == expected
|
||||
assert resolved.context[0].negative == expected
|
||||
else:
|
||||
mentioned = False
|
||||
assert resolved == query
|
||||
|
||||
if mentioned:
|
||||
assert mentioned_ids == {1}
|
||||
@@ -0,0 +1,21 @@
|
||||
import random
|
||||
|
||||
from qdrant_client import models
|
||||
from qdrant_client.local.local_collection import LocalCollection, DEFAULT_VECTOR_NAME
|
||||
|
||||
|
||||
def test_get_vectors():
|
||||
collection = LocalCollection(
|
||||
models.CreateCollection(
|
||||
vectors=models.VectorParams(size=2, distance=models.Distance.MANHATTAN)
|
||||
)
|
||||
)
|
||||
collection.upsert(
|
||||
points=[
|
||||
models.PointStruct(id=i, vector=[random.random(), random.random()]) for i in range(10)
|
||||
]
|
||||
)
|
||||
|
||||
assert collection._get_vectors(idx=1, with_vectors=DEFAULT_VECTOR_NAME)
|
||||
assert collection._get_vectors(idx=2, with_vectors=True)
|
||||
assert collection._get_vectors(idx=3, with_vectors=False) is None
|
||||
@@ -0,0 +1 @@
|
||||
from .migrate import migrate
|
||||
@@ -0,0 +1,182 @@
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from qdrant_client._pydantic_compat import to_dict, model_fields
|
||||
from qdrant_client.client_base import QdrantBase
|
||||
from qdrant_client.http import models
|
||||
|
||||
|
||||
def upload_with_retry(
|
||||
client: QdrantBase,
|
||||
collection_name: str,
|
||||
points: Iterable[models.PointStruct],
|
||||
max_attempts: int = 3,
|
||||
pause: float = 3.0,
|
||||
) -> None:
|
||||
attempts = 1
|
||||
while attempts <= max_attempts:
|
||||
try:
|
||||
client.upload_points(
|
||||
collection_name=collection_name,
|
||||
points=points,
|
||||
wait=True,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}, attempt {attempts}/{max_attempts}")
|
||||
if attempts < max_attempts:
|
||||
print(f"Next attempt in {pause} seconds")
|
||||
time.sleep(pause)
|
||||
attempts += 1
|
||||
|
||||
raise Exception(f"Failed to upload points after {max_attempts} attempts")
|
||||
|
||||
|
||||
def migrate(
|
||||
source_client: QdrantBase,
|
||||
dest_client: QdrantBase,
|
||||
collection_names: Optional[list[str]] = None,
|
||||
recreate_on_collision: bool = False,
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
"""
|
||||
Migrate collections from source client to destination client
|
||||
|
||||
Args:
|
||||
source_client (QdrantBase): Source client
|
||||
dest_client (QdrantBase): Destination client
|
||||
collection_names (list[str], optional): List of collection names to migrate.
|
||||
If None - migrate all source client collections. Defaults to None.
|
||||
recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise
|
||||
raise ValueError.
|
||||
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
|
||||
"""
|
||||
collection_names = _select_source_collections(source_client, collection_names)
|
||||
if any(
|
||||
_has_custom_shards(source_client, collection_name) for collection_name in collection_names
|
||||
):
|
||||
raise ValueError("Migration of collections with custom shards is not supported yet")
|
||||
|
||||
collisions = _find_collisions(dest_client, collection_names)
|
||||
absent_dest_collections = set(collection_names) - set(collisions)
|
||||
|
||||
if collisions and not recreate_on_collision:
|
||||
raise ValueError(f"Collections already exist in dest_client: {collisions}")
|
||||
|
||||
for collection_name in absent_dest_collections:
|
||||
_recreate_collection(source_client, dest_client, collection_name)
|
||||
_migrate_collection(source_client, dest_client, collection_name, batch_size)
|
||||
|
||||
for collection_name in collisions:
|
||||
_recreate_collection(source_client, dest_client, collection_name)
|
||||
_migrate_collection(source_client, dest_client, collection_name, batch_size)
|
||||
|
||||
|
||||
def _has_custom_shards(source_client: QdrantBase, collection_name: str) -> bool:
|
||||
collection_info = source_client.get_collection(collection_name)
|
||||
return (
|
||||
getattr(collection_info.config.params, "sharding_method", None)
|
||||
== models.ShardingMethod.CUSTOM
|
||||
)
|
||||
|
||||
|
||||
def _select_source_collections(
|
||||
source_client: QdrantBase, collection_names: Optional[list[str]] = None
|
||||
) -> list[str]:
|
||||
source_collections = source_client.get_collections().collections
|
||||
source_collection_names = [collection.name for collection in source_collections]
|
||||
|
||||
if collection_names is not None:
|
||||
assert all(
|
||||
collection_name in source_collection_names for collection_name in collection_names
|
||||
), f"Source client does not have collections: {set(collection_names) - set(source_collection_names)}"
|
||||
else:
|
||||
collection_names = source_collection_names
|
||||
|
||||
return collection_names
|
||||
|
||||
|
||||
def _find_collisions(dest_client: QdrantBase, collection_names: list[str]) -> list[str]:
|
||||
dest_collections = dest_client.get_collections().collections
|
||||
dest_collection_names = {collection.name for collection in dest_collections}
|
||||
existing_dest_collections = dest_collection_names & set(collection_names)
|
||||
return list(existing_dest_collections)
|
||||
|
||||
|
||||
def _recreate_collection(
|
||||
source_client: QdrantBase,
|
||||
dest_client: QdrantBase,
|
||||
collection_name: str,
|
||||
) -> None:
|
||||
src_collection_info = source_client.get_collection(collection_name)
|
||||
src_config = src_collection_info.config
|
||||
src_payload_schema = src_collection_info.payload_schema
|
||||
if dest_client.collection_exists(collection_name):
|
||||
dest_client.delete_collection(collection_name)
|
||||
|
||||
strict_mode_config: Optional[models.StrictModeConfig] = None
|
||||
if src_config.strict_mode_config is not None:
|
||||
strict_mode_config = models.StrictModeConfig(
|
||||
**{
|
||||
k: v
|
||||
for k, v in to_dict(src_config.strict_mode_config).items()
|
||||
if k in model_fields(models.StrictModeConfig)
|
||||
}
|
||||
)
|
||||
dest_client.create_collection(
|
||||
collection_name,
|
||||
vectors_config=src_config.params.vectors,
|
||||
sparse_vectors_config=src_config.params.sparse_vectors,
|
||||
shard_number=src_config.params.shard_number,
|
||||
replication_factor=src_config.params.replication_factor,
|
||||
write_consistency_factor=src_config.params.write_consistency_factor,
|
||||
on_disk_payload=src_config.params.on_disk_payload,
|
||||
hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)),
|
||||
optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)),
|
||||
wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
|
||||
quantization_config=src_config.quantization_config,
|
||||
strict_mode_config=strict_mode_config,
|
||||
)
|
||||
|
||||
_recreate_payload_schema(dest_client, collection_name, src_payload_schema)
|
||||
|
||||
|
||||
def _recreate_payload_schema(
|
||||
dest_client: QdrantBase,
|
||||
collection_name: str,
|
||||
payload_schema: dict[str, models.PayloadIndexInfo],
|
||||
) -> None:
|
||||
for field_name, field_info in payload_schema.items():
|
||||
dest_client.create_payload_index(
|
||||
collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=field_info.data_type if field_info.params is None else field_info.params,
|
||||
)
|
||||
|
||||
|
||||
def _migrate_collection(
|
||||
source_client: QdrantBase,
|
||||
dest_client: QdrantBase,
|
||||
collection_name: str,
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Migrate collection from source client to destination client
|
||||
|
||||
Args:
|
||||
collection_name (str): Collection name
|
||||
source_client (QdrantBase): Source client
|
||||
dest_client (QdrantBase): Destination client
|
||||
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
|
||||
"""
|
||||
records, next_offset = source_client.scroll(collection_name, limit=2, with_vectors=True)
|
||||
upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
|
||||
while next_offset is not None:
|
||||
records, next_offset = source_client.scroll(
|
||||
collection_name, offset=next_offset, limit=batch_size, with_vectors=True
|
||||
)
|
||||
upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
|
||||
source_client_vectors_count = source_client.count(collection_name).count
|
||||
dest_client_vectors_count = dest_client.count(collection_name).count
|
||||
assert (
|
||||
source_client_vectors_count == dest_client_vectors_count
|
||||
), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}"
|
||||
@@ -0,0 +1,3 @@
|
||||
from qdrant_client.http.models import *
|
||||
from qdrant_client.fastembed_common import *
|
||||
from qdrant_client.embed.models import *
|
||||
@@ -0,0 +1,239 @@
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from multiprocessing import Queue, get_context
|
||||
from multiprocessing.context import BaseContext
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.sharedctypes import Synchronized as BaseValue
|
||||
from queue import Empty
|
||||
from typing import Any, Iterable, Optional, Type
|
||||
|
||||
# Single item should be processed in less than:
|
||||
processing_timeout = 10 * 60 # seconds
|
||||
|
||||
MAX_INTERNAL_BATCH_SIZE = 200
|
||||
|
||||
|
||||
class QueueSignals(str, Enum):
|
||||
stop = "stop"
|
||||
confirm = "confirm"
|
||||
error = "error"
|
||||
|
||||
|
||||
class Worker:
|
||||
@classmethod
|
||||
def start(cls, *args: Any, **kwargs: Any) -> "Worker":
|
||||
raise NotImplementedError()
|
||||
|
||||
def process(self, items: Iterable[Any]) -> Iterable[Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _worker(
|
||||
worker_class: Type[Worker],
|
||||
input_queue: Queue,
|
||||
output_queue: Queue,
|
||||
num_active_workers: BaseValue,
|
||||
worker_id: int,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
A worker that pulls data pints off the input queue, and places the execution result on the output queue.
|
||||
When there are no data pints left on the input queue, it decrements
|
||||
num_active_workers to signal completion.
|
||||
"""
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
logging.info(f"Reader worker: {worker_id} PID: {os.getpid()}")
|
||||
try:
|
||||
worker = worker_class.start(**kwargs)
|
||||
|
||||
# Keep going until you get an item that's None.
|
||||
def input_queue_iterable() -> Iterable[Any]:
|
||||
while True:
|
||||
item = input_queue.get()
|
||||
if item == QueueSignals.stop:
|
||||
break
|
||||
yield item
|
||||
|
||||
for processed_item in worker.process(input_queue_iterable()):
|
||||
output_queue.put(processed_item)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
logging.exception(e)
|
||||
output_queue.put(QueueSignals.error)
|
||||
finally:
|
||||
# It's important that we close and join the queue here before
|
||||
# decrementing num_active_workers. Otherwise our parent may join us
|
||||
# before the queue's feeder thread has passed all buffered items to
|
||||
# the underlying pipe resulting in a deadlock.
|
||||
#
|
||||
# See:
|
||||
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
|
||||
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
|
||||
input_queue.close()
|
||||
output_queue.close()
|
||||
input_queue.join_thread()
|
||||
output_queue.join_thread()
|
||||
|
||||
with num_active_workers.get_lock():
|
||||
num_active_workers.value -= 1
|
||||
|
||||
logging.info(f"Reader worker {worker_id} finished")
|
||||
|
||||
|
||||
class ParallelWorkerPool:
|
||||
def __init__(
|
||||
self,
|
||||
num_workers: int,
|
||||
worker: Type[Worker],
|
||||
start_method: Optional[str] = None,
|
||||
max_internal_batch_size: int = MAX_INTERNAL_BATCH_SIZE,
|
||||
):
|
||||
self.worker_class = worker
|
||||
self.num_workers = num_workers
|
||||
self.input_queue: Optional[Queue] = None
|
||||
self.output_queue: Optional[Queue] = None
|
||||
self.ctx: BaseContext = get_context(start_method)
|
||||
self.processes: list[BaseProcess] = []
|
||||
self.queue_size = self.num_workers * max_internal_batch_size
|
||||
self.emergency_shutdown = False
|
||||
self.num_active_workers: Optional[BaseValue] = None
|
||||
|
||||
def start(self, **kwargs: Any) -> None:
|
||||
self.input_queue = self.ctx.Queue(self.queue_size)
|
||||
self.output_queue = self.ctx.Queue(self.queue_size)
|
||||
|
||||
ctx_value = self.ctx.Value("i", self.num_workers)
|
||||
assert isinstance(ctx_value, BaseValue)
|
||||
self.num_active_workers = ctx_value
|
||||
|
||||
for worker_id in range(0, self.num_workers):
|
||||
assert hasattr(self.ctx, "Process")
|
||||
process = self.ctx.Process(
|
||||
target=_worker,
|
||||
args=(
|
||||
self.worker_class,
|
||||
self.input_queue,
|
||||
self.output_queue,
|
||||
self.num_active_workers,
|
||||
worker_id,
|
||||
kwargs.copy(),
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
|
||||
def unordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
|
||||
try:
|
||||
self.start(**kwargs)
|
||||
|
||||
assert self.input_queue is not None, "Input queue was not initialized"
|
||||
assert self.output_queue is not None, "Output queue was not initialized"
|
||||
|
||||
pushed = 0
|
||||
read = 0
|
||||
for item in stream:
|
||||
self.check_worker_health()
|
||||
if pushed - read < self.queue_size:
|
||||
try:
|
||||
out_item = self.output_queue.get_nowait()
|
||||
except Empty:
|
||||
out_item = None
|
||||
else:
|
||||
try:
|
||||
out_item = self.output_queue.get(timeout=processing_timeout)
|
||||
except Empty as e:
|
||||
self.join_or_terminate()
|
||||
raise e
|
||||
|
||||
if out_item is not None:
|
||||
if out_item == QueueSignals.error:
|
||||
self.join_or_terminate()
|
||||
raise RuntimeError("Thread unexpectedly terminated")
|
||||
yield out_item
|
||||
read += 1
|
||||
self.input_queue.put(item)
|
||||
pushed += 1
|
||||
|
||||
for _ in range(self.num_workers):
|
||||
self.input_queue.put(QueueSignals.stop)
|
||||
|
||||
while read < pushed:
|
||||
out_item = self.output_queue.get(timeout=processing_timeout)
|
||||
if out_item == QueueSignals.error:
|
||||
self.join_or_terminate()
|
||||
raise RuntimeError("Thread unexpectedly terminated")
|
||||
yield out_item
|
||||
read += 1
|
||||
finally:
|
||||
assert self.input_queue is not None, "Input queue is None"
|
||||
assert self.output_queue is not None, "Output queue is None"
|
||||
self.join()
|
||||
self.input_queue.close()
|
||||
self.output_queue.close()
|
||||
if self.emergency_shutdown:
|
||||
self.input_queue.cancel_join_thread()
|
||||
self.output_queue.cancel_join_thread()
|
||||
else:
|
||||
self.input_queue.join_thread()
|
||||
self.output_queue.join_thread()
|
||||
|
||||
def semi_ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
|
||||
return self.unordered_map(enumerate(stream), *args, **kwargs)
|
||||
|
||||
def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
|
||||
buffer = defaultdict(int)
|
||||
next_expected = 0
|
||||
|
||||
for idx, item in self.semi_ordered_map(stream, *args, **kwargs):
|
||||
buffer[idx] = item
|
||||
while next_expected in buffer:
|
||||
yield buffer.pop(next_expected)
|
||||
next_expected += 1
|
||||
|
||||
def check_worker_health(self) -> None:
|
||||
"""
|
||||
Checks if any worker process has terminated unexpectedly
|
||||
"""
|
||||
for process in self.processes:
|
||||
if not process.is_alive() and process.exitcode != 0:
|
||||
self.emergency_shutdown = True
|
||||
self.join_or_terminate()
|
||||
raise RuntimeError(
|
||||
f"Worker PID: {process.pid} terminated unexpectedly with code {process.exitcode}"
|
||||
)
|
||||
|
||||
def join_or_terminate(self, timeout: Optional[int] = 1) -> None:
|
||||
"""
|
||||
Emergency shutdown
|
||||
@param timeout:
|
||||
@return:
|
||||
"""
|
||||
self.emergency_shutdown = True
|
||||
for process in self.processes:
|
||||
process.join(timeout=timeout)
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
self.processes.clear()
|
||||
|
||||
def join(self) -> None:
|
||||
for process in self.processes:
|
||||
process.join()
|
||||
self.processes.clear()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Terminate processes if the user hasn't joined. This is necessary as
|
||||
leaving stray processes running can corrupt shared state. In brief,
|
||||
we've observed shared memory counters being reused (when the memory was
|
||||
free from the perspective of the parent process) while the stray
|
||||
workers still held a reference to them.
|
||||
For a discussion of using destructors in Python in this manner, see
|
||||
https://eli.thegreenplace.net/2009/06/12/safely-using-destructors-in-python/.
|
||||
"""
|
||||
for process in self.processes:
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user