import importlib.metadata import logging import math import platform from multiprocessing import get_all_start_methods from typing import ( Any, Awaitable, Callable, Iterable, Mapping, Optional, Sequence, Type, Union, get_args, ) import httpx from grpc import Compression from urllib3.util import Url, parse_url from urllib.parse import urljoin from qdrant_client.common.client_warnings import show_warning, show_warning_once from qdrant_client import grpc as grpc from qdrant_client._pydantic_compat import construct from qdrant_client.auth import BearerAuth from qdrant_client.client_base import QdrantBase from qdrant_client.common.version_check import is_compatible, get_server_version from qdrant_client.connection import get_channel from qdrant_client.conversions import common_types as types from qdrant_client.conversions.common_types import get_args_subscribed from qdrant_client.conversions.conversion import ( GrpcToRest, RestToGrpc, grpc_payload_schema_to_field_type, ) from qdrant_client.http import ApiClient, SyncApis, models from qdrant_client.parallel_processor import ParallelWorkerPool from qdrant_client.uploader.grpc_uploader import GrpcBatchUploader from qdrant_client.uploader.rest_uploader import RestBatchUploader from qdrant_client.uploader.uploader import BaseUploader class QdrantRemote(QdrantBase): DEFAULT_GRPC_TIMEOUT = 5 # seconds DEFAULT_GRPC_POOL_SIZE = 3 def __init__( self, url: Optional[str] = None, port: Optional[int] = 6333, grpc_port: int = 6334, prefer_grpc: bool = False, https: Optional[bool] = None, api_key: Optional[str] = None, prefix: Optional[str] = None, timeout: Optional[int] = None, host: Optional[str] = None, grpc_options: Optional[dict[str, Any]] = None, auth_token_provider: Optional[ Union[Callable[[], str], Callable[[], Awaitable[str]]] ] = None, check_compatibility: bool = True, pool_size: Optional[int] = None, **kwargs: Any, ): super().__init__(**kwargs) self._prefer_grpc = prefer_grpc self._grpc_port = grpc_port self._grpc_options = grpc_options or {} self._https = https if https is not None else api_key is not None self._scheme = "https" if self._https else "http" # Pool size to use. This value should not be accessed directly; use _get_grpc_pool_size() instead. self._pool_size: Optional[int] = None if pool_size is not None: pool_size = max(1, pool_size) # Ensure pool_size is always > 0 self._pool_size = pool_size self._prefix = prefix or "" if len(self._prefix) > 0 and self._prefix[0] != "/": self._prefix = f"/{self._prefix}" if url is not None and host is not None: raise ValueError(f"Only one of (url, host) can be set. url is {url}, host is {host}") if host is not None and (host.startswith("http://") or host.startswith("https://")): raise ValueError( f"`host` param is not expected to contain protocol (http:// or https://). " f"Try to use `url` parameter instead." ) elif url: if url.startswith("localhost"): # Handle for a special case when url is localhost:port # Which is not parsed correctly by urllib url = f"//{url}" parsed_url: Url = parse_url(url) self._host, self._port = parsed_url.host, parsed_url.port if parsed_url.scheme: self._https = parsed_url.scheme == "https" self._scheme = parsed_url.scheme self._port = self._port if self._port else port if self._prefix and parsed_url.path: raise ValueError( "Prefix can be set either in `url` or in `prefix`. " f"url is {url}, prefix is {parsed_url.path}" ) elif parsed_url.path: self._prefix = parsed_url.path if self._scheme not in ("http", "https"): raise ValueError(f"Unknown scheme: {self._scheme}") else: self._host = host or "localhost" self._port = port _timeout = ( math.ceil(timeout) if timeout is not None else None ) # it has been changed from float to int. # convert it to the closest greater or equal int value (e.g. 0.5 -> 1) self._api_key = api_key self._auth_token_provider = auth_token_provider limits = kwargs.pop("limits", None) if limits is None: if self._host in ["localhost", "127.0.0.1"]: # Disable keep-alive for local connections # Cause in some cases, it may cause extra delays limits = httpx.Limits(max_connections=None, max_keepalive_connections=0) elif self._pool_size is not None: # Set http connection pooling to `self._pool_size`, if no limits are specified. limits = httpx.Limits(max_connections=self._pool_size) elif self._pool_size is not None: raise ValueError( "`pool_size` and `limits` are mutually exclusive. " f"`pool_size`: {pool_size}, `limit`: {limits}" ) http2 = kwargs.pop("http2", False) self._grpc_headers = [] self._rest_headers = {k: v for k, v in kwargs.pop("metadata", {}).items()} if api_key is not None: if self._scheme == "http": show_warning( message="Api key is used with an insecure connection.", category=UserWarning, stacklevel=4, ) # http2 = True self._rest_headers["api-key"] = api_key self._grpc_headers.append(("api-key", api_key)) client_version = importlib.metadata.version("qdrant-client") python_version = platform.python_version() user_agent = f"python-client/{client_version} python/{python_version}" self._rest_headers["User-Agent"] = user_agent self._grpc_options["grpc.primary_user_agent"] = user_agent # GRPC Channel-Level Compression grpc_compression: Optional[Compression] = kwargs.pop("grpc_compression", None) if grpc_compression is not None and not isinstance(grpc_compression, Compression): raise TypeError( f"Expected 'grpc_compression' to be of type " f"grpc.Compression or None, but got {type(grpc_compression)}" ) if grpc_compression == Compression.Deflate: raise ValueError( "grpc.Compression.Deflate is not supported. Try grpc.Compression.Gzip or grpc.Compression.NoCompression" ) self._grpc_compression = grpc_compression address = f"{self._host}:{self._port}" if self._port is not None else self._host base_url = f"{self._scheme}://{address}" self.rest_uri = urljoin(base_url, self._prefix) self._rest_args = {"headers": self._rest_headers, "http2": http2, **kwargs} if limits is not None: self._rest_args["limits"] = limits if _timeout is not None: self._rest_args["timeout"] = _timeout self._timeout = _timeout else: self._timeout = self.DEFAULT_GRPC_TIMEOUT if self._auth_token_provider is not None: if self._scheme == "http": show_warning( message="Auth token provider is used with an insecure connection.", category=UserWarning, stacklevel=4, ) bearer_auth = BearerAuth(self._auth_token_provider) self._rest_args["auth"] = bearer_auth self.openapi_client: SyncApis[ApiClient] = SyncApis( host=self.rest_uri, **self._rest_args, ) self._grpc_channel_pool: list[grpc.Channel] = [] self._grpc_points_client_pool: Optional[list[grpc.PointsStub]] = None self._grpc_collections_client_pool: Optional[list[grpc.CollectionsStub]] = None self._grpc_snapshots_client_pool: Optional[list[grpc.SnapshotsStub]] = None self._grpc_root_client_pool: Optional[list[grpc.QdrantStub]] = None self._grpc_client_next_index: int = 0 # The next index to use self._aio_grpc_points_client: Optional[grpc.PointsStub] = None self._aio_grpc_collections_client: Optional[grpc.CollectionsStub] = None self._aio_grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None self._aio_grpc_root_client: Optional[grpc.QdrantStub] = None self._closed: bool = False self.server_version = None if check_compatibility: try: client_version = importlib.metadata.version("qdrant-client") self.server_version = get_server_version( self.rest_uri, self._rest_headers, self._rest_args.get("auth") ) if not self.server_version: show_warning( message="Failed to obtain server version. Unable to check client-server compatibility." " Set check_compatibility=False to skip version check.", category=UserWarning, stacklevel=4, ) elif not is_compatible(client_version, self.server_version): show_warning( message=f"Qdrant client version {client_version} is incompatible with server " f"version {self.server_version}. Major versions should match and minor version difference " "must not exceed 1. Set check_compatibility=False to skip version check.", category=UserWarning, stacklevel=4, ) except Exception as er: logging.debug( f"Unable to get server version: {er}, server version defaults to None" ) @property def closed(self) -> bool: return self._closed def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None: if hasattr(self, "_grpc_channel_pool") and len(self._grpc_channel_pool) > 0: for channel in self._grpc_channel_pool: try: channel.close() except AttributeError: show_warning( message="Unable to close grpc_channel. Connection was interrupted on the server side", category=RuntimeWarning, stacklevel=4, ) try: self.openapi_client.close() except Exception: show_warning( message="Unable to close http connection. Connection was interrupted on the server side", category=RuntimeWarning, stacklevel=4, ) self._closed = True @staticmethod def _parse_url(url: str) -> tuple[Optional[str], str, Optional[int], Optional[str]]: parse_result: Url = parse_url(url) scheme, host, port, prefix = ( parse_result.scheme, parse_result.host, parse_result.port, parse_result.path, ) return scheme, host, port, prefix def _get_grpc_pool_size(self) -> int: """ Returns the pool size to use for GRPC connection pool. This method should be preferred over accessing `self._pool_size` directly as it applies the default value if no pool_size was provided. """ if self._pool_size is not None: return self._pool_size else: return self.DEFAULT_GRPC_POOL_SIZE def _init_grpc_channel(self) -> None: if self._closed: raise RuntimeError("Client was closed. Please create a new QdrantClient instance.") try: channel_pool = [] if len(self._grpc_channel_pool) == 0: for _ in range(self._get_grpc_pool_size()): channel = get_channel( host=self._host, port=self._grpc_port, ssl=self._https, metadata=self._grpc_headers, options=self._grpc_options, compression=self._grpc_compression, # sync get_channel does not accept coroutine functions, # but we can't check type here, since it'll get into async client as well auth_token_provider=self._auth_token_provider, # type: ignore ) channel_pool.append(channel) # Apply the clients late to prevent half-initialized pools if a channel creation fails. self._grpc_channel_pool = channel_pool except Exception as e: raise RuntimeError(f"Error initializing the grpc connection(s): {e}") def _init_grpc_points_client(self) -> None: self._init_grpc_channel() self._grpc_points_client_pool = [ grpc.PointsStub(channel) for channel in self._grpc_channel_pool ] def _init_grpc_collections_client(self) -> None: self._init_grpc_channel() self._grpc_collections_client_pool = [ grpc.CollectionsStub(channel) for channel in self._grpc_channel_pool ] def _init_grpc_snapshots_client(self) -> None: self._init_grpc_channel() self._grpc_snapshots_client_pool = [ grpc.SnapshotsStub(channel) for channel in self._grpc_channel_pool ] def _init_grpc_root_client(self) -> None: self._init_grpc_channel() self._grpc_root_client_pool = [ grpc.QdrantStub(channel) for channel in self._grpc_channel_pool ] def _next_grpc_client(self) -> int: current_index = self._grpc_client_next_index self._grpc_client_next_index = ( self._grpc_client_next_index + 1 ) % self._get_grpc_pool_size() return current_index @property def grpc_collections(self) -> grpc.CollectionsStub: """gRPC client for collections methods Returns: An instance of raw gRPC client, generated from Protobuf """ if self._grpc_collections_client_pool is None: self._init_grpc_collections_client() assert self._grpc_collections_client_pool is not None return self._grpc_collections_client_pool[self._next_grpc_client()] @property def grpc_points(self) -> grpc.PointsStub: """gRPC client for points methods Returns: An instance of raw gRPC client, generated from Protobuf """ if self._grpc_points_client_pool is None: self._init_grpc_points_client() assert self._grpc_points_client_pool is not None return self._grpc_points_client_pool[self._next_grpc_client()] @property def grpc_snapshots(self) -> grpc.SnapshotsStub: """gRPC client for snapshots methods Returns: An instance of raw gRPC client, generated from Protobuf """ if self._grpc_snapshots_client_pool is None: self._init_grpc_snapshots_client() assert self._grpc_snapshots_client_pool is not None return self._grpc_snapshots_client_pool[self._next_grpc_client()] @property def grpc_root(self) -> grpc.QdrantStub: """gRPC client for info methods Returns: An instance of raw gRPC client, generated from Protobuf """ if self._grpc_root_client_pool is None: self._init_grpc_root_client() assert self._grpc_root_client_pool is not None return self._grpc_root_client_pool[self._next_grpc_client()] @property def rest(self) -> SyncApis[ApiClient]: """REST Client Returns: An instance of raw REST API client, generated from OpenAPI schema """ return self.openapi_client @property def http(self) -> SyncApis[ApiClient]: """REST Client Returns: An instance of raw REST API client, generated from OpenAPI schema """ return self.openapi_client def query_points( self, collection_name: str, query: Union[ types.PointId, list[float], list[list[float]], types.SparseVector, types.Query, types.NumpyArray, types.Document, types.Image, types.InferenceObject, None, ] = None, using: Optional[str] = None, prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None, query_filter: Optional[types.Filter] = None, search_params: Optional[types.SearchParams] = None, limit: int = 10, offset: Optional[int] = None, with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True, with_vectors: Union[bool, Sequence[str]] = False, score_threshold: Optional[float] = None, lookup_from: Optional[types.LookupLocation] = None, consistency: Optional[types.ReadConsistency] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> types.QueryResponse: if self._prefer_grpc: if query is not None: query = RestToGrpc.convert_query(query) if isinstance(prefetch, models.Prefetch): prefetch = [RestToGrpc.convert_prefetch_query(prefetch)] if isinstance(prefetch, list): prefetch = [ RestToGrpc.convert_prefetch_query(p) if isinstance(p, models.Prefetch) else p for p in prefetch ] if isinstance(query_filter, models.Filter): query_filter = RestToGrpc.convert_filter(model=query_filter) if isinstance(search_params, models.SearchParams): search_params = RestToGrpc.convert_search_params(search_params) if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)): with_payload = RestToGrpc.convert_with_payload_interface(with_payload) if isinstance(with_vectors, get_args_subscribed(models.WithVector)): with_vectors = RestToGrpc.convert_with_vectors(with_vectors) if isinstance(lookup_from, models.LookupLocation): lookup_from = RestToGrpc.convert_lookup_location(lookup_from) if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) res: grpc.QueryResponse = self.grpc_points.Query( grpc.QueryPoints( collection_name=collection_name, query=query, prefetch=prefetch, filter=query_filter, limit=limit, offset=offset, with_vectors=with_vectors, with_payload=with_payload, params=search_params, score_threshold=score_threshold, using=using, lookup_from=lookup_from, timeout=timeout, shard_key_selector=shard_key_selector, read_consistency=consistency, ), timeout=timeout if timeout is not None else self._timeout, ) scored_points = [GrpcToRest.convert_scored_point(hit) for hit in res.result] return models.QueryResponse(points=scored_points) else: if isinstance(query_filter, grpc.Filter): query_filter = GrpcToRest.convert_filter(model=query_filter) if isinstance(search_params, grpc.SearchParams): search_params = GrpcToRest.convert_search_params(search_params) if isinstance(with_payload, grpc.WithPayloadSelector): with_payload = GrpcToRest.convert_with_payload_selector(with_payload) if isinstance(lookup_from, grpc.LookupLocation): lookup_from = GrpcToRest.convert_lookup_location(lookup_from) query_request = models.QueryRequest( shard_key=shard_key_selector, prefetch=prefetch, query=query, using=using, filter=query_filter, params=search_params, score_threshold=score_threshold, limit=limit, offset=offset, with_vector=with_vectors, with_payload=with_payload, lookup_from=lookup_from, ) query_result = self.http.search_api.query_points( collection_name=collection_name, consistency=consistency, timeout=timeout, query_request=query_request, ) result: Optional[models.QueryResponse] = query_result.result assert result is not None, "Search returned None" return result def query_batch_points( self, collection_name: str, requests: Sequence[types.QueryRequest], consistency: Optional[types.ReadConsistency] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> list[types.QueryResponse]: if self._prefer_grpc: requests = [ ( RestToGrpc.convert_query_request(r, collection_name) if isinstance(r, models.QueryRequest) else r ) for r in requests ] if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) grpc_res: grpc.QueryBatchResponse = self.grpc_points.QueryBatch( grpc.QueryBatchPoints( collection_name=collection_name, query_points=requests, read_consistency=consistency, timeout=timeout, ), timeout=timeout if timeout is not None else self._timeout, ) return [ models.QueryResponse( points=[GrpcToRest.convert_scored_point(hit) for hit in r.result] ) for r in grpc_res.result ] else: http_res: Optional[list[models.QueryResponse]] = ( self.http.search_api.query_batch_points( collection_name=collection_name, consistency=consistency, timeout=timeout, query_request_batch=models.QueryRequestBatch(searches=requests), ).result ) assert http_res is not None, "Query batch returned None" return http_res def query_points_groups( self, collection_name: str, group_by: str, query: Union[ types.PointId, list[float], list[list[float]], types.SparseVector, types.Query, types.NumpyArray, types.Document, types.Image, types.InferenceObject, None, ] = None, using: Optional[str] = None, prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None, query_filter: Optional[types.Filter] = None, search_params: Optional[types.SearchParams] = None, limit: int = 10, group_size: int = 3, with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True, with_vectors: Union[bool, Sequence[str]] = False, score_threshold: Optional[float] = None, with_lookup: Optional[types.WithLookupInterface] = None, lookup_from: Optional[types.LookupLocation] = None, consistency: Optional[types.ReadConsistency] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> types.GroupsResult: if self._prefer_grpc: if query is not None: query = RestToGrpc.convert_query(query) if isinstance(prefetch, models.Prefetch): prefetch = [RestToGrpc.convert_prefetch_query(prefetch)] if isinstance(prefetch, list): prefetch = [ RestToGrpc.convert_prefetch_query(p) if isinstance(p, models.Prefetch) else p for p in prefetch ] if isinstance(query_filter, models.Filter): query_filter = RestToGrpc.convert_filter(model=query_filter) if isinstance(search_params, models.SearchParams): search_params = RestToGrpc.convert_search_params(search_params) if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)): with_payload = RestToGrpc.convert_with_payload_interface(with_payload) if isinstance(with_vectors, get_args_subscribed(models.WithVector)): with_vectors = RestToGrpc.convert_with_vectors(with_vectors) if isinstance(with_lookup, models.WithLookup): with_lookup = RestToGrpc.convert_with_lookup(with_lookup) if isinstance(with_lookup, str): with_lookup = grpc.WithLookup(collection=with_lookup) if isinstance(lookup_from, models.LookupLocation): lookup_from = RestToGrpc.convert_lookup_location(lookup_from) if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) result: grpc.QueryGroupsResponse = self.grpc_points.QueryGroups( grpc.QueryPointGroups( collection_name=collection_name, query=query, prefetch=prefetch, filter=query_filter, limit=limit, with_vectors=with_vectors, with_payload=with_payload, params=search_params, score_threshold=score_threshold, using=using, group_by=group_by, group_size=group_size, with_lookup=with_lookup, lookup_from=lookup_from, timeout=timeout, shard_key_selector=shard_key_selector, read_consistency=consistency, ), timeout=timeout if timeout is not None else self._timeout, ).result return GrpcToRest.convert_groups_result(result) else: if isinstance(query_filter, grpc.Filter): query_filter = GrpcToRest.convert_filter(model=query_filter) if isinstance(search_params, grpc.SearchParams): search_params = GrpcToRest.convert_search_params(search_params) if isinstance(with_payload, grpc.WithPayloadSelector): with_payload = GrpcToRest.convert_with_payload_selector(with_payload) if isinstance(lookup_from, grpc.LookupLocation): lookup_from = GrpcToRest.convert_lookup_location(lookup_from) query_request = models.QueryGroupsRequest( shard_key=shard_key_selector, prefetch=prefetch, query=query, using=using, filter=query_filter, params=search_params, score_threshold=score_threshold, limit=limit, group_by=group_by, group_size=group_size, with_vector=with_vectors, with_payload=with_payload, with_lookup=with_lookup, lookup_from=lookup_from, ) query_result = self.http.search_api.query_points_groups( collection_name=collection_name, consistency=consistency, timeout=timeout, query_groups_request=query_request, ) assert query_result is not None, "Query points groups API returned None" return query_result.result def search_matrix_pairs( self, collection_name: str, query_filter: Optional[types.Filter] = None, limit: int = 3, sample: int = 10, using: Optional[str] = None, consistency: Optional[types.ReadConsistency] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> types.SearchMatrixPairsResponse: if self._prefer_grpc: if isinstance(query_filter, models.Filter): query_filter = RestToGrpc.convert_filter(model=query_filter) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) response = self.grpc_points.SearchMatrixPairs( grpc.SearchMatrixPoints( collection_name=collection_name, filter=query_filter, sample=sample, limit=limit, using=using, timeout=timeout, read_consistency=consistency, shard_key_selector=shard_key_selector, ), timeout=timeout if timeout is not None else self._timeout, ) return GrpcToRest.convert_search_matrix_pairs(response.result) if isinstance(query_filter, grpc.Filter): query_filter = GrpcToRest.convert_filter(model=query_filter) search_matrix_result = self.openapi_client.search_api.search_matrix_pairs( collection_name=collection_name, consistency=consistency, timeout=timeout, search_matrix_request=models.SearchMatrixRequest( shard_key=shard_key_selector, limit=limit, sample=sample, using=using, filter=query_filter, ), ).result assert search_matrix_result is not None, "Search matrix pairs returned None result" return search_matrix_result def search_matrix_offsets( self, collection_name: str, query_filter: Optional[types.Filter] = None, limit: int = 3, sample: int = 10, using: Optional[str] = None, consistency: Optional[types.ReadConsistency] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> types.SearchMatrixOffsetsResponse: if self._prefer_grpc: if isinstance(query_filter, models.Filter): query_filter = RestToGrpc.convert_filter(model=query_filter) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) response = self.grpc_points.SearchMatrixOffsets( grpc.SearchMatrixPoints( collection_name=collection_name, filter=query_filter, sample=sample, limit=limit, using=using, timeout=timeout, read_consistency=consistency, shard_key_selector=shard_key_selector, ), timeout=timeout if timeout is not None else self._timeout, ) return GrpcToRest.convert_search_matrix_offsets(response.result) if isinstance(query_filter, grpc.Filter): query_filter = GrpcToRest.convert_filter(model=query_filter) search_matrix_result = self.openapi_client.search_api.search_matrix_offsets( collection_name=collection_name, consistency=consistency, timeout=timeout, search_matrix_request=models.SearchMatrixRequest( shard_key=shard_key_selector, limit=limit, sample=sample, using=using, filter=query_filter, ), ).result assert search_matrix_result is not None, "Search matrix offsets returned None result" return search_matrix_result def scroll( self, collection_name: str, scroll_filter: Optional[types.Filter] = None, limit: int = 10, order_by: Optional[types.OrderBy] = None, offset: Optional[types.PointId] = None, with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True, with_vectors: Union[bool, Sequence[str]] = False, consistency: Optional[types.ReadConsistency] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> tuple[list[types.Record], Optional[types.PointId]]: if self._prefer_grpc: if isinstance(offset, get_args_subscribed(models.ExtendedPointId)): offset = RestToGrpc.convert_extended_point_id(offset) if isinstance(scroll_filter, models.Filter): scroll_filter = RestToGrpc.convert_filter(model=scroll_filter) if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)): with_payload = RestToGrpc.convert_with_payload_interface(with_payload) if isinstance(with_vectors, get_args_subscribed(models.WithVector)): with_vectors = RestToGrpc.convert_with_vectors(with_vectors) if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) if isinstance(order_by, get_args_subscribed(models.OrderByInterface)): order_by = RestToGrpc.convert_order_by_interface(order_by) res: grpc.ScrollResponse = self.grpc_points.Scroll( grpc.ScrollPoints( collection_name=collection_name, filter=scroll_filter, order_by=order_by, offset=offset, with_vectors=with_vectors, with_payload=with_payload, limit=limit, read_consistency=consistency, shard_key_selector=shard_key_selector, timeout=timeout, ), timeout=timeout if timeout is not None else self._timeout, ) return [GrpcToRest.convert_retrieved_point(point) for point in res.result], ( GrpcToRest.convert_point_id(res.next_page_offset) if res.HasField("next_page_offset") else None ) else: if isinstance(offset, grpc.PointId): offset = GrpcToRest.convert_point_id(offset) if isinstance(scroll_filter, grpc.Filter): scroll_filter = GrpcToRest.convert_filter(model=scroll_filter) if isinstance(order_by, grpc.OrderBy): order_by = GrpcToRest.convert_order_by(order_by) if isinstance(with_payload, grpc.WithPayloadSelector): with_payload = GrpcToRest.convert_with_payload_selector(with_payload) scroll_result: Optional[models.ScrollResult] = ( self.openapi_client.points_api.scroll_points( collection_name=collection_name, consistency=consistency, scroll_request=models.ScrollRequest( filter=scroll_filter, limit=limit, order_by=order_by, offset=offset, with_payload=with_payload, with_vector=with_vectors, shard_key=shard_key_selector, ), timeout=timeout, ).result ) assert scroll_result is not None, "Scroll points API returned None result" return scroll_result.points, scroll_result.next_page_offset def count( self, collection_name: str, count_filter: Optional[types.Filter] = None, exact: bool = True, shard_key_selector: Optional[types.ShardKeySelector] = None, timeout: Optional[int] = None, consistency: Optional[types.ReadConsistency] = None, **kwargs: Any, ) -> types.CountResult: if self._prefer_grpc: if isinstance(count_filter, models.Filter): count_filter = RestToGrpc.convert_filter(model=count_filter) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) response = self.grpc_points.Count( grpc.CountPoints( collection_name=collection_name, filter=count_filter, exact=exact, shard_key_selector=shard_key_selector, timeout=timeout, read_consistency=consistency, ), timeout=timeout if timeout is not None else self._timeout, ).result return GrpcToRest.convert_count_result(response) if isinstance(count_filter, grpc.Filter): count_filter = GrpcToRest.convert_filter(model=count_filter) count_result = self.openapi_client.points_api.count_points( collection_name=collection_name, count_request=models.CountRequest( filter=count_filter, exact=exact, shard_key=shard_key_selector, ), consistency=consistency, timeout=timeout, ).result assert count_result is not None, "Count points returned None result" return count_result def facet( self, collection_name: str, key: str, facet_filter: Optional[types.Filter] = None, limit: int = 10, exact: bool = False, timeout: Optional[int] = None, consistency: Optional[types.ReadConsistency] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, **kwargs: Any, ) -> types.FacetResponse: if self._prefer_grpc: if isinstance(facet_filter, models.Filter): facet_filter = RestToGrpc.convert_filter(model=facet_filter) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) response = self.grpc_points.Facet( grpc.FacetCounts( collection_name=collection_name, key=key, filter=facet_filter, limit=limit, exact=exact, timeout=timeout, read_consistency=consistency, shard_key_selector=shard_key_selector, ), timeout=timeout if timeout is not None else self._timeout, ) return types.FacetResponse( hits=[GrpcToRest.convert_facet_value_hit(hit) for hit in response.hits] ) if isinstance(facet_filter, grpc.Filter): facet_filter = GrpcToRest.convert_filter(model=facet_filter) facet_result = self.openapi_client.points_api.facet( collection_name=collection_name, consistency=consistency, timeout=timeout, facet_request=models.FacetRequest( shard_key=shard_key_selector, key=key, limit=limit, filter=facet_filter, exact=exact, ), ).result assert facet_result is not None, "Facet points returned None result" return facet_result def upsert( self, collection_name: str, points: types.Points, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, update_filter: Optional[types.Filter] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: if isinstance(points, models.Batch): vectors_batch: list[grpc.Vectors] = RestToGrpc.convert_batch_vector_struct( points.vectors, len(points.ids) ) points = [ grpc.PointStruct( id=RestToGrpc.convert_extended_point_id(points.ids[idx]), vectors=vectors_batch[idx], payload=( RestToGrpc.convert_payload(points.payloads[idx]) if points.payloads is not None else None ), ) for idx in range(len(points.ids)) ] if isinstance(points, list): points = [ ( RestToGrpc.convert_point_struct(point) if isinstance(point, models.PointStruct) else point ) for point in points ] if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) if isinstance(update_filter, models.Filter): update_filter = RestToGrpc.convert_filter(model=update_filter) grpc_result = self.grpc_points.Upsert( grpc.UpsertPoints( collection_name=collection_name, wait=wait, points=points, ordering=ordering, shard_key_selector=shard_key_selector, update_filter=update_filter, ), timeout=self._timeout, ).result assert grpc_result is not None, "Upsert returned None result" return GrpcToRest.convert_update_result(grpc_result) else: if isinstance(update_filter, grpc.Filter): update_filter = GrpcToRest.convert_filter(model=update_filter) if isinstance(points, list): points = [ ( GrpcToRest.convert_point_struct(point) if isinstance(point, grpc.PointStruct) else point ) for point in points ] points = models.PointsList( points=points, shard_key=shard_key_selector, update_filter=update_filter ) if isinstance(points, models.Batch): points = models.PointsBatch( batch=points, shard_key=shard_key_selector, update_filter=update_filter ) http_result = self.openapi_client.points_api.upsert_points( collection_name=collection_name, wait=wait, point_insert_operations=points, ordering=ordering, ).result assert http_result is not None, "Upsert returned None result" return http_result def update_vectors( self, collection_name: str, points: Sequence[types.PointVectors], wait: bool = True, ordering: Optional[types.WriteOrdering] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, update_filter: Optional[types.Filter] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: points = [RestToGrpc.convert_point_vectors(point) for point in points] if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) if isinstance(update_filter, models.Filter): update_filter = RestToGrpc.convert_filter(model=update_filter) grpc_result = self.grpc_points.UpdateVectors( grpc.UpdatePointVectors( collection_name=collection_name, wait=wait, points=points, ordering=ordering, shard_key_selector=shard_key_selector, update_filter=update_filter, ), timeout=self._timeout, ).result assert grpc_result is not None, "Upsert returned None result" return GrpcToRest.convert_update_result(grpc_result) else: if isinstance(update_filter, grpc.Filter): update_filter = GrpcToRest.convert_filter(model=update_filter) return self.openapi_client.points_api.update_vectors( collection_name=collection_name, wait=wait, update_vectors=models.UpdateVectors( points=points, shard_key=shard_key_selector, update_filter=update_filter, ), ordering=ordering, ).result def delete_vectors( self, collection_name: str, vectors: Sequence[str], points: types.PointsSelector, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points) shard_key_selector = shard_key_selector or opt_shard_key_selector if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) grpc_result = self.grpc_points.DeleteVectors( grpc.DeletePointVectors( collection_name=collection_name, wait=wait, vectors=grpc.VectorsSelector( names=vectors, ), points_selector=points_selector, ordering=ordering, shard_key_selector=shard_key_selector, ), timeout=self._timeout, ).result assert grpc_result is not None, "Delete vectors returned None result" return GrpcToRest.convert_update_result(grpc_result) else: _points, _filter = self._try_argument_to_rest_points_and_filter(points) return self.openapi_client.points_api.delete_vectors( collection_name=collection_name, wait=wait, ordering=ordering, delete_vectors=construct( models.DeleteVectors, vector=vectors, points=_points, filter=_filter, shard_key=shard_key_selector, ), ).result def retrieve( self, collection_name: str, ids: Sequence[types.PointId], with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True, with_vectors: Union[bool, Sequence[str]] = False, consistency: Optional[types.ReadConsistency] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> list[types.Record]: if self._prefer_grpc: if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)): with_payload = RestToGrpc.convert_with_payload_interface(with_payload) ids = [ ( RestToGrpc.convert_extended_point_id(idx) if isinstance(idx, get_args_subscribed(models.ExtendedPointId)) else idx ) for idx in ids ] with_vectors = RestToGrpc.convert_with_vectors(with_vectors) if isinstance(consistency, get_args_subscribed(models.ReadConsistency)): consistency = RestToGrpc.convert_read_consistency(consistency) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) result = self.grpc_points.Get( grpc.GetPoints( collection_name=collection_name, ids=ids, with_payload=with_payload, with_vectors=with_vectors, read_consistency=consistency, shard_key_selector=shard_key_selector, timeout=timeout, ), timeout=timeout if timeout is not None else self._timeout, ).result assert result is not None, "Retrieve returned None result" return [GrpcToRest.convert_retrieved_point(record) for record in result] else: if isinstance(with_payload, grpc.WithPayloadSelector): with_payload = GrpcToRest.convert_with_payload_selector(with_payload) ids = [ (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx) for idx in ids ] http_result = self.openapi_client.points_api.get_points( collection_name=collection_name, consistency=consistency, point_request=models.PointRequest( ids=ids, with_payload=with_payload, with_vector=with_vectors, shard_key=shard_key_selector, ), timeout=timeout, ).result assert http_result is not None, "Retrieve API returned None result" return http_result @classmethod def _try_argument_to_grpc_selector( cls, points: types.PointsSelector ) -> tuple[grpc.PointsSelector, Optional[grpc.ShardKeySelector]]: shard_key_selector = None if isinstance(points, list): points_selector = grpc.PointsSelector( points=grpc.PointsIdsList( ids=[ ( RestToGrpc.convert_extended_point_id(idx) if isinstance(idx, get_args_subscribed(models.ExtendedPointId)) else idx ) for idx in points ] ) ) elif isinstance(points, grpc.PointsSelector): points_selector = points elif isinstance(points, get_args(models.PointsSelector)): if points.shard_key is not None: shard_key_selector = RestToGrpc.convert_shard_key_selector(points.shard_key) points_selector = RestToGrpc.convert_points_selector(points) elif isinstance(points, models.Filter): points_selector = RestToGrpc.convert_points_selector( construct(models.FilterSelector, filter=points) ) elif isinstance(points, grpc.Filter): points_selector = grpc.PointsSelector(filter=points) else: raise ValueError(f"Unsupported points selector type: {type(points)}") return points_selector, shard_key_selector @classmethod def _try_argument_to_rest_selector( cls, points: types.PointsSelector, shard_key_selector: Optional[types.ShardKeySelector], ) -> models.PointsSelector: if isinstance(points, list): _points = [ (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx) for idx in points ] points_selector = construct( models.PointIdsList, points=_points, shard_key=shard_key_selector, ) elif isinstance(points, grpc.PointsSelector): points_selector = GrpcToRest.convert_points_selector(points) points_selector.shard_key = shard_key_selector elif isinstance(points, get_args(models.PointsSelector)): points_selector = points points_selector.shard_key = shard_key_selector elif isinstance(points, models.Filter): points_selector = construct( models.FilterSelector, filter=points, shard_key=shard_key_selector ) elif isinstance(points, grpc.Filter): points_selector = construct( models.FilterSelector, filter=GrpcToRest.convert_filter(points), shard_key=shard_key_selector, ) else: raise ValueError(f"Unsupported points selector type: {type(points)}") return points_selector @classmethod def _points_selector_to_points_list( cls, points_selector: grpc.PointsSelector ) -> list[grpc.PointId]: name = points_selector.WhichOneof("points_selector_one_of") if name is None: return [] val = getattr(points_selector, name) if name == "points": return list(val.ids) return [] @classmethod def _try_argument_to_rest_points_and_filter( cls, points: types.PointsSelector ) -> tuple[Optional[list[models.ExtendedPointId]], Optional[models.Filter]]: _points = None _filter = None if isinstance(points, list): _points = [ (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx) for idx in points ] elif isinstance(points, grpc.PointsSelector): selector = GrpcToRest.convert_points_selector(points) if isinstance(selector, models.PointIdsList): _points = selector.points elif isinstance(selector, models.FilterSelector): _filter = selector.filter elif isinstance(points, models.PointIdsList): _points = points.points elif isinstance(points, models.FilterSelector): _filter = points.filter elif isinstance(points, models.Filter): _filter = points elif isinstance(points, grpc.Filter): _filter = GrpcToRest.convert_filter(points) else: raise ValueError(f"Unsupported points selector type: {type(points)}") return _points, _filter def delete( self, collection_name: str, points_selector: types.PointsSelector, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector( points_selector ) shard_key_selector = shard_key_selector or opt_shard_key_selector if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) return GrpcToRest.convert_update_result( self.grpc_points.Delete( grpc.DeletePoints( collection_name=collection_name, wait=wait, points=points_selector, ordering=ordering, shard_key_selector=shard_key_selector, ), timeout=self._timeout, ).result ) else: points_selector = self._try_argument_to_rest_selector( points_selector, shard_key_selector ) result: Optional[types.UpdateResult] = self.openapi_client.points_api.delete_points( collection_name=collection_name, wait=wait, points_selector=points_selector, ordering=ordering, ).result assert result is not None, "Delete points returned None" return result def set_payload( self, collection_name: str, payload: types.Payload, points: types.PointsSelector, key: Optional[str] = None, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points) shard_key_selector = shard_key_selector or opt_shard_key_selector if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) return GrpcToRest.convert_update_result( self.grpc_points.SetPayload( grpc.SetPayloadPoints( collection_name=collection_name, wait=wait, payload=RestToGrpc.convert_payload(payload), points_selector=points_selector, ordering=ordering, shard_key_selector=shard_key_selector, key=key, ), timeout=self._timeout, ).result ) else: _points, _filter = self._try_argument_to_rest_points_and_filter(points) result: Optional[types.UpdateResult] = self.openapi_client.points_api.set_payload( collection_name=collection_name, wait=wait, ordering=ordering, set_payload=models.SetPayload( payload=payload, points=_points, filter=_filter, shard_key=shard_key_selector, key=key, ), ).result assert result is not None, "Set payload returned None" return result def overwrite_payload( self, collection_name: str, payload: types.Payload, points: types.PointsSelector, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points) shard_key_selector = shard_key_selector or opt_shard_key_selector if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) return GrpcToRest.convert_update_result( self.grpc_points.OverwritePayload( grpc.SetPayloadPoints( collection_name=collection_name, wait=wait, payload=RestToGrpc.convert_payload(payload), points_selector=points_selector, ordering=ordering, shard_key_selector=shard_key_selector, ), timeout=self._timeout, ).result ) else: _points, _filter = self._try_argument_to_rest_points_and_filter(points) result: Optional[types.UpdateResult] = ( self.openapi_client.points_api.overwrite_payload( collection_name=collection_name, wait=wait, ordering=ordering, set_payload=models.SetPayload( payload=payload, points=_points, filter=_filter, shard_key=shard_key_selector, ), ).result ) assert result is not None, "Overwrite payload returned None" return result def delete_payload( self, collection_name: str, keys: Sequence[str], points: types.PointsSelector, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points) shard_key_selector = shard_key_selector or opt_shard_key_selector if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) return GrpcToRest.convert_update_result( self.grpc_points.DeletePayload( grpc.DeletePayloadPoints( collection_name=collection_name, wait=wait, keys=keys, points_selector=points_selector, ordering=ordering, shard_key_selector=shard_key_selector, ), timeout=self._timeout, ).result ) else: _points, _filter = self._try_argument_to_rest_points_and_filter(points) result: Optional[types.UpdateResult] = self.openapi_client.points_api.delete_payload( collection_name=collection_name, wait=wait, ordering=ordering, delete_payload=models.DeletePayload( keys=keys, points=_points, filter=_filter, shard_key=shard_key_selector, ), ).result assert result is not None, "Delete payload returned None" return result def clear_payload( self, collection_name: str, points_selector: types.PointsSelector, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, shard_key_selector: Optional[types.ShardKeySelector] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector( points_selector ) shard_key_selector = shard_key_selector or opt_shard_key_selector if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)): shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector) return GrpcToRest.convert_update_result( self.grpc_points.ClearPayload( grpc.ClearPayloadPoints( collection_name=collection_name, wait=wait, points=points_selector, ordering=ordering, shard_key_selector=shard_key_selector, ), timeout=self._timeout, ).result ) else: points_selector = self._try_argument_to_rest_selector( points_selector, shard_key_selector ) result: Optional[types.UpdateResult] = self.openapi_client.points_api.clear_payload( collection_name=collection_name, wait=wait, ordering=ordering, points_selector=points_selector, ).result assert result is not None, "Clear payload returned None" return result def batch_update_points( self, collection_name: str, update_operations: Sequence[types.UpdateOperation], wait: bool = True, ordering: Optional[types.WriteOrdering] = None, **kwargs: Any, ) -> list[types.UpdateResult]: if self._prefer_grpc: update_operations = [ RestToGrpc.convert_update_operation(operation) for operation in update_operations ] if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) return [ GrpcToRest.convert_update_result(result) for result in self.grpc_points.UpdateBatch( grpc.UpdateBatchPoints( collection_name=collection_name, wait=wait, operations=update_operations, ordering=ordering, ), timeout=self._timeout, ).result ] else: result: Optional[list[types.UpdateResult]] = ( self.openapi_client.points_api.batch_update( collection_name=collection_name, wait=wait, ordering=ordering, update_operations=models.UpdateOperations(operations=update_operations), ).result ) assert result is not None, "Batch update points returned None" return result def update_collection_aliases( self, change_aliases_operations: Sequence[types.AliasOperations], timeout: Optional[int] = None, **kwargs: Any, ) -> bool: if self._prefer_grpc: change_aliases_operation = [ ( RestToGrpc.convert_alias_operations(operation) if not isinstance(operation, grpc.AliasOperations) else operation ) for operation in change_aliases_operations ] return self.grpc_collections.UpdateAliases( grpc.ChangeAliases( timeout=timeout, actions=change_aliases_operation, ), timeout=timeout if timeout is not None else self._timeout, ).result change_aliases_operation = [ ( GrpcToRest.convert_alias_operations(operation) if isinstance(operation, grpc.AliasOperations) else operation ) for operation in change_aliases_operations ] result: Optional[bool] = self.http.aliases_api.update_aliases( timeout=timeout, change_aliases_operation=models.ChangeAliasesOperation( actions=change_aliases_operation ), ).result assert result is not None, "Update aliases returned None" return result def get_collection_aliases( self, collection_name: str, **kwargs: Any ) -> types.CollectionsAliasesResponse: if self._prefer_grpc: response = self.grpc_collections.ListCollectionAliases( grpc.ListCollectionAliasesRequest(collection_name=collection_name), timeout=self._timeout, ).aliases return types.CollectionsAliasesResponse( aliases=[ GrpcToRest.convert_alias_description(description) for description in response ] ) result: Optional[types.CollectionsAliasesResponse] = ( self.http.aliases_api.get_collection_aliases(collection_name=collection_name).result ) assert result is not None, "Get collection aliases returned None" return result def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse: if self._prefer_grpc: response = self.grpc_collections.ListAliases( grpc.ListAliasesRequest(), timeout=self._timeout ).aliases return types.CollectionsAliasesResponse( aliases=[ GrpcToRest.convert_alias_description(description) for description in response ] ) result: Optional[types.CollectionsAliasesResponse] = ( self.http.aliases_api.get_collections_aliases().result ) assert result is not None, "Get aliases returned None" return result def get_collections(self, **kwargs: Any) -> types.CollectionsResponse: if self._prefer_grpc: response = self.grpc_collections.List( grpc.ListCollectionsRequest(), timeout=self._timeout ).collections return types.CollectionsResponse( collections=[ GrpcToRest.convert_collection_description(description) for description in response ] ) result: Optional[types.CollectionsResponse] = ( self.http.collections_api.get_collections().result ) assert result is not None, "Get collections returned None" return result def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo: if self._prefer_grpc: return GrpcToRest.convert_collection_info( self.grpc_collections.Get( grpc.GetCollectionInfoRequest(collection_name=collection_name), timeout=self._timeout, ).result ) result: Optional[types.CollectionInfo] = self.http.collections_api.get_collection( collection_name=collection_name ).result assert result is not None, "Get collection returned None" return result def collection_exists(self, collection_name: str, **kwargs: Any) -> bool: if self._prefer_grpc: return self.grpc_collections.CollectionExists( grpc.CollectionExistsRequest(collection_name=collection_name), timeout=self._timeout, ).result.exists result: Optional[models.CollectionExistence] = self.http.collections_api.collection_exists( collection_name=collection_name ).result assert result is not None, "Collection exists returned None" return result.exists def update_collection( self, collection_name: str, optimizers_config: Optional[types.OptimizersConfigDiff] = None, collection_params: Optional[types.CollectionParamsDiff] = None, vectors_config: Optional[types.VectorsConfigDiff] = None, hnsw_config: Optional[types.HnswConfigDiff] = None, quantization_config: Optional[types.QuantizationConfigDiff] = None, timeout: Optional[int] = None, sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None, strict_mode_config: Optional[types.StrictModeConfig] = None, metadata: Optional[types.Payload] = None, **kwargs: Any, ) -> bool: if self._prefer_grpc: if isinstance(optimizers_config, models.OptimizersConfigDiff): optimizers_config = RestToGrpc.convert_optimizers_config_diff(optimizers_config) if isinstance(collection_params, models.CollectionParamsDiff): collection_params = RestToGrpc.convert_collection_params_diff(collection_params) if isinstance(vectors_config, dict): vectors_config = RestToGrpc.convert_vectors_config_diff(vectors_config) if isinstance(hnsw_config, models.HnswConfigDiff): hnsw_config = RestToGrpc.convert_hnsw_config_diff(hnsw_config) if isinstance(quantization_config, get_args(models.QuantizationConfigDiff)): quantization_config = RestToGrpc.convert_quantization_config_diff( quantization_config ) if isinstance(sparse_vectors_config, dict): sparse_vectors_config = RestToGrpc.convert_sparse_vector_config( sparse_vectors_config ) if isinstance(strict_mode_config, models.StrictModeConfig): strict_mode_config = RestToGrpc.convert_strict_mode_config(strict_mode_config) if isinstance(metadata, dict): metadata = RestToGrpc.convert_payload(metadata) return self.grpc_collections.Update( grpc.UpdateCollection( collection_name=collection_name, optimizers_config=optimizers_config, params=collection_params, vectors_config=vectors_config, hnsw_config=hnsw_config, quantization_config=quantization_config, sparse_vectors_config=sparse_vectors_config, strict_mode_config=strict_mode_config, timeout=timeout, metadata=metadata, ), timeout=timeout if timeout is not None else self._timeout, ).result if isinstance(optimizers_config, grpc.OptimizersConfigDiff): optimizers_config = GrpcToRest.convert_optimizers_config_diff(optimizers_config) if isinstance(collection_params, grpc.CollectionParamsDiff): collection_params = GrpcToRest.convert_collection_params_diff(collection_params) if isinstance(vectors_config, grpc.VectorsConfigDiff): vectors_config = GrpcToRest.convert_vectors_config_diff(vectors_config) if isinstance(hnsw_config, grpc.HnswConfigDiff): hnsw_config = GrpcToRest.convert_hnsw_config_diff(hnsw_config) if isinstance(quantization_config, grpc.QuantizationConfigDiff): quantization_config = GrpcToRest.convert_quantization_config_diff(quantization_config) result: Optional[bool] = self.http.collections_api.update_collection( collection_name, update_collection=models.UpdateCollection( optimizers_config=optimizers_config, params=collection_params, vectors=vectors_config, hnsw_config=hnsw_config, quantization_config=quantization_config, sparse_vectors=sparse_vectors_config, strict_mode_config=strict_mode_config, metadata=metadata, ), timeout=timeout, ).result assert result is not None, "Update collection returned None" return result def delete_collection( self, collection_name: str, timeout: Optional[int] = None, **kwargs: Any ) -> bool: if self._prefer_grpc: return self.grpc_collections.Delete( grpc.DeleteCollection(collection_name=collection_name, timeout=timeout), timeout=timeout if timeout is not None else self._timeout, ).result result: Optional[bool] = self.http.collections_api.delete_collection( collection_name, timeout=timeout ).result assert result is not None, "Delete collection returned None" return result def create_collection( self, collection_name: str, vectors_config: Optional[ Union[types.VectorParams, Mapping[str, types.VectorParams]] ] = None, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None, on_disk_payload: Optional[bool] = None, hnsw_config: Optional[types.HnswConfigDiff] = None, optimizers_config: Optional[types.OptimizersConfigDiff] = None, wal_config: Optional[types.WalConfigDiff] = None, quantization_config: Optional[types.QuantizationConfig] = None, timeout: Optional[int] = None, sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None, sharding_method: Optional[types.ShardingMethod] = None, strict_mode_config: Optional[types.StrictModeConfig] = None, metadata: Optional[types.Payload] = None, **kwargs: Any, ) -> bool: if self._prefer_grpc: if isinstance(vectors_config, (models.VectorParams, dict)): vectors_config = RestToGrpc.convert_vectors_config(vectors_config) if isinstance(hnsw_config, models.HnswConfigDiff): hnsw_config = RestToGrpc.convert_hnsw_config_diff(hnsw_config) if isinstance(optimizers_config, models.OptimizersConfigDiff): optimizers_config = RestToGrpc.convert_optimizers_config_diff(optimizers_config) if isinstance(wal_config, models.WalConfigDiff): wal_config = RestToGrpc.convert_wal_config_diff(wal_config) if isinstance( quantization_config, get_args(models.QuantizationConfig), ): quantization_config = RestToGrpc.convert_quantization_config(quantization_config) if isinstance(sparse_vectors_config, dict): sparse_vectors_config = RestToGrpc.convert_sparse_vector_config( sparse_vectors_config ) if isinstance(sharding_method, models.ShardingMethod): sharding_method = RestToGrpc.convert_sharding_method(sharding_method) if isinstance(strict_mode_config, models.StrictModeConfig): strict_mode_config = RestToGrpc.convert_strict_mode_config(strict_mode_config) if isinstance(metadata, dict): metadata = RestToGrpc.convert_payload(metadata) create_collection = grpc.CreateCollection( collection_name=collection_name, hnsw_config=hnsw_config, wal_config=wal_config, optimizers_config=optimizers_config, shard_number=shard_number, on_disk_payload=on_disk_payload, timeout=timeout, vectors_config=vectors_config, replication_factor=replication_factor, write_consistency_factor=write_consistency_factor, quantization_config=quantization_config, sparse_vectors_config=sparse_vectors_config, sharding_method=sharding_method, strict_mode_config=strict_mode_config, metadata=metadata, ) return self.grpc_collections.Create(create_collection, timeout=self._timeout).result if isinstance(hnsw_config, grpc.HnswConfigDiff): hnsw_config = GrpcToRest.convert_hnsw_config_diff(hnsw_config) if isinstance(optimizers_config, grpc.OptimizersConfigDiff): optimizers_config = GrpcToRest.convert_optimizers_config_diff(optimizers_config) if isinstance(wal_config, grpc.WalConfigDiff): wal_config = GrpcToRest.convert_wal_config_diff(wal_config) if isinstance(quantization_config, grpc.QuantizationConfig): quantization_config = GrpcToRest.convert_quantization_config(quantization_config) create_collection_request = models.CreateCollection( vectors=vectors_config, shard_number=shard_number, replication_factor=replication_factor, write_consistency_factor=write_consistency_factor, on_disk_payload=on_disk_payload, hnsw_config=hnsw_config, optimizers_config=optimizers_config, wal_config=wal_config, quantization_config=quantization_config, sparse_vectors=sparse_vectors_config, sharding_method=sharding_method, strict_mode_config=strict_mode_config, metadata=metadata, ) result: Optional[bool] = self.http.collections_api.create_collection( collection_name=collection_name, create_collection=create_collection_request, timeout=timeout, ).result assert result is not None, "Create collection returned None" return result def recreate_collection( self, collection_name: str, vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]], shard_number: Optional[int] = None, replication_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None, on_disk_payload: Optional[bool] = None, hnsw_config: Optional[types.HnswConfigDiff] = None, optimizers_config: Optional[types.OptimizersConfigDiff] = None, wal_config: Optional[types.WalConfigDiff] = None, quantization_config: Optional[types.QuantizationConfig] = None, timeout: Optional[int] = None, sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None, sharding_method: Optional[types.ShardingMethod] = None, strict_mode_config: Optional[types.StrictModeConfig] = None, metadata: Optional[types.Payload] = None, **kwargs: Any, ) -> bool: self.delete_collection(collection_name, timeout=timeout) return self.create_collection( collection_name=collection_name, vectors_config=vectors_config, shard_number=shard_number, replication_factor=replication_factor, write_consistency_factor=write_consistency_factor, on_disk_payload=on_disk_payload, hnsw_config=hnsw_config, optimizers_config=optimizers_config, wal_config=wal_config, quantization_config=quantization_config, timeout=timeout, sparse_vectors_config=sparse_vectors_config, sharding_method=sharding_method, strict_mode_config=strict_mode_config, metadata=metadata, ) @property def _updater_class(self) -> Type[BaseUploader]: if self._prefer_grpc: return GrpcBatchUploader else: return RestBatchUploader def _upload_collection( self, batches_iterator: Iterable, collection_name: str, max_retries: int, parallel: int = 1, method: Optional[str] = None, wait: bool = False, shard_key_selector: Optional[types.ShardKeySelector] = None, update_filter: Optional[types.Filter] = None, ) -> None: if method is not None: if method in get_all_start_methods(): start_method = method else: raise ValueError( f"Start methods {method} is not available, available methods: {get_all_start_methods()}" ) else: start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" if self._prefer_grpc: updater_kwargs = { "collection_name": collection_name, "host": self._host, "port": self._grpc_port, "max_retries": max_retries, "ssl": self._https, "metadata": self._grpc_headers, "wait": wait, "shard_key_selector": shard_key_selector, "options": self._grpc_options, "timeout": self._timeout, "update_filter": update_filter, } else: updater_kwargs = { "collection_name": collection_name, "uri": self.rest_uri, "max_retries": max_retries, "wait": wait, "shard_key_selector": shard_key_selector, "update_filter": update_filter, **self._rest_args, } if parallel == 1: updater = self._updater_class.start(**updater_kwargs) for _ in updater.process(batches_iterator): pass else: pool = ParallelWorkerPool(parallel, self._updater_class, start_method=start_method) for _ in pool.unordered_map(batches_iterator, **updater_kwargs): pass def upload_points( self, collection_name: str, points: Iterable[types.PointStruct], batch_size: int = 64, parallel: int = 1, method: Optional[str] = None, max_retries: int = 3, wait: bool = False, shard_key_selector: Optional[types.ShardKeySelector] = None, update_filter: Optional[types.Filter] = None, **kwargs: Any, ) -> None: batches_iterator = self._updater_class.iterate_records_batches( records=points, batch_size=batch_size ) self._upload_collection( batches_iterator=batches_iterator, collection_name=collection_name, max_retries=max_retries, parallel=parallel, method=method, wait=wait, shard_key_selector=shard_key_selector, update_filter=update_filter, ) def upload_collection( self, collection_name: str, vectors: Union[ dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct] ], payload: Optional[Iterable[dict[Any, Any]]] = None, ids: Optional[Iterable[types.PointId]] = None, batch_size: int = 64, parallel: int = 1, method: Optional[str] = None, max_retries: int = 3, wait: bool = False, shard_key_selector: Optional[types.ShardKeySelector] = None, update_filter: Optional[types.Filter] = None, **kwargs: Any, ) -> None: batches_iterator = self._updater_class.iterate_batches( vectors=vectors, payload=payload, ids=ids, batch_size=batch_size, ) self._upload_collection( batches_iterator=batches_iterator, collection_name=collection_name, max_retries=max_retries, parallel=parallel, method=method, wait=wait, shard_key_selector=shard_key_selector, update_filter=update_filter, ) def create_payload_index( self, collection_name: str, field_name: str, field_schema: Optional[types.PayloadSchemaType] = None, field_type: Optional[types.PayloadSchemaType] = None, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, **kwargs: Any, ) -> types.UpdateResult: if field_type is not None: show_warning_once( message="field_type is deprecated, use field_schema instead", category=DeprecationWarning, stacklevel=5, idx="payload-index-field-type", ) field_schema = field_type if self._prefer_grpc: field_index_params = None if isinstance(field_schema, models.PayloadSchemaType): field_schema = RestToGrpc.convert_payload_schema_type(field_schema) if isinstance(field_schema, str): field_schema = RestToGrpc.convert_payload_schema_type( models.PayloadSchemaType(field_schema) ) if isinstance(field_schema, int): # There are no means to distinguish grpc.PayloadSchemaType and grpc.FieldType, # as both of them are just ints # method signature assumes that grpc.PayloadSchemaType is passed, # otherwise the value will be corrupted field_schema = grpc_payload_schema_to_field_type(field_schema) if isinstance(field_schema, get_args(models.PayloadSchemaParams)): field_schema = RestToGrpc.convert_payload_schema_params(field_schema) if isinstance(field_schema, grpc.PayloadIndexParams): field_index_params = field_schema name = field_index_params.WhichOneof("index_params") index_params = getattr(field_index_params, name) if isinstance(index_params, grpc.TextIndexParams): field_schema = grpc.FieldType.FieldTypeText if isinstance(index_params, grpc.IntegerIndexParams): field_schema = grpc.FieldType.FieldTypeInteger if isinstance(index_params, grpc.KeywordIndexParams): field_schema = grpc.FieldType.FieldTypeKeyword if isinstance(index_params, grpc.FloatIndexParams): field_schema = grpc.FieldType.FieldTypeFloat if isinstance(index_params, grpc.GeoIndexParams): field_schema = grpc.FieldType.FieldTypeGeo if isinstance(index_params, grpc.BoolIndexParams): field_schema = grpc.FieldType.FieldTypeBool if isinstance(index_params, grpc.DatetimeIndexParams): field_schema = grpc.FieldType.FieldTypeDatetime if isinstance(index_params, grpc.UuidIndexParams): field_schema = grpc.FieldType.FieldTypeUuid request = grpc.CreateFieldIndexCollection( collection_name=collection_name, field_name=field_name, field_type=field_schema, field_index_params=field_index_params, wait=wait, ordering=ordering, ) return GrpcToRest.convert_update_result( self.grpc_points.CreateFieldIndex(request, timeout=self._timeout).result ) if isinstance(field_schema, int): # type(grpc.PayloadSchemaType) == int field_schema = GrpcToRest.convert_payload_schema_type(field_schema) if isinstance(field_schema, grpc.PayloadIndexParams): field_schema = GrpcToRest.convert_payload_schema_params(field_schema) result: Optional[types.UpdateResult] = self.openapi_client.indexes_api.create_field_index( collection_name=collection_name, create_field_index=models.CreateFieldIndex( field_name=field_name, field_schema=field_schema ), wait=wait, ordering=ordering, ).result assert result is not None, "Create field index returned None" return result def delete_payload_index( self, collection_name: str, field_name: str, wait: bool = True, ordering: Optional[types.WriteOrdering] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: request = grpc.DeleteFieldIndexCollection( collection_name=collection_name, field_name=field_name, wait=wait, ordering=ordering, ) return GrpcToRest.convert_update_result( self.grpc_points.DeleteFieldIndex(request, timeout=self._timeout).result ) result: Optional[types.UpdateResult] = self.openapi_client.indexes_api.delete_field_index( collection_name=collection_name, field_name=field_name, wait=wait, ordering=ordering, ).result assert result is not None, "Delete field index returned None" return result def list_snapshots( self, collection_name: str, **kwargs: Any ) -> list[types.SnapshotDescription]: if self._prefer_grpc: snapshots = self.grpc_snapshots.List( grpc.ListSnapshotsRequest(collection_name=collection_name), timeout=self._timeout ).snapshot_descriptions return [GrpcToRest.convert_snapshot_description(snapshot) for snapshot in snapshots] snapshots = self.openapi_client.snapshots_api.list_snapshots( collection_name=collection_name ).result assert snapshots is not None, "List snapshots API returned None result" return snapshots def create_snapshot( self, collection_name: str, wait: bool = True, **kwargs: Any ) -> Optional[types.SnapshotDescription]: if self._prefer_grpc: snapshot = self.grpc_snapshots.Create( grpc.CreateSnapshotRequest(collection_name=collection_name), timeout=self._timeout ).snapshot_description return GrpcToRest.convert_snapshot_description(snapshot) return self.openapi_client.snapshots_api.create_snapshot( collection_name=collection_name, wait=wait ).result def delete_snapshot( self, collection_name: str, snapshot_name: str, wait: bool = True, **kwargs: Any ) -> Optional[bool]: if self._prefer_grpc: self.grpc_snapshots.Delete( grpc.DeleteSnapshotRequest( collection_name=collection_name, snapshot_name=snapshot_name ), timeout=self._timeout, ) return True return self.openapi_client.snapshots_api.delete_snapshot( collection_name=collection_name, snapshot_name=snapshot_name, wait=wait, ).result def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]: if self._prefer_grpc: snapshots = self.grpc_snapshots.ListFull( grpc.ListFullSnapshotsRequest(), timeout=self._timeout, ).snapshot_descriptions return [GrpcToRest.convert_snapshot_description(snapshot) for snapshot in snapshots] snapshots = self.openapi_client.snapshots_api.list_full_snapshots().result assert snapshots is not None, "List full snapshots API returned None result" return snapshots def create_full_snapshot(self, wait: bool = True, **kwargs: Any) -> types.SnapshotDescription: if self._prefer_grpc: snapshot_description = self.grpc_snapshots.CreateFull( grpc.CreateFullSnapshotRequest(), timeout=self._timeout ).snapshot_description return GrpcToRest.convert_snapshot_description(snapshot_description) return self.openapi_client.snapshots_api.create_full_snapshot(wait=wait).result def delete_full_snapshot( self, snapshot_name: str, wait: bool = True, **kwargs: Any ) -> Optional[bool]: if self._prefer_grpc: self.grpc_snapshots.DeleteFull( grpc.DeleteFullSnapshotRequest(snapshot_name=snapshot_name), timeout=self._timeout, ) return True return self.openapi_client.snapshots_api.delete_full_snapshot( snapshot_name=snapshot_name, wait=wait ).result def recover_snapshot( self, collection_name: str, location: str, api_key: Optional[str] = None, checksum: Optional[str] = None, priority: Optional[types.SnapshotPriority] = None, wait: bool = True, **kwargs: Any, ) -> Optional[bool]: return self.openapi_client.snapshots_api.recover_from_snapshot( collection_name=collection_name, wait=wait, snapshot_recover=models.SnapshotRecover( location=location, priority=priority, checksum=checksum, api_key=api_key, ), ).result def list_shard_snapshots( self, collection_name: str, shard_id: int, **kwargs: Any ) -> list[types.SnapshotDescription]: snapshots = self.openapi_client.snapshots_api.list_shard_snapshots( collection_name=collection_name, shard_id=shard_id, ).result assert snapshots is not None, "List snapshots API returned None result" return snapshots def create_shard_snapshot( self, collection_name: str, shard_id: int, wait: bool = True, **kwargs: Any ) -> Optional[types.SnapshotDescription]: return self.openapi_client.snapshots_api.create_shard_snapshot( collection_name=collection_name, shard_id=shard_id, wait=wait, ).result def delete_shard_snapshot( self, collection_name: str, shard_id: int, snapshot_name: str, wait: bool = True, **kwargs: Any, ) -> Optional[bool]: return self.openapi_client.snapshots_api.delete_shard_snapshot( collection_name=collection_name, shard_id=shard_id, snapshot_name=snapshot_name, wait=wait, ).result def recover_shard_snapshot( self, collection_name: str, shard_id: int, location: str, api_key: Optional[str] = None, checksum: Optional[str] = None, priority: Optional[types.SnapshotPriority] = None, wait: bool = True, **kwargs: Any, ) -> Optional[bool]: return self.openapi_client.snapshots_api.recover_shard_from_snapshot( collection_name=collection_name, shard_id=shard_id, wait=wait, shard_snapshot_recover=models.ShardSnapshotRecover( location=location, priority=priority, checksum=checksum, api_key=api_key, ), ).result def create_shard_key( self, collection_name: str, shard_key: types.ShardKey, shards_number: Optional[int] = None, replication_factor: Optional[int] = None, placement: Optional[list[int]] = None, initial_state: Optional[types.ReplicaState] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> bool: if self._prefer_grpc: if isinstance(shard_key, get_args_subscribed(models.ShardKey)): shard_key = RestToGrpc.convert_shard_key(shard_key) if isinstance(initial_state, models.ReplicaState): initial_state = RestToGrpc.convert_replica_state(initial_state) return self.grpc_collections.CreateShardKey( grpc.CreateShardKeyRequest( collection_name=collection_name, timeout=timeout, request=grpc.CreateShardKey( shard_key=shard_key, shards_number=shards_number, replication_factor=replication_factor, placement=placement or [], initial_state=initial_state, ), ), timeout=timeout if timeout is not None else self._timeout, ).result else: result = self.openapi_client.distributed_api.create_shard_key( collection_name=collection_name, timeout=timeout, create_sharding_key=models.CreateShardingKey( shard_key=shard_key, shards_number=shards_number, replication_factor=replication_factor, placement=placement, initial_state=initial_state, ), ).result assert result is not None, "Create shard key returned None" return result def delete_shard_key( self, collection_name: str, shard_key: types.ShardKey, timeout: Optional[int] = None, **kwargs: Any, ) -> bool: if self._prefer_grpc: if isinstance(shard_key, get_args_subscribed(models.ShardKey)): shard_key = RestToGrpc.convert_shard_key(shard_key) return self.grpc_collections.DeleteShardKey( grpc.DeleteShardKeyRequest( collection_name=collection_name, timeout=timeout, request=grpc.DeleteShardKey( shard_key=shard_key, ), ), timeout=timeout if timeout is not None else self._timeout, ).result else: result = self.openapi_client.distributed_api.delete_shard_key( collection_name=collection_name, timeout=timeout, drop_sharding_key=models.DropShardingKey( shard_key=shard_key, ), ).result assert result is not None, "Delete shard key returned None" return result def info(self) -> types.VersionInfo: if self._prefer_grpc: version_info = self.grpc_root.HealthCheck( grpc.HealthCheckRequest(), timeout=self._timeout ) return GrpcToRest.convert_health_check_reply(version_info) version_info = self.rest.service_api.root() assert version_info is not None, "Healthcheck returned None" return version_info def cluster_collection_update( self, collection_name: str, cluster_operation: types.ClusterOperations, timeout: Optional[int] = None, **kwargs: Any, ) -> bool: if self._prefer_grpc: cluster_operation = RestToGrpc.convert_cluster_operations(cluster_operation) grpc_operation = {} if isinstance(cluster_operation, grpc.MoveShard): grpc_operation["move_shard"] = cluster_operation elif isinstance(cluster_operation, grpc.ReplicateShard): grpc_operation["replicate_shard"] = cluster_operation elif isinstance(cluster_operation, grpc.AbortShardTransfer): grpc_operation["abort_transfer"] = cluster_operation elif isinstance(cluster_operation, grpc.Replica): grpc_operation["drop_replica"] = cluster_operation elif isinstance(cluster_operation, grpc.CreateShardKey): grpc_operation["create_shard_key"] = cluster_operation elif isinstance(cluster_operation, grpc.DeleteShardKey): grpc_operation["delete_shard_key"] = cluster_operation elif isinstance(cluster_operation, grpc.RestartTransfer): grpc_operation["restart_transfer"] = cluster_operation elif isinstance(cluster_operation, grpc.ReplicatePoints): grpc_operation["replicate_points"] = cluster_operation else: raise TypeError(f"Unknown cluster operation: {cluster_operation}") return self.grpc_collections.UpdateCollectionClusterSetup( grpc.UpdateCollectionClusterSetupRequest( collection_name=collection_name, timeout=timeout, **grpc_operation ), timeout=timeout if timeout is not None else self._timeout, ).result update_result = self.rest.distributed_api.update_collection_cluster( collection_name=collection_name, cluster_operations=cluster_operation, timeout=timeout ).result assert update_result is not None, "Cluster collection update returned None" return update_result def cluster_status(self) -> types.ClusterStatus: # grpc does not have cluster status api status_result = self.rest.distributed_api.cluster_status().result assert status_result is not None, "Cluster status returned None" return status_result def recover_current_peer(self) -> bool: # grpc does not have recover peer api recover_result = self.rest.distributed_api.recover_current_peer().result assert recover_result is not None, "Recover current peer returned None" return recover_result def remove_peer( self, peer_id: int, force: Optional[bool] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> bool: # grpc does not have remove peer api update_result = self.rest.distributed_api.remove_peer( peer_id=peer_id, force=force, timeout=timeout, ).result assert update_result is not None, "Remove peer returned None" return update_result def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo: if self._prefer_grpc: collection_info = self.grpc_collections.CollectionClusterInfo( grpc.CollectionClusterInfoRequest(collection_name=collection_name), timeout=self._timeout, ) return GrpcToRest.convert_collection_cluster_info(collection_info) collection_info = self.rest.distributed_api.collection_cluster_info( collection_name=collection_name ).result assert collection_info is not None, "Collection cluster info returned None" return collection_info