2666 lines
106 KiB
Python
2666 lines
106 KiB
Python
import importlib.metadata
|
|
import logging
|
|
import math
|
|
import platform
|
|
from multiprocessing import get_all_start_methods
|
|
from typing import (
|
|
Any,
|
|
Awaitable,
|
|
Callable,
|
|
Iterable,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Type,
|
|
Union,
|
|
get_args,
|
|
)
|
|
|
|
import httpx
|
|
from grpc import Compression
|
|
from urllib3.util import Url, parse_url
|
|
from urllib.parse import urljoin
|
|
|
|
from qdrant_client.common.client_warnings import show_warning, show_warning_once
|
|
from qdrant_client import grpc as grpc
|
|
from qdrant_client._pydantic_compat import construct
|
|
from qdrant_client.auth import BearerAuth
|
|
from qdrant_client.client_base import QdrantBase
|
|
from qdrant_client.common.version_check import is_compatible, get_server_version
|
|
from qdrant_client.connection import get_channel
|
|
from qdrant_client.conversions import common_types as types
|
|
from qdrant_client.conversions.common_types import get_args_subscribed
|
|
from qdrant_client.conversions.conversion import (
|
|
GrpcToRest,
|
|
RestToGrpc,
|
|
grpc_payload_schema_to_field_type,
|
|
)
|
|
from qdrant_client.http import ApiClient, SyncApis, models
|
|
from qdrant_client.parallel_processor import ParallelWorkerPool
|
|
from qdrant_client.uploader.grpc_uploader import GrpcBatchUploader
|
|
from qdrant_client.uploader.rest_uploader import RestBatchUploader
|
|
from qdrant_client.uploader.uploader import BaseUploader
|
|
|
|
|
|
class QdrantRemote(QdrantBase):
|
|
DEFAULT_GRPC_TIMEOUT = 5 # seconds
|
|
DEFAULT_GRPC_POOL_SIZE = 3
|
|
|
|
def __init__(
|
|
self,
|
|
url: Optional[str] = None,
|
|
port: Optional[int] = 6333,
|
|
grpc_port: int = 6334,
|
|
prefer_grpc: bool = False,
|
|
https: Optional[bool] = None,
|
|
api_key: Optional[str] = None,
|
|
prefix: Optional[str] = None,
|
|
timeout: Optional[int] = None,
|
|
host: Optional[str] = None,
|
|
grpc_options: Optional[dict[str, Any]] = None,
|
|
auth_token_provider: Optional[
|
|
Union[Callable[[], str], Callable[[], Awaitable[str]]]
|
|
] = None,
|
|
check_compatibility: bool = True,
|
|
pool_size: Optional[int] = None,
|
|
**kwargs: Any,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self._prefer_grpc = prefer_grpc
|
|
self._grpc_port = grpc_port
|
|
self._grpc_options = grpc_options or {}
|
|
self._https = https if https is not None else api_key is not None
|
|
self._scheme = "https" if self._https else "http"
|
|
|
|
# Pool size to use. This value should not be accessed directly; use _get_grpc_pool_size() instead.
|
|
self._pool_size: Optional[int] = None
|
|
|
|
if pool_size is not None:
|
|
pool_size = max(1, pool_size) # Ensure pool_size is always > 0
|
|
self._pool_size = pool_size
|
|
|
|
self._prefix = prefix or ""
|
|
if len(self._prefix) > 0 and self._prefix[0] != "/":
|
|
self._prefix = f"/{self._prefix}"
|
|
|
|
if url is not None and host is not None:
|
|
raise ValueError(f"Only one of (url, host) can be set. url is {url}, host is {host}")
|
|
|
|
if host is not None and (host.startswith("http://") or host.startswith("https://")):
|
|
raise ValueError(
|
|
f"`host` param is not expected to contain protocol (http:// or https://). "
|
|
f"Try to use `url` parameter instead."
|
|
)
|
|
|
|
elif url:
|
|
if url.startswith("localhost"):
|
|
# Handle for a special case when url is localhost:port
|
|
# Which is not parsed correctly by urllib
|
|
url = f"//{url}"
|
|
|
|
parsed_url: Url = parse_url(url)
|
|
self._host, self._port = parsed_url.host, parsed_url.port
|
|
|
|
if parsed_url.scheme:
|
|
self._https = parsed_url.scheme == "https"
|
|
self._scheme = parsed_url.scheme
|
|
|
|
self._port = self._port if self._port else port
|
|
|
|
if self._prefix and parsed_url.path:
|
|
raise ValueError(
|
|
"Prefix can be set either in `url` or in `prefix`. "
|
|
f"url is {url}, prefix is {parsed_url.path}"
|
|
)
|
|
elif parsed_url.path:
|
|
self._prefix = parsed_url.path
|
|
|
|
if self._scheme not in ("http", "https"):
|
|
raise ValueError(f"Unknown scheme: {self._scheme}")
|
|
else:
|
|
self._host = host or "localhost"
|
|
self._port = port
|
|
|
|
_timeout = (
|
|
math.ceil(timeout) if timeout is not None else None
|
|
) # it has been changed from float to int.
|
|
# convert it to the closest greater or equal int value (e.g. 0.5 -> 1)
|
|
self._api_key = api_key
|
|
self._auth_token_provider = auth_token_provider
|
|
|
|
limits = kwargs.pop("limits", None)
|
|
if limits is None:
|
|
if self._host in ["localhost", "127.0.0.1"]:
|
|
# Disable keep-alive for local connections
|
|
# Cause in some cases, it may cause extra delays
|
|
limits = httpx.Limits(max_connections=None, max_keepalive_connections=0)
|
|
elif self._pool_size is not None:
|
|
# Set http connection pooling to `self._pool_size`, if no limits are specified.
|
|
limits = httpx.Limits(max_connections=self._pool_size)
|
|
elif self._pool_size is not None:
|
|
raise ValueError(
|
|
"`pool_size` and `limits` are mutually exclusive. "
|
|
f"`pool_size`: {pool_size}, `limit`: {limits}"
|
|
)
|
|
|
|
http2 = kwargs.pop("http2", False)
|
|
self._grpc_headers = []
|
|
self._rest_headers = {k: v for k, v in kwargs.pop("metadata", {}).items()}
|
|
if api_key is not None:
|
|
if self._scheme == "http":
|
|
show_warning(
|
|
message="Api key is used with an insecure connection.",
|
|
category=UserWarning,
|
|
stacklevel=4,
|
|
)
|
|
|
|
# http2 = True
|
|
|
|
self._rest_headers["api-key"] = api_key
|
|
self._grpc_headers.append(("api-key", api_key))
|
|
|
|
client_version = importlib.metadata.version("qdrant-client")
|
|
python_version = platform.python_version()
|
|
user_agent = f"python-client/{client_version} python/{python_version}"
|
|
self._rest_headers["User-Agent"] = user_agent
|
|
self._grpc_options["grpc.primary_user_agent"] = user_agent
|
|
|
|
# GRPC Channel-Level Compression
|
|
grpc_compression: Optional[Compression] = kwargs.pop("grpc_compression", None)
|
|
if grpc_compression is not None and not isinstance(grpc_compression, Compression):
|
|
raise TypeError(
|
|
f"Expected 'grpc_compression' to be of type "
|
|
f"grpc.Compression or None, but got {type(grpc_compression)}"
|
|
)
|
|
if grpc_compression == Compression.Deflate:
|
|
raise ValueError(
|
|
"grpc.Compression.Deflate is not supported. Try grpc.Compression.Gzip or grpc.Compression.NoCompression"
|
|
)
|
|
self._grpc_compression = grpc_compression
|
|
|
|
address = f"{self._host}:{self._port}" if self._port is not None else self._host
|
|
base_url = f"{self._scheme}://{address}"
|
|
self.rest_uri = urljoin(base_url, self._prefix)
|
|
|
|
self._rest_args = {"headers": self._rest_headers, "http2": http2, **kwargs}
|
|
|
|
if limits is not None:
|
|
self._rest_args["limits"] = limits
|
|
|
|
if _timeout is not None:
|
|
self._rest_args["timeout"] = _timeout
|
|
self._timeout = _timeout
|
|
else:
|
|
self._timeout = self.DEFAULT_GRPC_TIMEOUT
|
|
|
|
if self._auth_token_provider is not None:
|
|
if self._scheme == "http":
|
|
show_warning(
|
|
message="Auth token provider is used with an insecure connection.",
|
|
category=UserWarning,
|
|
stacklevel=4,
|
|
)
|
|
|
|
bearer_auth = BearerAuth(self._auth_token_provider)
|
|
self._rest_args["auth"] = bearer_auth
|
|
|
|
self.openapi_client: SyncApis[ApiClient] = SyncApis(
|
|
host=self.rest_uri,
|
|
**self._rest_args,
|
|
)
|
|
|
|
self._grpc_channel_pool: list[grpc.Channel] = []
|
|
self._grpc_points_client_pool: Optional[list[grpc.PointsStub]] = None
|
|
self._grpc_collections_client_pool: Optional[list[grpc.CollectionsStub]] = None
|
|
self._grpc_snapshots_client_pool: Optional[list[grpc.SnapshotsStub]] = None
|
|
self._grpc_root_client_pool: Optional[list[grpc.QdrantStub]] = None
|
|
self._grpc_client_next_index: int = 0 # The next index to use
|
|
|
|
self._aio_grpc_points_client: Optional[grpc.PointsStub] = None
|
|
self._aio_grpc_collections_client: Optional[grpc.CollectionsStub] = None
|
|
self._aio_grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None
|
|
self._aio_grpc_root_client: Optional[grpc.QdrantStub] = None
|
|
|
|
self._closed: bool = False
|
|
|
|
self.server_version = None
|
|
if check_compatibility:
|
|
try:
|
|
client_version = importlib.metadata.version("qdrant-client")
|
|
self.server_version = get_server_version(
|
|
self.rest_uri, self._rest_headers, self._rest_args.get("auth")
|
|
)
|
|
|
|
if not self.server_version:
|
|
show_warning(
|
|
message="Failed to obtain server version. Unable to check client-server compatibility."
|
|
" Set check_compatibility=False to skip version check.",
|
|
category=UserWarning,
|
|
stacklevel=4,
|
|
)
|
|
elif not is_compatible(client_version, self.server_version):
|
|
show_warning(
|
|
message=f"Qdrant client version {client_version} is incompatible with server "
|
|
f"version {self.server_version}. Major versions should match and minor version difference "
|
|
"must not exceed 1. Set check_compatibility=False to skip version check.",
|
|
category=UserWarning,
|
|
stacklevel=4,
|
|
)
|
|
except Exception as er:
|
|
logging.debug(
|
|
f"Unable to get server version: {er}, server version defaults to None"
|
|
)
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
return self._closed
|
|
|
|
def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None:
|
|
if hasattr(self, "_grpc_channel_pool") and len(self._grpc_channel_pool) > 0:
|
|
for channel in self._grpc_channel_pool:
|
|
try:
|
|
channel.close()
|
|
except AttributeError:
|
|
show_warning(
|
|
message="Unable to close grpc_channel. Connection was interrupted on the server side",
|
|
category=RuntimeWarning,
|
|
stacklevel=4,
|
|
)
|
|
|
|
try:
|
|
self.openapi_client.close()
|
|
except Exception:
|
|
show_warning(
|
|
message="Unable to close http connection. Connection was interrupted on the server side",
|
|
category=RuntimeWarning,
|
|
stacklevel=4,
|
|
)
|
|
|
|
self._closed = True
|
|
|
|
@staticmethod
|
|
def _parse_url(url: str) -> tuple[Optional[str], str, Optional[int], Optional[str]]:
|
|
parse_result: Url = parse_url(url)
|
|
scheme, host, port, prefix = (
|
|
parse_result.scheme,
|
|
parse_result.host,
|
|
parse_result.port,
|
|
parse_result.path,
|
|
)
|
|
return scheme, host, port, prefix
|
|
|
|
def _get_grpc_pool_size(self) -> int:
|
|
"""
|
|
Returns the pool size to use for GRPC connection pool.
|
|
This method should be preferred over accessing `self._pool_size` directly as it applies the
|
|
default value if no pool_size was provided.
|
|
"""
|
|
|
|
if self._pool_size is not None:
|
|
return self._pool_size
|
|
else:
|
|
return self.DEFAULT_GRPC_POOL_SIZE
|
|
|
|
def _init_grpc_channel(self) -> None:
|
|
if self._closed:
|
|
raise RuntimeError("Client was closed. Please create a new QdrantClient instance.")
|
|
|
|
try:
|
|
channel_pool = []
|
|
|
|
if len(self._grpc_channel_pool) == 0:
|
|
for _ in range(self._get_grpc_pool_size()):
|
|
channel = get_channel(
|
|
host=self._host,
|
|
port=self._grpc_port,
|
|
ssl=self._https,
|
|
metadata=self._grpc_headers,
|
|
options=self._grpc_options,
|
|
compression=self._grpc_compression,
|
|
# sync get_channel does not accept coroutine functions,
|
|
# but we can't check type here, since it'll get into async client as well
|
|
auth_token_provider=self._auth_token_provider, # type: ignore
|
|
)
|
|
channel_pool.append(channel)
|
|
|
|
# Apply the clients late to prevent half-initialized pools if a channel creation fails.
|
|
self._grpc_channel_pool = channel_pool
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error initializing the grpc connection(s): {e}")
|
|
|
|
def _init_grpc_points_client(self) -> None:
|
|
self._init_grpc_channel()
|
|
self._grpc_points_client_pool = [
|
|
grpc.PointsStub(channel) for channel in self._grpc_channel_pool
|
|
]
|
|
|
|
def _init_grpc_collections_client(self) -> None:
|
|
self._init_grpc_channel()
|
|
self._grpc_collections_client_pool = [
|
|
grpc.CollectionsStub(channel) for channel in self._grpc_channel_pool
|
|
]
|
|
|
|
def _init_grpc_snapshots_client(self) -> None:
|
|
self._init_grpc_channel()
|
|
self._grpc_snapshots_client_pool = [
|
|
grpc.SnapshotsStub(channel) for channel in self._grpc_channel_pool
|
|
]
|
|
|
|
def _init_grpc_root_client(self) -> None:
|
|
self._init_grpc_channel()
|
|
self._grpc_root_client_pool = [
|
|
grpc.QdrantStub(channel) for channel in self._grpc_channel_pool
|
|
]
|
|
|
|
def _next_grpc_client(self) -> int:
|
|
current_index = self._grpc_client_next_index
|
|
self._grpc_client_next_index = (
|
|
self._grpc_client_next_index + 1
|
|
) % self._get_grpc_pool_size()
|
|
return current_index
|
|
|
|
@property
|
|
def grpc_collections(self) -> grpc.CollectionsStub:
|
|
"""gRPC client for collections methods
|
|
|
|
Returns:
|
|
An instance of raw gRPC client, generated from Protobuf
|
|
"""
|
|
if self._grpc_collections_client_pool is None:
|
|
self._init_grpc_collections_client()
|
|
assert self._grpc_collections_client_pool is not None
|
|
return self._grpc_collections_client_pool[self._next_grpc_client()]
|
|
|
|
@property
|
|
def grpc_points(self) -> grpc.PointsStub:
|
|
"""gRPC client for points methods
|
|
|
|
Returns:
|
|
An instance of raw gRPC client, generated from Protobuf
|
|
"""
|
|
if self._grpc_points_client_pool is None:
|
|
self._init_grpc_points_client()
|
|
assert self._grpc_points_client_pool is not None
|
|
return self._grpc_points_client_pool[self._next_grpc_client()]
|
|
|
|
@property
|
|
def grpc_snapshots(self) -> grpc.SnapshotsStub:
|
|
"""gRPC client for snapshots methods
|
|
|
|
Returns:
|
|
An instance of raw gRPC client, generated from Protobuf
|
|
"""
|
|
if self._grpc_snapshots_client_pool is None:
|
|
self._init_grpc_snapshots_client()
|
|
assert self._grpc_snapshots_client_pool is not None
|
|
return self._grpc_snapshots_client_pool[self._next_grpc_client()]
|
|
|
|
@property
|
|
def grpc_root(self) -> grpc.QdrantStub:
|
|
"""gRPC client for info methods
|
|
|
|
Returns:
|
|
An instance of raw gRPC client, generated from Protobuf
|
|
"""
|
|
if self._grpc_root_client_pool is None:
|
|
self._init_grpc_root_client()
|
|
assert self._grpc_root_client_pool is not None
|
|
return self._grpc_root_client_pool[self._next_grpc_client()]
|
|
|
|
@property
|
|
def rest(self) -> SyncApis[ApiClient]:
|
|
"""REST Client
|
|
|
|
Returns:
|
|
An instance of raw REST API client, generated from OpenAPI schema
|
|
"""
|
|
return self.openapi_client
|
|
|
|
@property
|
|
def http(self) -> SyncApis[ApiClient]:
|
|
"""REST Client
|
|
|
|
Returns:
|
|
An instance of raw REST API client, generated from OpenAPI schema
|
|
"""
|
|
return self.openapi_client
|
|
|
|
def query_points(
|
|
self,
|
|
collection_name: str,
|
|
query: Union[
|
|
types.PointId,
|
|
list[float],
|
|
list[list[float]],
|
|
types.SparseVector,
|
|
types.Query,
|
|
types.NumpyArray,
|
|
types.Document,
|
|
types.Image,
|
|
types.InferenceObject,
|
|
None,
|
|
] = None,
|
|
using: Optional[str] = None,
|
|
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
|
query_filter: Optional[types.Filter] = None,
|
|
search_params: Optional[types.SearchParams] = None,
|
|
limit: int = 10,
|
|
offset: Optional[int] = None,
|
|
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
|
with_vectors: Union[bool, Sequence[str]] = False,
|
|
score_threshold: Optional[float] = None,
|
|
lookup_from: Optional[types.LookupLocation] = None,
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> types.QueryResponse:
|
|
if self._prefer_grpc:
|
|
if query is not None:
|
|
query = RestToGrpc.convert_query(query)
|
|
|
|
if isinstance(prefetch, models.Prefetch):
|
|
prefetch = [RestToGrpc.convert_prefetch_query(prefetch)]
|
|
|
|
if isinstance(prefetch, list):
|
|
prefetch = [
|
|
RestToGrpc.convert_prefetch_query(p) if isinstance(p, models.Prefetch) else p
|
|
for p in prefetch
|
|
]
|
|
|
|
if isinstance(query_filter, models.Filter):
|
|
query_filter = RestToGrpc.convert_filter(model=query_filter)
|
|
|
|
if isinstance(search_params, models.SearchParams):
|
|
search_params = RestToGrpc.convert_search_params(search_params)
|
|
|
|
if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
|
|
with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
|
|
|
|
if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
|
|
with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
|
|
|
|
if isinstance(lookup_from, models.LookupLocation):
|
|
lookup_from = RestToGrpc.convert_lookup_location(lookup_from)
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
res: grpc.QueryResponse = self.grpc_points.Query(
|
|
grpc.QueryPoints(
|
|
collection_name=collection_name,
|
|
query=query,
|
|
prefetch=prefetch,
|
|
filter=query_filter,
|
|
limit=limit,
|
|
offset=offset,
|
|
with_vectors=with_vectors,
|
|
with_payload=with_payload,
|
|
params=search_params,
|
|
score_threshold=score_threshold,
|
|
using=using,
|
|
lookup_from=lookup_from,
|
|
timeout=timeout,
|
|
shard_key_selector=shard_key_selector,
|
|
read_consistency=consistency,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
)
|
|
|
|
scored_points = [GrpcToRest.convert_scored_point(hit) for hit in res.result]
|
|
return models.QueryResponse(points=scored_points)
|
|
|
|
else:
|
|
if isinstance(query_filter, grpc.Filter):
|
|
query_filter = GrpcToRest.convert_filter(model=query_filter)
|
|
|
|
if isinstance(search_params, grpc.SearchParams):
|
|
search_params = GrpcToRest.convert_search_params(search_params)
|
|
|
|
if isinstance(with_payload, grpc.WithPayloadSelector):
|
|
with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
|
|
|
|
if isinstance(lookup_from, grpc.LookupLocation):
|
|
lookup_from = GrpcToRest.convert_lookup_location(lookup_from)
|
|
|
|
query_request = models.QueryRequest(
|
|
shard_key=shard_key_selector,
|
|
prefetch=prefetch,
|
|
query=query,
|
|
using=using,
|
|
filter=query_filter,
|
|
params=search_params,
|
|
score_threshold=score_threshold,
|
|
limit=limit,
|
|
offset=offset,
|
|
with_vector=with_vectors,
|
|
with_payload=with_payload,
|
|
lookup_from=lookup_from,
|
|
)
|
|
|
|
query_result = self.http.search_api.query_points(
|
|
collection_name=collection_name,
|
|
consistency=consistency,
|
|
timeout=timeout,
|
|
query_request=query_request,
|
|
)
|
|
|
|
result: Optional[models.QueryResponse] = query_result.result
|
|
assert result is not None, "Search returned None"
|
|
return result
|
|
|
|
def query_batch_points(
|
|
self,
|
|
collection_name: str,
|
|
requests: Sequence[types.QueryRequest],
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> list[types.QueryResponse]:
|
|
if self._prefer_grpc:
|
|
requests = [
|
|
(
|
|
RestToGrpc.convert_query_request(r, collection_name)
|
|
if isinstance(r, models.QueryRequest)
|
|
else r
|
|
)
|
|
for r in requests
|
|
]
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
grpc_res: grpc.QueryBatchResponse = self.grpc_points.QueryBatch(
|
|
grpc.QueryBatchPoints(
|
|
collection_name=collection_name,
|
|
query_points=requests,
|
|
read_consistency=consistency,
|
|
timeout=timeout,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
)
|
|
|
|
return [
|
|
models.QueryResponse(
|
|
points=[GrpcToRest.convert_scored_point(hit) for hit in r.result]
|
|
)
|
|
for r in grpc_res.result
|
|
]
|
|
else:
|
|
http_res: Optional[list[models.QueryResponse]] = (
|
|
self.http.search_api.query_batch_points(
|
|
collection_name=collection_name,
|
|
consistency=consistency,
|
|
timeout=timeout,
|
|
query_request_batch=models.QueryRequestBatch(searches=requests),
|
|
).result
|
|
)
|
|
assert http_res is not None, "Query batch returned None"
|
|
return http_res
|
|
|
|
def query_points_groups(
|
|
self,
|
|
collection_name: str,
|
|
group_by: str,
|
|
query: Union[
|
|
types.PointId,
|
|
list[float],
|
|
list[list[float]],
|
|
types.SparseVector,
|
|
types.Query,
|
|
types.NumpyArray,
|
|
types.Document,
|
|
types.Image,
|
|
types.InferenceObject,
|
|
None,
|
|
] = None,
|
|
using: Optional[str] = None,
|
|
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
|
query_filter: Optional[types.Filter] = None,
|
|
search_params: Optional[types.SearchParams] = None,
|
|
limit: int = 10,
|
|
group_size: int = 3,
|
|
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
|
with_vectors: Union[bool, Sequence[str]] = False,
|
|
score_threshold: Optional[float] = None,
|
|
with_lookup: Optional[types.WithLookupInterface] = None,
|
|
lookup_from: Optional[types.LookupLocation] = None,
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> types.GroupsResult:
|
|
if self._prefer_grpc:
|
|
if query is not None:
|
|
query = RestToGrpc.convert_query(query)
|
|
|
|
if isinstance(prefetch, models.Prefetch):
|
|
prefetch = [RestToGrpc.convert_prefetch_query(prefetch)]
|
|
|
|
if isinstance(prefetch, list):
|
|
prefetch = [
|
|
RestToGrpc.convert_prefetch_query(p) if isinstance(p, models.Prefetch) else p
|
|
for p in prefetch
|
|
]
|
|
|
|
if isinstance(query_filter, models.Filter):
|
|
query_filter = RestToGrpc.convert_filter(model=query_filter)
|
|
|
|
if isinstance(search_params, models.SearchParams):
|
|
search_params = RestToGrpc.convert_search_params(search_params)
|
|
|
|
if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
|
|
with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
|
|
|
|
if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
|
|
with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
|
|
|
|
if isinstance(with_lookup, models.WithLookup):
|
|
with_lookup = RestToGrpc.convert_with_lookup(with_lookup)
|
|
|
|
if isinstance(with_lookup, str):
|
|
with_lookup = grpc.WithLookup(collection=with_lookup)
|
|
|
|
if isinstance(lookup_from, models.LookupLocation):
|
|
lookup_from = RestToGrpc.convert_lookup_location(lookup_from)
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
result: grpc.QueryGroupsResponse = self.grpc_points.QueryGroups(
|
|
grpc.QueryPointGroups(
|
|
collection_name=collection_name,
|
|
query=query,
|
|
prefetch=prefetch,
|
|
filter=query_filter,
|
|
limit=limit,
|
|
with_vectors=with_vectors,
|
|
with_payload=with_payload,
|
|
params=search_params,
|
|
score_threshold=score_threshold,
|
|
using=using,
|
|
group_by=group_by,
|
|
group_size=group_size,
|
|
with_lookup=with_lookup,
|
|
lookup_from=lookup_from,
|
|
timeout=timeout,
|
|
shard_key_selector=shard_key_selector,
|
|
read_consistency=consistency,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
return GrpcToRest.convert_groups_result(result)
|
|
else:
|
|
if isinstance(query_filter, grpc.Filter):
|
|
query_filter = GrpcToRest.convert_filter(model=query_filter)
|
|
|
|
if isinstance(search_params, grpc.SearchParams):
|
|
search_params = GrpcToRest.convert_search_params(search_params)
|
|
|
|
if isinstance(with_payload, grpc.WithPayloadSelector):
|
|
with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
|
|
|
|
if isinstance(lookup_from, grpc.LookupLocation):
|
|
lookup_from = GrpcToRest.convert_lookup_location(lookup_from)
|
|
|
|
query_request = models.QueryGroupsRequest(
|
|
shard_key=shard_key_selector,
|
|
prefetch=prefetch,
|
|
query=query,
|
|
using=using,
|
|
filter=query_filter,
|
|
params=search_params,
|
|
score_threshold=score_threshold,
|
|
limit=limit,
|
|
group_by=group_by,
|
|
group_size=group_size,
|
|
with_vector=with_vectors,
|
|
with_payload=with_payload,
|
|
with_lookup=with_lookup,
|
|
lookup_from=lookup_from,
|
|
)
|
|
|
|
query_result = self.http.search_api.query_points_groups(
|
|
collection_name=collection_name,
|
|
consistency=consistency,
|
|
timeout=timeout,
|
|
query_groups_request=query_request,
|
|
)
|
|
assert query_result is not None, "Query points groups API returned None"
|
|
return query_result.result
|
|
|
|
def search_matrix_pairs(
|
|
self,
|
|
collection_name: str,
|
|
query_filter: Optional[types.Filter] = None,
|
|
limit: int = 3,
|
|
sample: int = 10,
|
|
using: Optional[str] = None,
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> types.SearchMatrixPairsResponse:
|
|
if self._prefer_grpc:
|
|
if isinstance(query_filter, models.Filter):
|
|
query_filter = RestToGrpc.convert_filter(model=query_filter)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
response = self.grpc_points.SearchMatrixPairs(
|
|
grpc.SearchMatrixPoints(
|
|
collection_name=collection_name,
|
|
filter=query_filter,
|
|
sample=sample,
|
|
limit=limit,
|
|
using=using,
|
|
timeout=timeout,
|
|
read_consistency=consistency,
|
|
shard_key_selector=shard_key_selector,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
)
|
|
return GrpcToRest.convert_search_matrix_pairs(response.result)
|
|
|
|
if isinstance(query_filter, grpc.Filter):
|
|
query_filter = GrpcToRest.convert_filter(model=query_filter)
|
|
|
|
search_matrix_result = self.openapi_client.search_api.search_matrix_pairs(
|
|
collection_name=collection_name,
|
|
consistency=consistency,
|
|
timeout=timeout,
|
|
search_matrix_request=models.SearchMatrixRequest(
|
|
shard_key=shard_key_selector,
|
|
limit=limit,
|
|
sample=sample,
|
|
using=using,
|
|
filter=query_filter,
|
|
),
|
|
).result
|
|
assert search_matrix_result is not None, "Search matrix pairs returned None result"
|
|
|
|
return search_matrix_result
|
|
|
|
def search_matrix_offsets(
|
|
self,
|
|
collection_name: str,
|
|
query_filter: Optional[types.Filter] = None,
|
|
limit: int = 3,
|
|
sample: int = 10,
|
|
using: Optional[str] = None,
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> types.SearchMatrixOffsetsResponse:
|
|
if self._prefer_grpc:
|
|
if isinstance(query_filter, models.Filter):
|
|
query_filter = RestToGrpc.convert_filter(model=query_filter)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
response = self.grpc_points.SearchMatrixOffsets(
|
|
grpc.SearchMatrixPoints(
|
|
collection_name=collection_name,
|
|
filter=query_filter,
|
|
sample=sample,
|
|
limit=limit,
|
|
using=using,
|
|
timeout=timeout,
|
|
read_consistency=consistency,
|
|
shard_key_selector=shard_key_selector,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
)
|
|
return GrpcToRest.convert_search_matrix_offsets(response.result)
|
|
|
|
if isinstance(query_filter, grpc.Filter):
|
|
query_filter = GrpcToRest.convert_filter(model=query_filter)
|
|
|
|
search_matrix_result = self.openapi_client.search_api.search_matrix_offsets(
|
|
collection_name=collection_name,
|
|
consistency=consistency,
|
|
timeout=timeout,
|
|
search_matrix_request=models.SearchMatrixRequest(
|
|
shard_key=shard_key_selector,
|
|
limit=limit,
|
|
sample=sample,
|
|
using=using,
|
|
filter=query_filter,
|
|
),
|
|
).result
|
|
assert search_matrix_result is not None, "Search matrix offsets returned None result"
|
|
|
|
return search_matrix_result
|
|
|
|
def scroll(
|
|
self,
|
|
collection_name: str,
|
|
scroll_filter: Optional[types.Filter] = None,
|
|
limit: int = 10,
|
|
order_by: Optional[types.OrderBy] = None,
|
|
offset: Optional[types.PointId] = None,
|
|
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
|
with_vectors: Union[bool, Sequence[str]] = False,
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> tuple[list[types.Record], Optional[types.PointId]]:
|
|
if self._prefer_grpc:
|
|
if isinstance(offset, get_args_subscribed(models.ExtendedPointId)):
|
|
offset = RestToGrpc.convert_extended_point_id(offset)
|
|
|
|
if isinstance(scroll_filter, models.Filter):
|
|
scroll_filter = RestToGrpc.convert_filter(model=scroll_filter)
|
|
|
|
if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
|
|
with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
|
|
|
|
if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
|
|
with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
if isinstance(order_by, get_args_subscribed(models.OrderByInterface)):
|
|
order_by = RestToGrpc.convert_order_by_interface(order_by)
|
|
|
|
res: grpc.ScrollResponse = self.grpc_points.Scroll(
|
|
grpc.ScrollPoints(
|
|
collection_name=collection_name,
|
|
filter=scroll_filter,
|
|
order_by=order_by,
|
|
offset=offset,
|
|
with_vectors=with_vectors,
|
|
with_payload=with_payload,
|
|
limit=limit,
|
|
read_consistency=consistency,
|
|
shard_key_selector=shard_key_selector,
|
|
timeout=timeout,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
)
|
|
|
|
return [GrpcToRest.convert_retrieved_point(point) for point in res.result], (
|
|
GrpcToRest.convert_point_id(res.next_page_offset)
|
|
if res.HasField("next_page_offset")
|
|
else None
|
|
)
|
|
else:
|
|
if isinstance(offset, grpc.PointId):
|
|
offset = GrpcToRest.convert_point_id(offset)
|
|
|
|
if isinstance(scroll_filter, grpc.Filter):
|
|
scroll_filter = GrpcToRest.convert_filter(model=scroll_filter)
|
|
|
|
if isinstance(order_by, grpc.OrderBy):
|
|
order_by = GrpcToRest.convert_order_by(order_by)
|
|
|
|
if isinstance(with_payload, grpc.WithPayloadSelector):
|
|
with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
|
|
|
|
scroll_result: Optional[models.ScrollResult] = (
|
|
self.openapi_client.points_api.scroll_points(
|
|
collection_name=collection_name,
|
|
consistency=consistency,
|
|
scroll_request=models.ScrollRequest(
|
|
filter=scroll_filter,
|
|
limit=limit,
|
|
order_by=order_by,
|
|
offset=offset,
|
|
with_payload=with_payload,
|
|
with_vector=with_vectors,
|
|
shard_key=shard_key_selector,
|
|
),
|
|
timeout=timeout,
|
|
).result
|
|
)
|
|
assert scroll_result is not None, "Scroll points API returned None result"
|
|
|
|
return scroll_result.points, scroll_result.next_page_offset
|
|
|
|
def count(
|
|
self,
|
|
collection_name: str,
|
|
count_filter: Optional[types.Filter] = None,
|
|
exact: bool = True,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
timeout: Optional[int] = None,
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
**kwargs: Any,
|
|
) -> types.CountResult:
|
|
if self._prefer_grpc:
|
|
if isinstance(count_filter, models.Filter):
|
|
count_filter = RestToGrpc.convert_filter(model=count_filter)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
response = self.grpc_points.Count(
|
|
grpc.CountPoints(
|
|
collection_name=collection_name,
|
|
filter=count_filter,
|
|
exact=exact,
|
|
shard_key_selector=shard_key_selector,
|
|
timeout=timeout,
|
|
read_consistency=consistency,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
return GrpcToRest.convert_count_result(response)
|
|
|
|
if isinstance(count_filter, grpc.Filter):
|
|
count_filter = GrpcToRest.convert_filter(model=count_filter)
|
|
|
|
count_result = self.openapi_client.points_api.count_points(
|
|
collection_name=collection_name,
|
|
count_request=models.CountRequest(
|
|
filter=count_filter,
|
|
exact=exact,
|
|
shard_key=shard_key_selector,
|
|
),
|
|
consistency=consistency,
|
|
timeout=timeout,
|
|
).result
|
|
assert count_result is not None, "Count points returned None result"
|
|
return count_result
|
|
|
|
def facet(
|
|
self,
|
|
collection_name: str,
|
|
key: str,
|
|
facet_filter: Optional[types.Filter] = None,
|
|
limit: int = 10,
|
|
exact: bool = False,
|
|
timeout: Optional[int] = None,
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
**kwargs: Any,
|
|
) -> types.FacetResponse:
|
|
if self._prefer_grpc:
|
|
if isinstance(facet_filter, models.Filter):
|
|
facet_filter = RestToGrpc.convert_filter(model=facet_filter)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
response = self.grpc_points.Facet(
|
|
grpc.FacetCounts(
|
|
collection_name=collection_name,
|
|
key=key,
|
|
filter=facet_filter,
|
|
limit=limit,
|
|
exact=exact,
|
|
timeout=timeout,
|
|
read_consistency=consistency,
|
|
shard_key_selector=shard_key_selector,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
)
|
|
|
|
return types.FacetResponse(
|
|
hits=[GrpcToRest.convert_facet_value_hit(hit) for hit in response.hits]
|
|
)
|
|
|
|
if isinstance(facet_filter, grpc.Filter):
|
|
facet_filter = GrpcToRest.convert_filter(model=facet_filter)
|
|
|
|
facet_result = self.openapi_client.points_api.facet(
|
|
collection_name=collection_name,
|
|
consistency=consistency,
|
|
timeout=timeout,
|
|
facet_request=models.FacetRequest(
|
|
shard_key=shard_key_selector,
|
|
key=key,
|
|
limit=limit,
|
|
filter=facet_filter,
|
|
exact=exact,
|
|
),
|
|
).result
|
|
assert facet_result is not None, "Facet points returned None result"
|
|
|
|
return facet_result
|
|
|
|
def upsert(
|
|
self,
|
|
collection_name: str,
|
|
points: types.Points,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
update_filter: Optional[types.Filter] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
if isinstance(points, models.Batch):
|
|
vectors_batch: list[grpc.Vectors] = RestToGrpc.convert_batch_vector_struct(
|
|
points.vectors, len(points.ids)
|
|
)
|
|
points = [
|
|
grpc.PointStruct(
|
|
id=RestToGrpc.convert_extended_point_id(points.ids[idx]),
|
|
vectors=vectors_batch[idx],
|
|
payload=(
|
|
RestToGrpc.convert_payload(points.payloads[idx])
|
|
if points.payloads is not None
|
|
else None
|
|
),
|
|
)
|
|
for idx in range(len(points.ids))
|
|
]
|
|
if isinstance(points, list):
|
|
points = [
|
|
(
|
|
RestToGrpc.convert_point_struct(point)
|
|
if isinstance(point, models.PointStruct)
|
|
else point
|
|
)
|
|
for point in points
|
|
]
|
|
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
if isinstance(update_filter, models.Filter):
|
|
update_filter = RestToGrpc.convert_filter(model=update_filter)
|
|
|
|
grpc_result = self.grpc_points.Upsert(
|
|
grpc.UpsertPoints(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
points=points,
|
|
ordering=ordering,
|
|
shard_key_selector=shard_key_selector,
|
|
update_filter=update_filter,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
|
|
assert grpc_result is not None, "Upsert returned None result"
|
|
return GrpcToRest.convert_update_result(grpc_result)
|
|
else:
|
|
if isinstance(update_filter, grpc.Filter):
|
|
update_filter = GrpcToRest.convert_filter(model=update_filter)
|
|
|
|
if isinstance(points, list):
|
|
points = [
|
|
(
|
|
GrpcToRest.convert_point_struct(point)
|
|
if isinstance(point, grpc.PointStruct)
|
|
else point
|
|
)
|
|
for point in points
|
|
]
|
|
|
|
points = models.PointsList(
|
|
points=points, shard_key=shard_key_selector, update_filter=update_filter
|
|
)
|
|
|
|
if isinstance(points, models.Batch):
|
|
points = models.PointsBatch(
|
|
batch=points, shard_key=shard_key_selector, update_filter=update_filter
|
|
)
|
|
|
|
http_result = self.openapi_client.points_api.upsert_points(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
point_insert_operations=points,
|
|
ordering=ordering,
|
|
).result
|
|
assert http_result is not None, "Upsert returned None result"
|
|
return http_result
|
|
|
|
def update_vectors(
|
|
self,
|
|
collection_name: str,
|
|
points: Sequence[types.PointVectors],
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
update_filter: Optional[types.Filter] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
points = [RestToGrpc.convert_point_vectors(point) for point in points]
|
|
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
if isinstance(update_filter, models.Filter):
|
|
update_filter = RestToGrpc.convert_filter(model=update_filter)
|
|
|
|
grpc_result = self.grpc_points.UpdateVectors(
|
|
grpc.UpdatePointVectors(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
points=points,
|
|
ordering=ordering,
|
|
shard_key_selector=shard_key_selector,
|
|
update_filter=update_filter,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
assert grpc_result is not None, "Upsert returned None result"
|
|
return GrpcToRest.convert_update_result(grpc_result)
|
|
else:
|
|
if isinstance(update_filter, grpc.Filter):
|
|
update_filter = GrpcToRest.convert_filter(model=update_filter)
|
|
|
|
return self.openapi_client.points_api.update_vectors(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
update_vectors=models.UpdateVectors(
|
|
points=points,
|
|
shard_key=shard_key_selector,
|
|
update_filter=update_filter,
|
|
),
|
|
ordering=ordering,
|
|
).result
|
|
|
|
def delete_vectors(
|
|
self,
|
|
collection_name: str,
|
|
vectors: Sequence[str],
|
|
points: types.PointsSelector,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
|
|
shard_key_selector = shard_key_selector or opt_shard_key_selector
|
|
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
grpc_result = self.grpc_points.DeleteVectors(
|
|
grpc.DeletePointVectors(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
vectors=grpc.VectorsSelector(
|
|
names=vectors,
|
|
),
|
|
points_selector=points_selector,
|
|
ordering=ordering,
|
|
shard_key_selector=shard_key_selector,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
|
|
assert grpc_result is not None, "Delete vectors returned None result"
|
|
|
|
return GrpcToRest.convert_update_result(grpc_result)
|
|
else:
|
|
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
|
|
return self.openapi_client.points_api.delete_vectors(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
delete_vectors=construct(
|
|
models.DeleteVectors,
|
|
vector=vectors,
|
|
points=_points,
|
|
filter=_filter,
|
|
shard_key=shard_key_selector,
|
|
),
|
|
).result
|
|
|
|
def retrieve(
|
|
self,
|
|
collection_name: str,
|
|
ids: Sequence[types.PointId],
|
|
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
|
with_vectors: Union[bool, Sequence[str]] = False,
|
|
consistency: Optional[types.ReadConsistency] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> list[types.Record]:
|
|
if self._prefer_grpc:
|
|
if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
|
|
with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
|
|
|
|
ids = [
|
|
(
|
|
RestToGrpc.convert_extended_point_id(idx)
|
|
if isinstance(idx, get_args_subscribed(models.ExtendedPointId))
|
|
else idx
|
|
)
|
|
for idx in ids
|
|
]
|
|
|
|
with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
|
|
|
|
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
|
|
consistency = RestToGrpc.convert_read_consistency(consistency)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
result = self.grpc_points.Get(
|
|
grpc.GetPoints(
|
|
collection_name=collection_name,
|
|
ids=ids,
|
|
with_payload=with_payload,
|
|
with_vectors=with_vectors,
|
|
read_consistency=consistency,
|
|
shard_key_selector=shard_key_selector,
|
|
timeout=timeout,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
|
|
assert result is not None, "Retrieve returned None result"
|
|
|
|
return [GrpcToRest.convert_retrieved_point(record) for record in result]
|
|
|
|
else:
|
|
if isinstance(with_payload, grpc.WithPayloadSelector):
|
|
with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
|
|
|
|
ids = [
|
|
(GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx)
|
|
for idx in ids
|
|
]
|
|
|
|
http_result = self.openapi_client.points_api.get_points(
|
|
collection_name=collection_name,
|
|
consistency=consistency,
|
|
point_request=models.PointRequest(
|
|
ids=ids,
|
|
with_payload=with_payload,
|
|
with_vector=with_vectors,
|
|
shard_key=shard_key_selector,
|
|
),
|
|
timeout=timeout,
|
|
).result
|
|
assert http_result is not None, "Retrieve API returned None result"
|
|
return http_result
|
|
|
|
@classmethod
|
|
def _try_argument_to_grpc_selector(
|
|
cls, points: types.PointsSelector
|
|
) -> tuple[grpc.PointsSelector, Optional[grpc.ShardKeySelector]]:
|
|
shard_key_selector = None
|
|
if isinstance(points, list):
|
|
points_selector = grpc.PointsSelector(
|
|
points=grpc.PointsIdsList(
|
|
ids=[
|
|
(
|
|
RestToGrpc.convert_extended_point_id(idx)
|
|
if isinstance(idx, get_args_subscribed(models.ExtendedPointId))
|
|
else idx
|
|
)
|
|
for idx in points
|
|
]
|
|
)
|
|
)
|
|
elif isinstance(points, grpc.PointsSelector):
|
|
points_selector = points
|
|
elif isinstance(points, get_args(models.PointsSelector)):
|
|
if points.shard_key is not None:
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(points.shard_key)
|
|
points_selector = RestToGrpc.convert_points_selector(points)
|
|
elif isinstance(points, models.Filter):
|
|
points_selector = RestToGrpc.convert_points_selector(
|
|
construct(models.FilterSelector, filter=points)
|
|
)
|
|
elif isinstance(points, grpc.Filter):
|
|
points_selector = grpc.PointsSelector(filter=points)
|
|
else:
|
|
raise ValueError(f"Unsupported points selector type: {type(points)}")
|
|
return points_selector, shard_key_selector
|
|
|
|
@classmethod
|
|
def _try_argument_to_rest_selector(
|
|
cls,
|
|
points: types.PointsSelector,
|
|
shard_key_selector: Optional[types.ShardKeySelector],
|
|
) -> models.PointsSelector:
|
|
if isinstance(points, list):
|
|
_points = [
|
|
(GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx)
|
|
for idx in points
|
|
]
|
|
points_selector = construct(
|
|
models.PointIdsList,
|
|
points=_points,
|
|
shard_key=shard_key_selector,
|
|
)
|
|
elif isinstance(points, grpc.PointsSelector):
|
|
points_selector = GrpcToRest.convert_points_selector(points)
|
|
points_selector.shard_key = shard_key_selector
|
|
elif isinstance(points, get_args(models.PointsSelector)):
|
|
points_selector = points
|
|
points_selector.shard_key = shard_key_selector
|
|
elif isinstance(points, models.Filter):
|
|
points_selector = construct(
|
|
models.FilterSelector, filter=points, shard_key=shard_key_selector
|
|
)
|
|
elif isinstance(points, grpc.Filter):
|
|
points_selector = construct(
|
|
models.FilterSelector,
|
|
filter=GrpcToRest.convert_filter(points),
|
|
shard_key=shard_key_selector,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported points selector type: {type(points)}")
|
|
return points_selector
|
|
|
|
@classmethod
|
|
def _points_selector_to_points_list(
|
|
cls, points_selector: grpc.PointsSelector
|
|
) -> list[grpc.PointId]:
|
|
name = points_selector.WhichOneof("points_selector_one_of")
|
|
if name is None:
|
|
return []
|
|
|
|
val = getattr(points_selector, name)
|
|
|
|
if name == "points":
|
|
return list(val.ids)
|
|
return []
|
|
|
|
@classmethod
|
|
def _try_argument_to_rest_points_and_filter(
|
|
cls, points: types.PointsSelector
|
|
) -> tuple[Optional[list[models.ExtendedPointId]], Optional[models.Filter]]:
|
|
_points = None
|
|
_filter = None
|
|
if isinstance(points, list):
|
|
_points = [
|
|
(GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx)
|
|
for idx in points
|
|
]
|
|
elif isinstance(points, grpc.PointsSelector):
|
|
selector = GrpcToRest.convert_points_selector(points)
|
|
if isinstance(selector, models.PointIdsList):
|
|
_points = selector.points
|
|
elif isinstance(selector, models.FilterSelector):
|
|
_filter = selector.filter
|
|
elif isinstance(points, models.PointIdsList):
|
|
_points = points.points
|
|
elif isinstance(points, models.FilterSelector):
|
|
_filter = points.filter
|
|
elif isinstance(points, models.Filter):
|
|
_filter = points
|
|
elif isinstance(points, grpc.Filter):
|
|
_filter = GrpcToRest.convert_filter(points)
|
|
else:
|
|
raise ValueError(f"Unsupported points selector type: {type(points)}")
|
|
|
|
return _points, _filter
|
|
|
|
def delete(
|
|
self,
|
|
collection_name: str,
|
|
points_selector: types.PointsSelector,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
|
|
points_selector
|
|
)
|
|
shard_key_selector = shard_key_selector or opt_shard_key_selector
|
|
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
return GrpcToRest.convert_update_result(
|
|
self.grpc_points.Delete(
|
|
grpc.DeletePoints(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
points=points_selector,
|
|
ordering=ordering,
|
|
shard_key_selector=shard_key_selector,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
)
|
|
else:
|
|
points_selector = self._try_argument_to_rest_selector(
|
|
points_selector, shard_key_selector
|
|
)
|
|
result: Optional[types.UpdateResult] = self.openapi_client.points_api.delete_points(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
points_selector=points_selector,
|
|
ordering=ordering,
|
|
).result
|
|
assert result is not None, "Delete points returned None"
|
|
return result
|
|
|
|
def set_payload(
|
|
self,
|
|
collection_name: str,
|
|
payload: types.Payload,
|
|
points: types.PointsSelector,
|
|
key: Optional[str] = None,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
|
|
shard_key_selector = shard_key_selector or opt_shard_key_selector
|
|
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
return GrpcToRest.convert_update_result(
|
|
self.grpc_points.SetPayload(
|
|
grpc.SetPayloadPoints(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
payload=RestToGrpc.convert_payload(payload),
|
|
points_selector=points_selector,
|
|
ordering=ordering,
|
|
shard_key_selector=shard_key_selector,
|
|
key=key,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
)
|
|
else:
|
|
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
|
|
result: Optional[types.UpdateResult] = self.openapi_client.points_api.set_payload(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
set_payload=models.SetPayload(
|
|
payload=payload,
|
|
points=_points,
|
|
filter=_filter,
|
|
shard_key=shard_key_selector,
|
|
key=key,
|
|
),
|
|
).result
|
|
assert result is not None, "Set payload returned None"
|
|
return result
|
|
|
|
def overwrite_payload(
|
|
self,
|
|
collection_name: str,
|
|
payload: types.Payload,
|
|
points: types.PointsSelector,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
|
|
shard_key_selector = shard_key_selector or opt_shard_key_selector
|
|
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
return GrpcToRest.convert_update_result(
|
|
self.grpc_points.OverwritePayload(
|
|
grpc.SetPayloadPoints(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
payload=RestToGrpc.convert_payload(payload),
|
|
points_selector=points_selector,
|
|
ordering=ordering,
|
|
shard_key_selector=shard_key_selector,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
)
|
|
else:
|
|
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
|
|
result: Optional[types.UpdateResult] = (
|
|
self.openapi_client.points_api.overwrite_payload(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
set_payload=models.SetPayload(
|
|
payload=payload,
|
|
points=_points,
|
|
filter=_filter,
|
|
shard_key=shard_key_selector,
|
|
),
|
|
).result
|
|
)
|
|
assert result is not None, "Overwrite payload returned None"
|
|
return result
|
|
|
|
def delete_payload(
|
|
self,
|
|
collection_name: str,
|
|
keys: Sequence[str],
|
|
points: types.PointsSelector,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
|
|
shard_key_selector = shard_key_selector or opt_shard_key_selector
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
return GrpcToRest.convert_update_result(
|
|
self.grpc_points.DeletePayload(
|
|
grpc.DeletePayloadPoints(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
keys=keys,
|
|
points_selector=points_selector,
|
|
ordering=ordering,
|
|
shard_key_selector=shard_key_selector,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
)
|
|
else:
|
|
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
|
|
result: Optional[types.UpdateResult] = self.openapi_client.points_api.delete_payload(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
delete_payload=models.DeletePayload(
|
|
keys=keys,
|
|
points=_points,
|
|
filter=_filter,
|
|
shard_key=shard_key_selector,
|
|
),
|
|
).result
|
|
assert result is not None, "Delete payload returned None"
|
|
return result
|
|
|
|
def clear_payload(
|
|
self,
|
|
collection_name: str,
|
|
points_selector: types.PointsSelector,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
|
|
points_selector
|
|
)
|
|
shard_key_selector = shard_key_selector or opt_shard_key_selector
|
|
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
|
|
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
|
|
|
return GrpcToRest.convert_update_result(
|
|
self.grpc_points.ClearPayload(
|
|
grpc.ClearPayloadPoints(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
points=points_selector,
|
|
ordering=ordering,
|
|
shard_key_selector=shard_key_selector,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
)
|
|
else:
|
|
points_selector = self._try_argument_to_rest_selector(
|
|
points_selector, shard_key_selector
|
|
)
|
|
result: Optional[types.UpdateResult] = self.openapi_client.points_api.clear_payload(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
points_selector=points_selector,
|
|
).result
|
|
assert result is not None, "Clear payload returned None"
|
|
return result
|
|
|
|
def batch_update_points(
|
|
self,
|
|
collection_name: str,
|
|
update_operations: Sequence[types.UpdateOperation],
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
**kwargs: Any,
|
|
) -> list[types.UpdateResult]:
|
|
if self._prefer_grpc:
|
|
update_operations = [
|
|
RestToGrpc.convert_update_operation(operation) for operation in update_operations
|
|
]
|
|
|
|
if isinstance(ordering, models.WriteOrdering):
|
|
ordering = RestToGrpc.convert_write_ordering(ordering)
|
|
|
|
return [
|
|
GrpcToRest.convert_update_result(result)
|
|
for result in self.grpc_points.UpdateBatch(
|
|
grpc.UpdateBatchPoints(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
operations=update_operations,
|
|
ordering=ordering,
|
|
),
|
|
timeout=self._timeout,
|
|
).result
|
|
]
|
|
else:
|
|
result: Optional[list[types.UpdateResult]] = (
|
|
self.openapi_client.points_api.batch_update(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
update_operations=models.UpdateOperations(operations=update_operations),
|
|
).result
|
|
)
|
|
assert result is not None, "Batch update points returned None"
|
|
return result
|
|
|
|
def update_collection_aliases(
|
|
self,
|
|
change_aliases_operations: Sequence[types.AliasOperations],
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
if self._prefer_grpc:
|
|
change_aliases_operation = [
|
|
(
|
|
RestToGrpc.convert_alias_operations(operation)
|
|
if not isinstance(operation, grpc.AliasOperations)
|
|
else operation
|
|
)
|
|
for operation in change_aliases_operations
|
|
]
|
|
return self.grpc_collections.UpdateAliases(
|
|
grpc.ChangeAliases(
|
|
timeout=timeout,
|
|
actions=change_aliases_operation,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
|
|
change_aliases_operation = [
|
|
(
|
|
GrpcToRest.convert_alias_operations(operation)
|
|
if isinstance(operation, grpc.AliasOperations)
|
|
else operation
|
|
)
|
|
for operation in change_aliases_operations
|
|
]
|
|
result: Optional[bool] = self.http.aliases_api.update_aliases(
|
|
timeout=timeout,
|
|
change_aliases_operation=models.ChangeAliasesOperation(
|
|
actions=change_aliases_operation
|
|
),
|
|
).result
|
|
assert result is not None, "Update aliases returned None"
|
|
return result
|
|
|
|
def get_collection_aliases(
|
|
self, collection_name: str, **kwargs: Any
|
|
) -> types.CollectionsAliasesResponse:
|
|
if self._prefer_grpc:
|
|
response = self.grpc_collections.ListCollectionAliases(
|
|
grpc.ListCollectionAliasesRequest(collection_name=collection_name),
|
|
timeout=self._timeout,
|
|
).aliases
|
|
return types.CollectionsAliasesResponse(
|
|
aliases=[
|
|
GrpcToRest.convert_alias_description(description) for description in response
|
|
]
|
|
)
|
|
|
|
result: Optional[types.CollectionsAliasesResponse] = (
|
|
self.http.aliases_api.get_collection_aliases(collection_name=collection_name).result
|
|
)
|
|
assert result is not None, "Get collection aliases returned None"
|
|
return result
|
|
|
|
def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
|
|
if self._prefer_grpc:
|
|
response = self.grpc_collections.ListAliases(
|
|
grpc.ListAliasesRequest(), timeout=self._timeout
|
|
).aliases
|
|
return types.CollectionsAliasesResponse(
|
|
aliases=[
|
|
GrpcToRest.convert_alias_description(description) for description in response
|
|
]
|
|
)
|
|
result: Optional[types.CollectionsAliasesResponse] = (
|
|
self.http.aliases_api.get_collections_aliases().result
|
|
)
|
|
assert result is not None, "Get aliases returned None"
|
|
return result
|
|
|
|
def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
|
|
if self._prefer_grpc:
|
|
response = self.grpc_collections.List(
|
|
grpc.ListCollectionsRequest(), timeout=self._timeout
|
|
).collections
|
|
return types.CollectionsResponse(
|
|
collections=[
|
|
GrpcToRest.convert_collection_description(description)
|
|
for description in response
|
|
]
|
|
)
|
|
|
|
result: Optional[types.CollectionsResponse] = (
|
|
self.http.collections_api.get_collections().result
|
|
)
|
|
assert result is not None, "Get collections returned None"
|
|
return result
|
|
|
|
def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
|
|
if self._prefer_grpc:
|
|
return GrpcToRest.convert_collection_info(
|
|
self.grpc_collections.Get(
|
|
grpc.GetCollectionInfoRequest(collection_name=collection_name),
|
|
timeout=self._timeout,
|
|
).result
|
|
)
|
|
result: Optional[types.CollectionInfo] = self.http.collections_api.get_collection(
|
|
collection_name=collection_name
|
|
).result
|
|
assert result is not None, "Get collection returned None"
|
|
return result
|
|
|
|
def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
|
|
if self._prefer_grpc:
|
|
return self.grpc_collections.CollectionExists(
|
|
grpc.CollectionExistsRequest(collection_name=collection_name),
|
|
timeout=self._timeout,
|
|
).result.exists
|
|
|
|
result: Optional[models.CollectionExistence] = self.http.collections_api.collection_exists(
|
|
collection_name=collection_name
|
|
).result
|
|
assert result is not None, "Collection exists returned None"
|
|
return result.exists
|
|
|
|
def update_collection(
|
|
self,
|
|
collection_name: str,
|
|
optimizers_config: Optional[types.OptimizersConfigDiff] = None,
|
|
collection_params: Optional[types.CollectionParamsDiff] = None,
|
|
vectors_config: Optional[types.VectorsConfigDiff] = None,
|
|
hnsw_config: Optional[types.HnswConfigDiff] = None,
|
|
quantization_config: Optional[types.QuantizationConfigDiff] = None,
|
|
timeout: Optional[int] = None,
|
|
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
|
strict_mode_config: Optional[types.StrictModeConfig] = None,
|
|
metadata: Optional[types.Payload] = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
if self._prefer_grpc:
|
|
if isinstance(optimizers_config, models.OptimizersConfigDiff):
|
|
optimizers_config = RestToGrpc.convert_optimizers_config_diff(optimizers_config)
|
|
|
|
if isinstance(collection_params, models.CollectionParamsDiff):
|
|
collection_params = RestToGrpc.convert_collection_params_diff(collection_params)
|
|
|
|
if isinstance(vectors_config, dict):
|
|
vectors_config = RestToGrpc.convert_vectors_config_diff(vectors_config)
|
|
|
|
if isinstance(hnsw_config, models.HnswConfigDiff):
|
|
hnsw_config = RestToGrpc.convert_hnsw_config_diff(hnsw_config)
|
|
|
|
if isinstance(quantization_config, get_args(models.QuantizationConfigDiff)):
|
|
quantization_config = RestToGrpc.convert_quantization_config_diff(
|
|
quantization_config
|
|
)
|
|
|
|
if isinstance(sparse_vectors_config, dict):
|
|
sparse_vectors_config = RestToGrpc.convert_sparse_vector_config(
|
|
sparse_vectors_config
|
|
)
|
|
|
|
if isinstance(strict_mode_config, models.StrictModeConfig):
|
|
strict_mode_config = RestToGrpc.convert_strict_mode_config(strict_mode_config)
|
|
|
|
if isinstance(metadata, dict):
|
|
metadata = RestToGrpc.convert_payload(metadata)
|
|
|
|
return self.grpc_collections.Update(
|
|
grpc.UpdateCollection(
|
|
collection_name=collection_name,
|
|
optimizers_config=optimizers_config,
|
|
params=collection_params,
|
|
vectors_config=vectors_config,
|
|
hnsw_config=hnsw_config,
|
|
quantization_config=quantization_config,
|
|
sparse_vectors_config=sparse_vectors_config,
|
|
strict_mode_config=strict_mode_config,
|
|
timeout=timeout,
|
|
metadata=metadata,
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
|
|
if isinstance(optimizers_config, grpc.OptimizersConfigDiff):
|
|
optimizers_config = GrpcToRest.convert_optimizers_config_diff(optimizers_config)
|
|
|
|
if isinstance(collection_params, grpc.CollectionParamsDiff):
|
|
collection_params = GrpcToRest.convert_collection_params_diff(collection_params)
|
|
|
|
if isinstance(vectors_config, grpc.VectorsConfigDiff):
|
|
vectors_config = GrpcToRest.convert_vectors_config_diff(vectors_config)
|
|
|
|
if isinstance(hnsw_config, grpc.HnswConfigDiff):
|
|
hnsw_config = GrpcToRest.convert_hnsw_config_diff(hnsw_config)
|
|
|
|
if isinstance(quantization_config, grpc.QuantizationConfigDiff):
|
|
quantization_config = GrpcToRest.convert_quantization_config_diff(quantization_config)
|
|
|
|
result: Optional[bool] = self.http.collections_api.update_collection(
|
|
collection_name,
|
|
update_collection=models.UpdateCollection(
|
|
optimizers_config=optimizers_config,
|
|
params=collection_params,
|
|
vectors=vectors_config,
|
|
hnsw_config=hnsw_config,
|
|
quantization_config=quantization_config,
|
|
sparse_vectors=sparse_vectors_config,
|
|
strict_mode_config=strict_mode_config,
|
|
metadata=metadata,
|
|
),
|
|
timeout=timeout,
|
|
).result
|
|
assert result is not None, "Update collection returned None"
|
|
return result
|
|
|
|
def delete_collection(
|
|
self, collection_name: str, timeout: Optional[int] = None, **kwargs: Any
|
|
) -> bool:
|
|
if self._prefer_grpc:
|
|
return self.grpc_collections.Delete(
|
|
grpc.DeleteCollection(collection_name=collection_name, timeout=timeout),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
|
|
result: Optional[bool] = self.http.collections_api.delete_collection(
|
|
collection_name, timeout=timeout
|
|
).result
|
|
assert result is not None, "Delete collection returned None"
|
|
return result
|
|
|
|
def create_collection(
|
|
self,
|
|
collection_name: str,
|
|
vectors_config: Optional[
|
|
Union[types.VectorParams, Mapping[str, types.VectorParams]]
|
|
] = None,
|
|
shard_number: Optional[int] = None,
|
|
replication_factor: Optional[int] = None,
|
|
write_consistency_factor: Optional[int] = None,
|
|
on_disk_payload: Optional[bool] = None,
|
|
hnsw_config: Optional[types.HnswConfigDiff] = None,
|
|
optimizers_config: Optional[types.OptimizersConfigDiff] = None,
|
|
wal_config: Optional[types.WalConfigDiff] = None,
|
|
quantization_config: Optional[types.QuantizationConfig] = None,
|
|
timeout: Optional[int] = None,
|
|
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
|
sharding_method: Optional[types.ShardingMethod] = None,
|
|
strict_mode_config: Optional[types.StrictModeConfig] = None,
|
|
metadata: Optional[types.Payload] = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
if self._prefer_grpc:
|
|
if isinstance(vectors_config, (models.VectorParams, dict)):
|
|
vectors_config = RestToGrpc.convert_vectors_config(vectors_config)
|
|
|
|
if isinstance(hnsw_config, models.HnswConfigDiff):
|
|
hnsw_config = RestToGrpc.convert_hnsw_config_diff(hnsw_config)
|
|
|
|
if isinstance(optimizers_config, models.OptimizersConfigDiff):
|
|
optimizers_config = RestToGrpc.convert_optimizers_config_diff(optimizers_config)
|
|
|
|
if isinstance(wal_config, models.WalConfigDiff):
|
|
wal_config = RestToGrpc.convert_wal_config_diff(wal_config)
|
|
|
|
if isinstance(
|
|
quantization_config,
|
|
get_args(models.QuantizationConfig),
|
|
):
|
|
quantization_config = RestToGrpc.convert_quantization_config(quantization_config)
|
|
|
|
if isinstance(sparse_vectors_config, dict):
|
|
sparse_vectors_config = RestToGrpc.convert_sparse_vector_config(
|
|
sparse_vectors_config
|
|
)
|
|
|
|
if isinstance(sharding_method, models.ShardingMethod):
|
|
sharding_method = RestToGrpc.convert_sharding_method(sharding_method)
|
|
|
|
if isinstance(strict_mode_config, models.StrictModeConfig):
|
|
strict_mode_config = RestToGrpc.convert_strict_mode_config(strict_mode_config)
|
|
|
|
if isinstance(metadata, dict):
|
|
metadata = RestToGrpc.convert_payload(metadata)
|
|
|
|
create_collection = grpc.CreateCollection(
|
|
collection_name=collection_name,
|
|
hnsw_config=hnsw_config,
|
|
wal_config=wal_config,
|
|
optimizers_config=optimizers_config,
|
|
shard_number=shard_number,
|
|
on_disk_payload=on_disk_payload,
|
|
timeout=timeout,
|
|
vectors_config=vectors_config,
|
|
replication_factor=replication_factor,
|
|
write_consistency_factor=write_consistency_factor,
|
|
quantization_config=quantization_config,
|
|
sparse_vectors_config=sparse_vectors_config,
|
|
sharding_method=sharding_method,
|
|
strict_mode_config=strict_mode_config,
|
|
metadata=metadata,
|
|
)
|
|
return self.grpc_collections.Create(create_collection, timeout=self._timeout).result
|
|
|
|
if isinstance(hnsw_config, grpc.HnswConfigDiff):
|
|
hnsw_config = GrpcToRest.convert_hnsw_config_diff(hnsw_config)
|
|
|
|
if isinstance(optimizers_config, grpc.OptimizersConfigDiff):
|
|
optimizers_config = GrpcToRest.convert_optimizers_config_diff(optimizers_config)
|
|
|
|
if isinstance(wal_config, grpc.WalConfigDiff):
|
|
wal_config = GrpcToRest.convert_wal_config_diff(wal_config)
|
|
|
|
if isinstance(quantization_config, grpc.QuantizationConfig):
|
|
quantization_config = GrpcToRest.convert_quantization_config(quantization_config)
|
|
|
|
create_collection_request = models.CreateCollection(
|
|
vectors=vectors_config,
|
|
shard_number=shard_number,
|
|
replication_factor=replication_factor,
|
|
write_consistency_factor=write_consistency_factor,
|
|
on_disk_payload=on_disk_payload,
|
|
hnsw_config=hnsw_config,
|
|
optimizers_config=optimizers_config,
|
|
wal_config=wal_config,
|
|
quantization_config=quantization_config,
|
|
sparse_vectors=sparse_vectors_config,
|
|
sharding_method=sharding_method,
|
|
strict_mode_config=strict_mode_config,
|
|
metadata=metadata,
|
|
)
|
|
|
|
result: Optional[bool] = self.http.collections_api.create_collection(
|
|
collection_name=collection_name,
|
|
create_collection=create_collection_request,
|
|
timeout=timeout,
|
|
).result
|
|
|
|
assert result is not None, "Create collection returned None"
|
|
return result
|
|
|
|
def recreate_collection(
|
|
self,
|
|
collection_name: str,
|
|
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
|
|
shard_number: Optional[int] = None,
|
|
replication_factor: Optional[int] = None,
|
|
write_consistency_factor: Optional[int] = None,
|
|
on_disk_payload: Optional[bool] = None,
|
|
hnsw_config: Optional[types.HnswConfigDiff] = None,
|
|
optimizers_config: Optional[types.OptimizersConfigDiff] = None,
|
|
wal_config: Optional[types.WalConfigDiff] = None,
|
|
quantization_config: Optional[types.QuantizationConfig] = None,
|
|
timeout: Optional[int] = None,
|
|
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
|
sharding_method: Optional[types.ShardingMethod] = None,
|
|
strict_mode_config: Optional[types.StrictModeConfig] = None,
|
|
metadata: Optional[types.Payload] = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
self.delete_collection(collection_name, timeout=timeout)
|
|
|
|
return self.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=vectors_config,
|
|
shard_number=shard_number,
|
|
replication_factor=replication_factor,
|
|
write_consistency_factor=write_consistency_factor,
|
|
on_disk_payload=on_disk_payload,
|
|
hnsw_config=hnsw_config,
|
|
optimizers_config=optimizers_config,
|
|
wal_config=wal_config,
|
|
quantization_config=quantization_config,
|
|
timeout=timeout,
|
|
sparse_vectors_config=sparse_vectors_config,
|
|
sharding_method=sharding_method,
|
|
strict_mode_config=strict_mode_config,
|
|
metadata=metadata,
|
|
)
|
|
|
|
@property
|
|
def _updater_class(self) -> Type[BaseUploader]:
|
|
if self._prefer_grpc:
|
|
return GrpcBatchUploader
|
|
else:
|
|
return RestBatchUploader
|
|
|
|
def _upload_collection(
|
|
self,
|
|
batches_iterator: Iterable,
|
|
collection_name: str,
|
|
max_retries: int,
|
|
parallel: int = 1,
|
|
method: Optional[str] = None,
|
|
wait: bool = False,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
update_filter: Optional[types.Filter] = None,
|
|
) -> None:
|
|
if method is not None:
|
|
if method in get_all_start_methods():
|
|
start_method = method
|
|
else:
|
|
raise ValueError(
|
|
f"Start methods {method} is not available, available methods: {get_all_start_methods()}"
|
|
)
|
|
else:
|
|
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
|
|
|
if self._prefer_grpc:
|
|
updater_kwargs = {
|
|
"collection_name": collection_name,
|
|
"host": self._host,
|
|
"port": self._grpc_port,
|
|
"max_retries": max_retries,
|
|
"ssl": self._https,
|
|
"metadata": self._grpc_headers,
|
|
"wait": wait,
|
|
"shard_key_selector": shard_key_selector,
|
|
"options": self._grpc_options,
|
|
"timeout": self._timeout,
|
|
"update_filter": update_filter,
|
|
}
|
|
else:
|
|
updater_kwargs = {
|
|
"collection_name": collection_name,
|
|
"uri": self.rest_uri,
|
|
"max_retries": max_retries,
|
|
"wait": wait,
|
|
"shard_key_selector": shard_key_selector,
|
|
"update_filter": update_filter,
|
|
**self._rest_args,
|
|
}
|
|
|
|
if parallel == 1:
|
|
updater = self._updater_class.start(**updater_kwargs)
|
|
for _ in updater.process(batches_iterator):
|
|
pass
|
|
else:
|
|
pool = ParallelWorkerPool(parallel, self._updater_class, start_method=start_method)
|
|
for _ in pool.unordered_map(batches_iterator, **updater_kwargs):
|
|
pass
|
|
|
|
def upload_points(
|
|
self,
|
|
collection_name: str,
|
|
points: Iterable[types.PointStruct],
|
|
batch_size: int = 64,
|
|
parallel: int = 1,
|
|
method: Optional[str] = None,
|
|
max_retries: int = 3,
|
|
wait: bool = False,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
update_filter: Optional[types.Filter] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
batches_iterator = self._updater_class.iterate_records_batches(
|
|
records=points, batch_size=batch_size
|
|
)
|
|
|
|
self._upload_collection(
|
|
batches_iterator=batches_iterator,
|
|
collection_name=collection_name,
|
|
max_retries=max_retries,
|
|
parallel=parallel,
|
|
method=method,
|
|
wait=wait,
|
|
shard_key_selector=shard_key_selector,
|
|
update_filter=update_filter,
|
|
)
|
|
|
|
def upload_collection(
|
|
self,
|
|
collection_name: str,
|
|
vectors: Union[
|
|
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
|
|
],
|
|
payload: Optional[Iterable[dict[Any, Any]]] = None,
|
|
ids: Optional[Iterable[types.PointId]] = None,
|
|
batch_size: int = 64,
|
|
parallel: int = 1,
|
|
method: Optional[str] = None,
|
|
max_retries: int = 3,
|
|
wait: bool = False,
|
|
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
|
update_filter: Optional[types.Filter] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
batches_iterator = self._updater_class.iterate_batches(
|
|
vectors=vectors,
|
|
payload=payload,
|
|
ids=ids,
|
|
batch_size=batch_size,
|
|
)
|
|
|
|
self._upload_collection(
|
|
batches_iterator=batches_iterator,
|
|
collection_name=collection_name,
|
|
max_retries=max_retries,
|
|
parallel=parallel,
|
|
method=method,
|
|
wait=wait,
|
|
shard_key_selector=shard_key_selector,
|
|
update_filter=update_filter,
|
|
)
|
|
|
|
def create_payload_index(
|
|
self,
|
|
collection_name: str,
|
|
field_name: str,
|
|
field_schema: Optional[types.PayloadSchemaType] = None,
|
|
field_type: Optional[types.PayloadSchemaType] = None,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if field_type is not None:
|
|
show_warning_once(
|
|
message="field_type is deprecated, use field_schema instead",
|
|
category=DeprecationWarning,
|
|
stacklevel=5,
|
|
idx="payload-index-field-type",
|
|
)
|
|
field_schema = field_type
|
|
|
|
if self._prefer_grpc:
|
|
field_index_params = None
|
|
if isinstance(field_schema, models.PayloadSchemaType):
|
|
field_schema = RestToGrpc.convert_payload_schema_type(field_schema)
|
|
|
|
if isinstance(field_schema, str):
|
|
field_schema = RestToGrpc.convert_payload_schema_type(
|
|
models.PayloadSchemaType(field_schema)
|
|
)
|
|
|
|
if isinstance(field_schema, int):
|
|
# There are no means to distinguish grpc.PayloadSchemaType and grpc.FieldType,
|
|
# as both of them are just ints
|
|
# method signature assumes that grpc.PayloadSchemaType is passed,
|
|
# otherwise the value will be corrupted
|
|
field_schema = grpc_payload_schema_to_field_type(field_schema)
|
|
|
|
if isinstance(field_schema, get_args(models.PayloadSchemaParams)):
|
|
field_schema = RestToGrpc.convert_payload_schema_params(field_schema)
|
|
|
|
if isinstance(field_schema, grpc.PayloadIndexParams):
|
|
field_index_params = field_schema
|
|
name = field_index_params.WhichOneof("index_params")
|
|
index_params = getattr(field_index_params, name)
|
|
if isinstance(index_params, grpc.TextIndexParams):
|
|
field_schema = grpc.FieldType.FieldTypeText
|
|
|
|
if isinstance(index_params, grpc.IntegerIndexParams):
|
|
field_schema = grpc.FieldType.FieldTypeInteger
|
|
|
|
if isinstance(index_params, grpc.KeywordIndexParams):
|
|
field_schema = grpc.FieldType.FieldTypeKeyword
|
|
|
|
if isinstance(index_params, grpc.FloatIndexParams):
|
|
field_schema = grpc.FieldType.FieldTypeFloat
|
|
|
|
if isinstance(index_params, grpc.GeoIndexParams):
|
|
field_schema = grpc.FieldType.FieldTypeGeo
|
|
|
|
if isinstance(index_params, grpc.BoolIndexParams):
|
|
field_schema = grpc.FieldType.FieldTypeBool
|
|
|
|
if isinstance(index_params, grpc.DatetimeIndexParams):
|
|
field_schema = grpc.FieldType.FieldTypeDatetime
|
|
|
|
if isinstance(index_params, grpc.UuidIndexParams):
|
|
field_schema = grpc.FieldType.FieldTypeUuid
|
|
|
|
request = grpc.CreateFieldIndexCollection(
|
|
collection_name=collection_name,
|
|
field_name=field_name,
|
|
field_type=field_schema,
|
|
field_index_params=field_index_params,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
)
|
|
return GrpcToRest.convert_update_result(
|
|
self.grpc_points.CreateFieldIndex(request, timeout=self._timeout).result
|
|
)
|
|
|
|
if isinstance(field_schema, int): # type(grpc.PayloadSchemaType) == int
|
|
field_schema = GrpcToRest.convert_payload_schema_type(field_schema)
|
|
|
|
if isinstance(field_schema, grpc.PayloadIndexParams):
|
|
field_schema = GrpcToRest.convert_payload_schema_params(field_schema)
|
|
|
|
result: Optional[types.UpdateResult] = self.openapi_client.indexes_api.create_field_index(
|
|
collection_name=collection_name,
|
|
create_field_index=models.CreateFieldIndex(
|
|
field_name=field_name, field_schema=field_schema
|
|
),
|
|
wait=wait,
|
|
ordering=ordering,
|
|
).result
|
|
assert result is not None, "Create field index returned None"
|
|
return result
|
|
|
|
def delete_payload_index(
|
|
self,
|
|
collection_name: str,
|
|
field_name: str,
|
|
wait: bool = True,
|
|
ordering: Optional[types.WriteOrdering] = None,
|
|
**kwargs: Any,
|
|
) -> types.UpdateResult:
|
|
if self._prefer_grpc:
|
|
request = grpc.DeleteFieldIndexCollection(
|
|
collection_name=collection_name,
|
|
field_name=field_name,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
)
|
|
return GrpcToRest.convert_update_result(
|
|
self.grpc_points.DeleteFieldIndex(request, timeout=self._timeout).result
|
|
)
|
|
|
|
result: Optional[types.UpdateResult] = self.openapi_client.indexes_api.delete_field_index(
|
|
collection_name=collection_name,
|
|
field_name=field_name,
|
|
wait=wait,
|
|
ordering=ordering,
|
|
).result
|
|
assert result is not None, "Delete field index returned None"
|
|
return result
|
|
|
|
def list_snapshots(
|
|
self, collection_name: str, **kwargs: Any
|
|
) -> list[types.SnapshotDescription]:
|
|
if self._prefer_grpc:
|
|
snapshots = self.grpc_snapshots.List(
|
|
grpc.ListSnapshotsRequest(collection_name=collection_name), timeout=self._timeout
|
|
).snapshot_descriptions
|
|
return [GrpcToRest.convert_snapshot_description(snapshot) for snapshot in snapshots]
|
|
|
|
snapshots = self.openapi_client.snapshots_api.list_snapshots(
|
|
collection_name=collection_name
|
|
).result
|
|
assert snapshots is not None, "List snapshots API returned None result"
|
|
return snapshots
|
|
|
|
def create_snapshot(
|
|
self, collection_name: str, wait: bool = True, **kwargs: Any
|
|
) -> Optional[types.SnapshotDescription]:
|
|
if self._prefer_grpc:
|
|
snapshot = self.grpc_snapshots.Create(
|
|
grpc.CreateSnapshotRequest(collection_name=collection_name), timeout=self._timeout
|
|
).snapshot_description
|
|
return GrpcToRest.convert_snapshot_description(snapshot)
|
|
|
|
return self.openapi_client.snapshots_api.create_snapshot(
|
|
collection_name=collection_name, wait=wait
|
|
).result
|
|
|
|
def delete_snapshot(
|
|
self, collection_name: str, snapshot_name: str, wait: bool = True, **kwargs: Any
|
|
) -> Optional[bool]:
|
|
if self._prefer_grpc:
|
|
self.grpc_snapshots.Delete(
|
|
grpc.DeleteSnapshotRequest(
|
|
collection_name=collection_name, snapshot_name=snapshot_name
|
|
),
|
|
timeout=self._timeout,
|
|
)
|
|
return True
|
|
|
|
return self.openapi_client.snapshots_api.delete_snapshot(
|
|
collection_name=collection_name,
|
|
snapshot_name=snapshot_name,
|
|
wait=wait,
|
|
).result
|
|
|
|
def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
|
|
if self._prefer_grpc:
|
|
snapshots = self.grpc_snapshots.ListFull(
|
|
grpc.ListFullSnapshotsRequest(),
|
|
timeout=self._timeout,
|
|
).snapshot_descriptions
|
|
return [GrpcToRest.convert_snapshot_description(snapshot) for snapshot in snapshots]
|
|
|
|
snapshots = self.openapi_client.snapshots_api.list_full_snapshots().result
|
|
assert snapshots is not None, "List full snapshots API returned None result"
|
|
return snapshots
|
|
|
|
def create_full_snapshot(self, wait: bool = True, **kwargs: Any) -> types.SnapshotDescription:
|
|
if self._prefer_grpc:
|
|
snapshot_description = self.grpc_snapshots.CreateFull(
|
|
grpc.CreateFullSnapshotRequest(), timeout=self._timeout
|
|
).snapshot_description
|
|
return GrpcToRest.convert_snapshot_description(snapshot_description)
|
|
|
|
return self.openapi_client.snapshots_api.create_full_snapshot(wait=wait).result
|
|
|
|
def delete_full_snapshot(
|
|
self, snapshot_name: str, wait: bool = True, **kwargs: Any
|
|
) -> Optional[bool]:
|
|
if self._prefer_grpc:
|
|
self.grpc_snapshots.DeleteFull(
|
|
grpc.DeleteFullSnapshotRequest(snapshot_name=snapshot_name),
|
|
timeout=self._timeout,
|
|
)
|
|
return True
|
|
|
|
return self.openapi_client.snapshots_api.delete_full_snapshot(
|
|
snapshot_name=snapshot_name, wait=wait
|
|
).result
|
|
|
|
def recover_snapshot(
|
|
self,
|
|
collection_name: str,
|
|
location: str,
|
|
api_key: Optional[str] = None,
|
|
checksum: Optional[str] = None,
|
|
priority: Optional[types.SnapshotPriority] = None,
|
|
wait: bool = True,
|
|
**kwargs: Any,
|
|
) -> Optional[bool]:
|
|
return self.openapi_client.snapshots_api.recover_from_snapshot(
|
|
collection_name=collection_name,
|
|
wait=wait,
|
|
snapshot_recover=models.SnapshotRecover(
|
|
location=location,
|
|
priority=priority,
|
|
checksum=checksum,
|
|
api_key=api_key,
|
|
),
|
|
).result
|
|
|
|
def list_shard_snapshots(
|
|
self, collection_name: str, shard_id: int, **kwargs: Any
|
|
) -> list[types.SnapshotDescription]:
|
|
snapshots = self.openapi_client.snapshots_api.list_shard_snapshots(
|
|
collection_name=collection_name,
|
|
shard_id=shard_id,
|
|
).result
|
|
assert snapshots is not None, "List snapshots API returned None result"
|
|
return snapshots
|
|
|
|
def create_shard_snapshot(
|
|
self, collection_name: str, shard_id: int, wait: bool = True, **kwargs: Any
|
|
) -> Optional[types.SnapshotDescription]:
|
|
return self.openapi_client.snapshots_api.create_shard_snapshot(
|
|
collection_name=collection_name,
|
|
shard_id=shard_id,
|
|
wait=wait,
|
|
).result
|
|
|
|
def delete_shard_snapshot(
|
|
self,
|
|
collection_name: str,
|
|
shard_id: int,
|
|
snapshot_name: str,
|
|
wait: bool = True,
|
|
**kwargs: Any,
|
|
) -> Optional[bool]:
|
|
return self.openapi_client.snapshots_api.delete_shard_snapshot(
|
|
collection_name=collection_name,
|
|
shard_id=shard_id,
|
|
snapshot_name=snapshot_name,
|
|
wait=wait,
|
|
).result
|
|
|
|
def recover_shard_snapshot(
|
|
self,
|
|
collection_name: str,
|
|
shard_id: int,
|
|
location: str,
|
|
api_key: Optional[str] = None,
|
|
checksum: Optional[str] = None,
|
|
priority: Optional[types.SnapshotPriority] = None,
|
|
wait: bool = True,
|
|
**kwargs: Any,
|
|
) -> Optional[bool]:
|
|
return self.openapi_client.snapshots_api.recover_shard_from_snapshot(
|
|
collection_name=collection_name,
|
|
shard_id=shard_id,
|
|
wait=wait,
|
|
shard_snapshot_recover=models.ShardSnapshotRecover(
|
|
location=location,
|
|
priority=priority,
|
|
checksum=checksum,
|
|
api_key=api_key,
|
|
),
|
|
).result
|
|
|
|
def create_shard_key(
|
|
self,
|
|
collection_name: str,
|
|
shard_key: types.ShardKey,
|
|
shards_number: Optional[int] = None,
|
|
replication_factor: Optional[int] = None,
|
|
placement: Optional[list[int]] = None,
|
|
initial_state: Optional[types.ReplicaState] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
if self._prefer_grpc:
|
|
if isinstance(shard_key, get_args_subscribed(models.ShardKey)):
|
|
shard_key = RestToGrpc.convert_shard_key(shard_key)
|
|
|
|
if isinstance(initial_state, models.ReplicaState):
|
|
initial_state = RestToGrpc.convert_replica_state(initial_state)
|
|
|
|
return self.grpc_collections.CreateShardKey(
|
|
grpc.CreateShardKeyRequest(
|
|
collection_name=collection_name,
|
|
timeout=timeout,
|
|
request=grpc.CreateShardKey(
|
|
shard_key=shard_key,
|
|
shards_number=shards_number,
|
|
replication_factor=replication_factor,
|
|
placement=placement or [],
|
|
initial_state=initial_state,
|
|
),
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
else:
|
|
result = self.openapi_client.distributed_api.create_shard_key(
|
|
collection_name=collection_name,
|
|
timeout=timeout,
|
|
create_sharding_key=models.CreateShardingKey(
|
|
shard_key=shard_key,
|
|
shards_number=shards_number,
|
|
replication_factor=replication_factor,
|
|
placement=placement,
|
|
initial_state=initial_state,
|
|
),
|
|
).result
|
|
assert result is not None, "Create shard key returned None"
|
|
return result
|
|
|
|
def delete_shard_key(
|
|
self,
|
|
collection_name: str,
|
|
shard_key: types.ShardKey,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
if self._prefer_grpc:
|
|
if isinstance(shard_key, get_args_subscribed(models.ShardKey)):
|
|
shard_key = RestToGrpc.convert_shard_key(shard_key)
|
|
|
|
return self.grpc_collections.DeleteShardKey(
|
|
grpc.DeleteShardKeyRequest(
|
|
collection_name=collection_name,
|
|
timeout=timeout,
|
|
request=grpc.DeleteShardKey(
|
|
shard_key=shard_key,
|
|
),
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
else:
|
|
result = self.openapi_client.distributed_api.delete_shard_key(
|
|
collection_name=collection_name,
|
|
timeout=timeout,
|
|
drop_sharding_key=models.DropShardingKey(
|
|
shard_key=shard_key,
|
|
),
|
|
).result
|
|
assert result is not None, "Delete shard key returned None"
|
|
return result
|
|
|
|
def info(self) -> types.VersionInfo:
|
|
if self._prefer_grpc:
|
|
version_info = self.grpc_root.HealthCheck(
|
|
grpc.HealthCheckRequest(), timeout=self._timeout
|
|
)
|
|
return GrpcToRest.convert_health_check_reply(version_info)
|
|
version_info = self.rest.service_api.root()
|
|
assert version_info is not None, "Healthcheck returned None"
|
|
return version_info
|
|
|
|
def cluster_collection_update(
|
|
self,
|
|
collection_name: str,
|
|
cluster_operation: types.ClusterOperations,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
if self._prefer_grpc:
|
|
cluster_operation = RestToGrpc.convert_cluster_operations(cluster_operation)
|
|
grpc_operation = {}
|
|
|
|
if isinstance(cluster_operation, grpc.MoveShard):
|
|
grpc_operation["move_shard"] = cluster_operation
|
|
elif isinstance(cluster_operation, grpc.ReplicateShard):
|
|
grpc_operation["replicate_shard"] = cluster_operation
|
|
elif isinstance(cluster_operation, grpc.AbortShardTransfer):
|
|
grpc_operation["abort_transfer"] = cluster_operation
|
|
elif isinstance(cluster_operation, grpc.Replica):
|
|
grpc_operation["drop_replica"] = cluster_operation
|
|
elif isinstance(cluster_operation, grpc.CreateShardKey):
|
|
grpc_operation["create_shard_key"] = cluster_operation
|
|
elif isinstance(cluster_operation, grpc.DeleteShardKey):
|
|
grpc_operation["delete_shard_key"] = cluster_operation
|
|
elif isinstance(cluster_operation, grpc.RestartTransfer):
|
|
grpc_operation["restart_transfer"] = cluster_operation
|
|
elif isinstance(cluster_operation, grpc.ReplicatePoints):
|
|
grpc_operation["replicate_points"] = cluster_operation
|
|
else:
|
|
raise TypeError(f"Unknown cluster operation: {cluster_operation}")
|
|
|
|
return self.grpc_collections.UpdateCollectionClusterSetup(
|
|
grpc.UpdateCollectionClusterSetupRequest(
|
|
collection_name=collection_name, timeout=timeout, **grpc_operation
|
|
),
|
|
timeout=timeout if timeout is not None else self._timeout,
|
|
).result
|
|
update_result = self.rest.distributed_api.update_collection_cluster(
|
|
collection_name=collection_name, cluster_operations=cluster_operation, timeout=timeout
|
|
).result
|
|
assert update_result is not None, "Cluster collection update returned None"
|
|
return update_result
|
|
|
|
def cluster_status(self) -> types.ClusterStatus:
|
|
# grpc does not have cluster status api
|
|
status_result = self.rest.distributed_api.cluster_status().result
|
|
assert status_result is not None, "Cluster status returned None"
|
|
return status_result
|
|
|
|
def recover_current_peer(self) -> bool:
|
|
# grpc does not have recover peer api
|
|
recover_result = self.rest.distributed_api.recover_current_peer().result
|
|
assert recover_result is not None, "Recover current peer returned None"
|
|
return recover_result
|
|
|
|
def remove_peer(
|
|
self,
|
|
peer_id: int,
|
|
force: Optional[bool] = None,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
# grpc does not have remove peer api
|
|
update_result = self.rest.distributed_api.remove_peer(
|
|
peer_id=peer_id,
|
|
force=force,
|
|
timeout=timeout,
|
|
).result
|
|
assert update_result is not None, "Remove peer returned None"
|
|
return update_result
|
|
|
|
def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo:
|
|
if self._prefer_grpc:
|
|
collection_info = self.grpc_collections.CollectionClusterInfo(
|
|
grpc.CollectionClusterInfoRequest(collection_name=collection_name),
|
|
timeout=self._timeout,
|
|
)
|
|
return GrpcToRest.convert_collection_cluster_info(collection_info)
|
|
collection_info = self.rest.distributed_api.collection_cluster_info(
|
|
collection_name=collection_name
|
|
).result
|
|
assert collection_info is not None, "Collection cluster info returned None"
|
|
return collection_info
|