refactor: excel parse
This commit is contained in:
@@ -0,0 +1,973 @@
|
||||
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
|
||||
#
|
||||
# This file is autogenerated. Do not edit it manually.
|
||||
# To regenerate this file, use
|
||||
#
|
||||
# ```
|
||||
# bash -x tools/generate_async_client.sh
|
||||
# ```
|
||||
#
|
||||
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
|
||||
|
||||
import importlib.metadata
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from io import TextIOWrapper
|
||||
from typing import Any, Generator, Iterable, Mapping, Optional, Sequence, Union, get_args
|
||||
from uuid import uuid4
|
||||
import numpy as np
|
||||
import portalocker
|
||||
from qdrant_client.common.client_warnings import show_warning, show_warning_once
|
||||
from qdrant_client._pydantic_compat import to_dict
|
||||
from qdrant_client.async_client_base import AsyncQdrantBase
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.http import models as rest_models
|
||||
from qdrant_client.local.local_collection import (
|
||||
LocalCollection,
|
||||
DEFAULT_VECTOR_NAME,
|
||||
ignore_mentioned_ids_filter,
|
||||
)
|
||||
|
||||
META_INFO_FILENAME = "meta.json"
|
||||
|
||||
|
||||
class AsyncQdrantLocal(AsyncQdrantBase):
|
||||
"""
|
||||
Everything Qdrant server can do, but locally.
|
||||
|
||||
Use this implementation to run vector search without running a Qdrant server.
|
||||
Everything that works with local Qdrant will work with server Qdrant as well.
|
||||
|
||||
Use for small-scale data, demos, and tests.
|
||||
If you need more speed or size, use Qdrant server.
|
||||
"""
|
||||
|
||||
LARGE_DATA_THRESHOLD = 20000
|
||||
|
||||
def __init__(self, location: str, force_disable_check_same_thread: bool = False) -> None:
|
||||
"""
|
||||
Initialize local Qdrant.
|
||||
|
||||
Args:
|
||||
location: Where to store data. Can be a path to a directory or `:memory:` for in-memory storage.
|
||||
force_disable_check_same_thread: Disable SQLite check_same_thread check. Use only if you know what you are doing.
|
||||
"""
|
||||
super().__init__()
|
||||
self.force_disable_check_same_thread = force_disable_check_same_thread
|
||||
self.location = location
|
||||
self.persistent = location != ":memory:"
|
||||
self.collections: dict[str, LocalCollection] = {}
|
||||
self.aliases: dict[str, str] = {}
|
||||
self._flock_file: Optional[TextIOWrapper] = None
|
||||
self._load()
|
||||
self._closed: bool = False
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def close(self, **kwargs: Any) -> None:
|
||||
self._closed = True
|
||||
for collection in self.collections.values():
|
||||
if collection is not None:
|
||||
collection.close()
|
||||
else:
|
||||
show_warning(
|
||||
message=f"Collection appears to be None before closing. The existing collections are: {list(self.collections.keys())}",
|
||||
category=UserWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
try:
|
||||
if self._flock_file is not None and (not self._flock_file.closed):
|
||||
portalocker.unlock(self._flock_file)
|
||||
self._flock_file.close()
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
def _load(self) -> None:
|
||||
deprecated_config_fields = ("init_from",)
|
||||
if not self.persistent:
|
||||
return
|
||||
meta_path = os.path.join(self.location, META_INFO_FILENAME)
|
||||
if not os.path.exists(meta_path):
|
||||
os.makedirs(self.location, exist_ok=True)
|
||||
with open(meta_path, "w") as f:
|
||||
f.write(json.dumps({"collections": {}, "aliases": {}}))
|
||||
else:
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
for collection_name, config_json in meta["collections"].items():
|
||||
for key in deprecated_config_fields:
|
||||
config_json.pop(key, None)
|
||||
config = rest_models.CreateCollection(**config_json)
|
||||
collection_path = self._collection_path(collection_name)
|
||||
collection = LocalCollection(
|
||||
config,
|
||||
collection_path,
|
||||
force_disable_check_same_thread=self.force_disable_check_same_thread,
|
||||
)
|
||||
self.collections[collection_name] = collection
|
||||
if len(collection.ids) > self.LARGE_DATA_THRESHOLD:
|
||||
show_warning(
|
||||
f"Local mode is not recommended for collections with more than {self.LARGE_DATA_THRESHOLD:,} points. Collection <{collection_name}> contains {len(collection.ids)} points. Consider using Qdrant in Docker or Qdrant Cloud for better performance with large datasets.",
|
||||
category=UserWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
self.aliases = meta["aliases"]
|
||||
lock_file_path = os.path.join(self.location, ".lock")
|
||||
if not os.path.exists(lock_file_path):
|
||||
os.makedirs(self.location, exist_ok=True)
|
||||
with open(lock_file_path, "w") as f:
|
||||
f.write("tmp lock file")
|
||||
self._flock_file = open(lock_file_path, "r+")
|
||||
try:
|
||||
portalocker.lock(
|
||||
self._flock_file,
|
||||
portalocker.LockFlags.EXCLUSIVE | portalocker.LockFlags.NON_BLOCKING,
|
||||
)
|
||||
except portalocker.exceptions.LockException:
|
||||
raise RuntimeError(
|
||||
f"Storage folder {self.location} is already accessed by another instance of Qdrant client. If you require concurrent access, use Qdrant server instead."
|
||||
)
|
||||
|
||||
def _save(self) -> None:
|
||||
if not self.persistent:
|
||||
return
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
meta_path = os.path.join(self.location, META_INFO_FILENAME)
|
||||
with open(meta_path, "w") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"collections": {
|
||||
collection_name: to_dict(collection.config)
|
||||
for (collection_name, collection) in self.collections.items()
|
||||
},
|
||||
"aliases": self.aliases,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _get_collection(self, collection_name: str) -> LocalCollection:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
if collection_name in self.collections:
|
||||
return self.collections[collection_name]
|
||||
if collection_name in self.aliases:
|
||||
return self.collections[self.aliases[collection_name]]
|
||||
raise ValueError(f"Collection {collection_name} not found")
|
||||
|
||||
def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: Union[
|
||||
types.NumpyArray,
|
||||
Sequence[float],
|
||||
tuple[str, list[float]],
|
||||
types.NamedVector,
|
||||
types.NamedSparseVector,
|
||||
],
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
offset: Optional[int] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[types.ScoredPoint]:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.search(
|
||||
query_vector=query_vector,
|
||||
query_filter=query_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
async def search_matrix_offsets(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
limit: int = 3,
|
||||
sample: int = 10,
|
||||
using: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.SearchMatrixOffsetsResponse:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.search_matrix_offsets(
|
||||
query_filter=query_filter, limit=limit, sample=sample, using=using
|
||||
)
|
||||
|
||||
async def search_matrix_pairs(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
limit: int = 3,
|
||||
sample: int = 10,
|
||||
using: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.SearchMatrixPairsResponse:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.search_matrix_pairs(
|
||||
query_filter=query_filter, limit=limit, sample=sample, using=using
|
||||
)
|
||||
|
||||
def _resolve_query_input(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: Optional[types.Query],
|
||||
using: Optional[str],
|
||||
lookup_from: Optional[types.LookupLocation],
|
||||
) -> tuple[types.Query, set[types.PointId]]:
|
||||
"""
|
||||
Resolves any possible ids into vectors and returns a new query object, along with a set of the mentioned
|
||||
point ids that should be filtered when searching.
|
||||
"""
|
||||
lookup_collection_name = lookup_from.collection if lookup_from else collection_name
|
||||
collection = self._get_collection(lookup_collection_name)
|
||||
search_in_vector_name = using if using is not None else DEFAULT_VECTOR_NAME
|
||||
vector_name = (
|
||||
lookup_from.vector
|
||||
if lookup_from is not None and lookup_from.vector is not None
|
||||
else search_in_vector_name
|
||||
)
|
||||
sparse = vector_name in collection.sparse_vectors
|
||||
multi = vector_name in collection.multivectors
|
||||
if sparse:
|
||||
collection_vectors = collection.sparse_vectors
|
||||
elif multi:
|
||||
collection_vectors = collection.multivectors
|
||||
else:
|
||||
collection_vectors = collection.vectors
|
||||
mentioned_ids: set[types.PointId] = set()
|
||||
|
||||
def input_into_vector(vector_input: types.VectorInput) -> types.VectorInput:
|
||||
if isinstance(vector_input, get_args(types.PointId)):
|
||||
if isinstance(vector_input, uuid.UUID):
|
||||
vector_input = str(vector_input)
|
||||
point_id = vector_input
|
||||
if point_id not in collection.ids:
|
||||
raise ValueError(f"Point {point_id} is not found in the collection")
|
||||
idx = collection.ids[point_id]
|
||||
if vector_name in collection_vectors:
|
||||
vec = collection_vectors[vector_name][idx]
|
||||
else:
|
||||
raise ValueError(f"Vector {vector_name} not found")
|
||||
if isinstance(vec, np.ndarray):
|
||||
vec = vec.tolist()
|
||||
if collection_name == lookup_collection_name:
|
||||
mentioned_ids.add(point_id)
|
||||
return vec
|
||||
else:
|
||||
return vector_input
|
||||
|
||||
query = deepcopy(query)
|
||||
if isinstance(query, rest_models.NearestQuery):
|
||||
query.nearest = input_into_vector(query.nearest)
|
||||
elif isinstance(query, rest_models.RecommendQuery):
|
||||
if query.recommend.negative is not None:
|
||||
query.recommend.negative = [
|
||||
input_into_vector(vector_input) for vector_input in query.recommend.negative
|
||||
]
|
||||
if query.recommend.positive is not None:
|
||||
query.recommend.positive = [
|
||||
input_into_vector(vector_input) for vector_input in query.recommend.positive
|
||||
]
|
||||
elif isinstance(query, rest_models.DiscoverQuery):
|
||||
query.discover.target = input_into_vector(query.discover.target)
|
||||
pairs = (
|
||||
query.discover.context
|
||||
if isinstance(query.discover.context, list)
|
||||
else [query.discover.context]
|
||||
)
|
||||
query.discover.context = [
|
||||
rest_models.ContextPair(
|
||||
positive=input_into_vector(pair.positive),
|
||||
negative=input_into_vector(pair.negative),
|
||||
)
|
||||
for pair in pairs
|
||||
]
|
||||
elif isinstance(query, rest_models.ContextQuery):
|
||||
pairs = query.context if isinstance(query.context, list) else [query.context]
|
||||
query.context = [
|
||||
rest_models.ContextPair(
|
||||
positive=input_into_vector(pair.positive),
|
||||
negative=input_into_vector(pair.negative),
|
||||
)
|
||||
for pair in pairs
|
||||
]
|
||||
elif isinstance(query, rest_models.OrderByQuery):
|
||||
pass
|
||||
elif isinstance(query, rest_models.FusionQuery):
|
||||
pass
|
||||
elif isinstance(query, rest_models.RrfQuery):
|
||||
pass
|
||||
return (query, mentioned_ids)
|
||||
|
||||
def _resolve_prefetches_input(
|
||||
self,
|
||||
prefetch: Optional[Union[Sequence[types.Prefetch], types.Prefetch]],
|
||||
collection_name: str,
|
||||
) -> list[types.Prefetch]:
|
||||
if prefetch is None:
|
||||
return []
|
||||
if isinstance(prefetch, list) and len(prefetch) == 0:
|
||||
return []
|
||||
prefetches = []
|
||||
if isinstance(prefetch, types.Prefetch):
|
||||
prefetches = [prefetch]
|
||||
prefetches.extend(
|
||||
prefetch.prefetch if isinstance(prefetch.prefetch, list) else [prefetch.prefetch]
|
||||
)
|
||||
elif isinstance(prefetch, Sequence):
|
||||
prefetches = list(prefetch)
|
||||
return [
|
||||
self._resolve_prefetch_input(prefetch, collection_name)
|
||||
for prefetch in prefetches
|
||||
if prefetch is not None
|
||||
]
|
||||
|
||||
def _resolve_prefetch_input(
|
||||
self, prefetch: types.Prefetch, collection_name: str
|
||||
) -> types.Prefetch:
|
||||
if prefetch.query is None:
|
||||
return prefetch
|
||||
prefetch = deepcopy(prefetch)
|
||||
(query, mentioned_ids) = self._resolve_query_input(
|
||||
collection_name, prefetch.query, prefetch.using, prefetch.lookup_from
|
||||
)
|
||||
prefetch.query = query
|
||||
prefetch.filter = ignore_mentioned_ids_filter(prefetch.filter, list(mentioned_ids))
|
||||
prefetch.prefetch = self._resolve_prefetches_input(prefetch.prefetch, collection_name)
|
||||
return prefetch
|
||||
|
||||
async def query_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: Optional[types.Query] = None,
|
||||
using: Optional[str] = None,
|
||||
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
offset: Optional[int] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
lookup_from: Optional[types.LookupLocation] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.QueryResponse:
|
||||
collection = self._get_collection(collection_name)
|
||||
if query is not None:
|
||||
(query, mentioned_ids) = self._resolve_query_input(
|
||||
collection_name, query, using, lookup_from
|
||||
)
|
||||
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
|
||||
prefetch = self._resolve_prefetches_input(prefetch, collection_name)
|
||||
return collection.query_points(
|
||||
query=query,
|
||||
prefetch=prefetch,
|
||||
query_filter=query_filter,
|
||||
using=using,
|
||||
score_threshold=score_threshold,
|
||||
limit=limit,
|
||||
offset=offset or 0,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
)
|
||||
|
||||
async def query_batch_points(
|
||||
self, collection_name: str, requests: Sequence[types.QueryRequest], **kwargs: Any
|
||||
) -> list[types.QueryResponse]:
|
||||
return [
|
||||
await self.query_points(
|
||||
collection_name=collection_name,
|
||||
query=request.query,
|
||||
prefetch=request.prefetch,
|
||||
query_filter=request.filter,
|
||||
limit=request.limit or 10,
|
||||
offset=request.offset,
|
||||
with_payload=request.with_payload,
|
||||
with_vectors=request.with_vector,
|
||||
score_threshold=request.score_threshold,
|
||||
using=request.using,
|
||||
lookup_from=request.lookup_from,
|
||||
)
|
||||
for request in requests
|
||||
]
|
||||
|
||||
async def query_points_groups(
|
||||
self,
|
||||
collection_name: str,
|
||||
group_by: str,
|
||||
query: Union[
|
||||
types.PointId,
|
||||
list[float],
|
||||
list[list[float]],
|
||||
types.SparseVector,
|
||||
types.Query,
|
||||
types.NumpyArray,
|
||||
types.Document,
|
||||
types.Image,
|
||||
types.InferenceObject,
|
||||
None,
|
||||
] = None,
|
||||
using: Optional[str] = None,
|
||||
prefetch: Union[types.Prefetch, list[types.Prefetch], None] = None,
|
||||
query_filter: Optional[types.Filter] = None,
|
||||
search_params: Optional[types.SearchParams] = None,
|
||||
limit: int = 10,
|
||||
group_size: int = 3,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
with_lookup: Optional[types.WithLookupInterface] = None,
|
||||
lookup_from: Optional[types.LookupLocation] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.GroupsResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
if query is not None:
|
||||
(query, mentioned_ids) = self._resolve_query_input(
|
||||
collection_name, query, using, lookup_from
|
||||
)
|
||||
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
|
||||
with_lookup_collection = None
|
||||
if with_lookup is not None:
|
||||
if isinstance(with_lookup, str):
|
||||
with_lookup_collection = self._get_collection(with_lookup)
|
||||
else:
|
||||
with_lookup_collection = self._get_collection(with_lookup.collection)
|
||||
return collection.query_groups(
|
||||
query=query,
|
||||
query_filter=query_filter,
|
||||
using=using,
|
||||
prefetch=prefetch,
|
||||
limit=limit,
|
||||
group_by=group_by,
|
||||
group_size=group_size,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
score_threshold=score_threshold,
|
||||
with_lookup=with_lookup,
|
||||
with_lookup_collection=with_lookup_collection,
|
||||
)
|
||||
|
||||
async def scroll(
|
||||
self,
|
||||
collection_name: str,
|
||||
scroll_filter: Optional[types.Filter] = None,
|
||||
limit: int = 10,
|
||||
order_by: Optional[types.OrderBy] = None,
|
||||
offset: Optional[types.PointId] = None,
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
**kwargs: Any,
|
||||
) -> tuple[list[types.Record], Optional[types.PointId]]:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.scroll(
|
||||
scroll_filter=scroll_filter,
|
||||
limit=limit,
|
||||
order_by=order_by,
|
||||
offset=offset,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
)
|
||||
|
||||
async def count(
|
||||
self,
|
||||
collection_name: str,
|
||||
count_filter: Optional[types.Filter] = None,
|
||||
exact: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> types.CountResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.count(count_filter=count_filter)
|
||||
|
||||
async def facet(
|
||||
self,
|
||||
collection_name: str,
|
||||
key: str,
|
||||
facet_filter: Optional[types.Filter] = None,
|
||||
limit: int = 10,
|
||||
exact: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> types.FacetResponse:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.facet(key=key, facet_filter=facet_filter, limit=limit)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: types.Points,
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.upsert(points, update_filter=update_filter)
|
||||
return self._default_update_result()
|
||||
|
||||
async def update_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: Sequence[types.PointVectors],
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.update_vectors(points, update_filter=update_filter)
|
||||
return self._default_update_result()
|
||||
|
||||
async def delete_vectors(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: Sequence[str],
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.delete_vectors(vectors, points)
|
||||
return self._default_update_result()
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Sequence[types.PointId],
|
||||
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
|
||||
with_vectors: Union[bool, Sequence[str]] = False,
|
||||
**kwargs: Any,
|
||||
) -> list[types.Record]:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.retrieve(ids, with_payload, with_vectors)
|
||||
|
||||
@classmethod
|
||||
def _default_update_result(cls, operation_id: int = 0) -> types.UpdateResult:
|
||||
return types.UpdateResult(
|
||||
operation_id=operation_id, status=rest_models.UpdateStatus.COMPLETED
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.delete(points_selector)
|
||||
return self._default_update_result()
|
||||
|
||||
async def set_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: types.Payload,
|
||||
points: types.PointsSelector,
|
||||
key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.set_payload(payload=payload, selector=points, key=key)
|
||||
return self._default_update_result()
|
||||
|
||||
async def overwrite_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
payload: types.Payload,
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.overwrite_payload(payload=payload, selector=points)
|
||||
return self._default_update_result()
|
||||
|
||||
async def delete_payload(
|
||||
self,
|
||||
collection_name: str,
|
||||
keys: Sequence[str],
|
||||
points: types.PointsSelector,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.delete_payload(keys=keys, selector=points)
|
||||
return self._default_update_result()
|
||||
|
||||
async def clear_payload(
|
||||
self, collection_name: str, points_selector: types.PointsSelector, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.clear_payload(selector=points_selector)
|
||||
return self._default_update_result()
|
||||
|
||||
async def batch_update_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
update_operations: Sequence[types.UpdateOperation],
|
||||
**kwargs: Any,
|
||||
) -> list[types.UpdateResult]:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.batch_update_points(update_operations)
|
||||
return [self._default_update_result()] * len(update_operations)
|
||||
|
||||
async def update_collection_aliases(
|
||||
self, change_aliases_operations: Sequence[types.AliasOperations], **kwargs: Any
|
||||
) -> bool:
|
||||
for operation in change_aliases_operations:
|
||||
if isinstance(operation, rest_models.CreateAliasOperation):
|
||||
self._get_collection(operation.create_alias.collection_name)
|
||||
self.aliases[operation.create_alias.alias_name] = (
|
||||
operation.create_alias.collection_name
|
||||
)
|
||||
elif isinstance(operation, rest_models.DeleteAliasOperation):
|
||||
self.aliases.pop(operation.delete_alias.alias_name, None)
|
||||
elif isinstance(operation, rest_models.RenameAliasOperation):
|
||||
new_name = operation.rename_alias.new_alias_name
|
||||
old_name = operation.rename_alias.old_alias_name
|
||||
self.aliases[new_name] = self.aliases.pop(old_name)
|
||||
else:
|
||||
raise ValueError(f"Unknown operation: {operation}")
|
||||
self._save()
|
||||
return True
|
||||
|
||||
async def get_collection_aliases(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> types.CollectionsAliasesResponse:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
return types.CollectionsAliasesResponse(
|
||||
aliases=[
|
||||
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
|
||||
for (alias_name, name) in self.aliases.items()
|
||||
if name == collection_name
|
||||
]
|
||||
)
|
||||
|
||||
async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
return types.CollectionsAliasesResponse(
|
||||
aliases=[
|
||||
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
|
||||
for (alias_name, name) in self.aliases.items()
|
||||
]
|
||||
)
|
||||
|
||||
async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
return types.CollectionsResponse(
|
||||
collections=[
|
||||
rest_models.CollectionDescription(name=name)
|
||||
for (name, _) in self.collections.items()
|
||||
]
|
||||
)
|
||||
|
||||
async def get_collection(self, collection_name: str, **kwargs: Any) -> types.CollectionInfo:
|
||||
collection = self._get_collection(collection_name)
|
||||
return collection.info()
|
||||
|
||||
async def collection_exists(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
try:
|
||||
self._get_collection(collection_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
async def update_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
||||
metadata: Optional[types.Payload] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
_collection = self._get_collection(collection_name)
|
||||
updated = False
|
||||
if sparse_vectors_config is not None:
|
||||
for vector_name, vector_params in sparse_vectors_config.items():
|
||||
_collection.update_sparse_vectors_config(vector_name, vector_params)
|
||||
updated = True
|
||||
if metadata is not None:
|
||||
if _collection.config.metadata is not None:
|
||||
_collection.config.metadata.update(metadata)
|
||||
else:
|
||||
_collection.config.metadata = deepcopy(metadata)
|
||||
updated = True
|
||||
self._save()
|
||||
return updated
|
||||
|
||||
def _collection_path(self, collection_name: str) -> Optional[str]:
|
||||
if self.persistent:
|
||||
return os.path.join(self.location, "collection", collection_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
_collection = self.collections.pop(collection_name, None)
|
||||
del _collection
|
||||
self.aliases = {
|
||||
alias_name: name
|
||||
for (alias_name, name) in self.aliases.items()
|
||||
if name != collection_name
|
||||
}
|
||||
collection_path = self._collection_path(collection_name)
|
||||
if collection_path is not None:
|
||||
shutil.rmtree(collection_path, ignore_errors=True)
|
||||
self._save()
|
||||
return True
|
||||
|
||||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors_config: Optional[
|
||||
Union[types.VectorParams, Mapping[str, types.VectorParams]]
|
||||
] = None,
|
||||
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
||||
metadata: Optional[types.Payload] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
if self.closed:
|
||||
raise RuntimeError("QdrantLocal instance is closed. Please create a new instance.")
|
||||
if collection_name in self.collections:
|
||||
raise ValueError(f"Collection {collection_name} already exists")
|
||||
collection_path = self._collection_path(collection_name)
|
||||
if collection_path is not None:
|
||||
os.makedirs(collection_path, exist_ok=True)
|
||||
collection = LocalCollection(
|
||||
rest_models.CreateCollection(
|
||||
vectors=vectors_config or {},
|
||||
sparse_vectors=sparse_vectors_config,
|
||||
metadata=deepcopy(metadata),
|
||||
),
|
||||
location=collection_path,
|
||||
force_disable_check_same_thread=self.force_disable_check_same_thread,
|
||||
)
|
||||
self.collections[collection_name] = collection
|
||||
self._save()
|
||||
return True
|
||||
|
||||
async def recreate_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors_config: Union[types.VectorParams, Mapping[str, types.VectorParams]],
|
||||
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
|
||||
metadata: Optional[types.Payload] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
await self.delete_collection(collection_name)
|
||||
return await self.create_collection(
|
||||
collection_name, vectors_config, sparse_vectors_config, metadata=metadata
|
||||
)
|
||||
|
||||
def upload_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: Iterable[types.PointStruct],
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._upload_points(collection_name, points, update_filter=update_filter)
|
||||
|
||||
def _upload_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
points: Iterable[Union[types.PointStruct, types.Record]],
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
) -> None:
|
||||
collection = self._get_collection(collection_name)
|
||||
collection.upsert(
|
||||
[
|
||||
rest_models.PointStruct(
|
||||
id=point.id, vector=point.vector or {}, payload=point.payload or {}
|
||||
)
|
||||
for point in points
|
||||
],
|
||||
update_filter=update_filter,
|
||||
)
|
||||
|
||||
def upload_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: Union[
|
||||
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
|
||||
],
|
||||
payload: Optional[Iterable[dict[Any, Any]]] = None,
|
||||
ids: Optional[Iterable[types.PointId]] = None,
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
def uuid_generator() -> Generator[str, None, None]:
|
||||
while True:
|
||||
yield str(uuid4())
|
||||
|
||||
collection = self._get_collection(collection_name)
|
||||
if isinstance(vectors, dict) and any(
|
||||
(isinstance(v, np.ndarray) for v in vectors.values())
|
||||
):
|
||||
assert (
|
||||
len(set([arr.shape[0] for arr in vectors.values()])) == 1
|
||||
), "Each named vector should have the same number of vectors"
|
||||
num_vectors = next(iter(vectors.values())).shape[0]
|
||||
vectors = [
|
||||
{name: vectors[name][i].tolist() for name in vectors.keys()}
|
||||
for i in range(num_vectors)
|
||||
]
|
||||
collection.upsert(
|
||||
[
|
||||
rest_models.PointStruct(
|
||||
id=str(point_id) if isinstance(point_id, uuid.UUID) else point_id,
|
||||
vector=(vector.tolist() if isinstance(vector, np.ndarray) else vector) or {},
|
||||
payload=payload or {},
|
||||
)
|
||||
for (point_id, vector, payload) in zip(
|
||||
ids or uuid_generator(), iter(vectors), payload or itertools.cycle([{}])
|
||||
)
|
||||
],
|
||||
update_filter=update_filter,
|
||||
)
|
||||
|
||||
async def create_payload_index(
|
||||
self,
|
||||
collection_name: str,
|
||||
field_name: str,
|
||||
field_schema: Optional[types.PayloadSchemaType] = None,
|
||||
field_type: Optional[types.PayloadSchemaType] = None,
|
||||
**kwargs: Any,
|
||||
) -> types.UpdateResult:
|
||||
show_warning_once(
|
||||
message="Payload indexes have no effect in the local Qdrant. Please use server Qdrant if you need payload indexes.",
|
||||
category=UserWarning,
|
||||
idx="create-local-payload-indexes",
|
||||
stacklevel=5,
|
||||
)
|
||||
return self._default_update_result()
|
||||
|
||||
async def delete_payload_index(
|
||||
self, collection_name: str, field_name: str, **kwargs: Any
|
||||
) -> types.UpdateResult:
|
||||
show_warning_once(
|
||||
message="Payload indexes have no effect in the local Qdrant. Please use server Qdrant if you need payload indexes.",
|
||||
category=UserWarning,
|
||||
idx="delete-local-payload-indexes",
|
||||
stacklevel=5,
|
||||
)
|
||||
return self._default_update_result()
|
||||
|
||||
async def list_snapshots(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> list[types.SnapshotDescription]:
|
||||
return []
|
||||
|
||||
async def create_snapshot(
|
||||
self, collection_name: str, **kwargs: Any
|
||||
) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def delete_snapshot(
|
||||
self, collection_name: str, snapshot_name: str, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def list_full_snapshots(self, **kwargs: Any) -> list[types.SnapshotDescription]:
|
||||
return []
|
||||
|
||||
async def create_full_snapshot(self, **kwargs: Any) -> types.SnapshotDescription:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def delete_full_snapshot(self, snapshot_name: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def recover_snapshot(self, collection_name: str, location: str, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need full snapshots."
|
||||
)
|
||||
|
||||
async def list_shard_snapshots(
|
||||
self, collection_name: str, shard_id: int, **kwargs: Any
|
||||
) -> list[types.SnapshotDescription]:
|
||||
return []
|
||||
|
||||
async def create_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, **kwargs: Any
|
||||
) -> Optional[types.SnapshotDescription]:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
|
||||
)
|
||||
|
||||
async def delete_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, snapshot_name: str, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
|
||||
)
|
||||
|
||||
async def recover_shard_snapshot(
|
||||
self, collection_name: str, shard_id: int, location: str, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Snapshots are not supported in the local Qdrant. Please use server Qdrant if you need snapshots."
|
||||
)
|
||||
|
||||
async def create_shard_key(
|
||||
self,
|
||||
collection_name: str,
|
||||
shard_key: types.ShardKey,
|
||||
shards_number: Optional[int] = None,
|
||||
replication_factor: Optional[int] = None,
|
||||
placement: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Sharding is not supported in the local Qdrant. Please use server Qdrant if you need sharding."
|
||||
)
|
||||
|
||||
async def delete_shard_key(
|
||||
self, collection_name: str, shard_key: types.ShardKey, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Sharding is not supported in the local Qdrant. Please use server Qdrant if you need sharding."
|
||||
)
|
||||
|
||||
async def info(self) -> types.VersionInfo:
|
||||
version = importlib.metadata.version("qdrant-client")
|
||||
return rest_models.VersionInfo(
|
||||
title="qdrant - vector search engine", version=version, commit=None
|
||||
)
|
||||
|
||||
async def cluster_collection_update(
|
||||
self, collection_name: str, cluster_operation: types.ClusterOperations, **kwargs: Any
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Cluster collection update is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
|
||||
async def collection_cluster_info(self, collection_name: str) -> types.CollectionClusterInfo:
|
||||
raise NotImplementedError(
|
||||
"Collection cluster info is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
|
||||
async def cluster_status(self) -> types.ClusterStatus:
|
||||
raise NotImplementedError(
|
||||
"Cluster status is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
|
||||
async def recover_current_peer(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Recover current peer is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
|
||||
async def remove_peer(self, peer_id: int, **kwargs: Any) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Remove peer info is not supported in the local Qdrant. Please use server Qdrant if you need a cluster"
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
# These are the formats accepted by qdrant core
|
||||
available_formats = [
|
||||
"%Y-%m-%dT%H:%M:%S.%f%z",
|
||||
"%Y-%m-%d %H:%M:%S.%f%z",
|
||||
"%Y-%m-%dT%H:%M:%S%z",
|
||||
"%Y-%m-%d %H:%M:%S%z",
|
||||
"%Y-%m-%dT%H:%M:%S.%f",
|
||||
"%Y-%m-%d %H:%M:%S.%f",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
]
|
||||
|
||||
|
||||
def parse(date_str: str) -> Optional[datetime]:
|
||||
"""Parses one section of the date string at a time.
|
||||
|
||||
Args:
|
||||
date_str (str): Accepts any of the formats in qdrant core (see https://github.com/qdrant/qdrant/blob/0ed86ce0575d35930268db19e1f7680287072c58/lib/segment/src/types.rs#L1388-L1410)
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: the datetime if the string is valid, otherwise None
|
||||
"""
|
||||
|
||||
def parse_available_formats(datetime_str: str) -> Optional[datetime]:
|
||||
for fmt in available_formats:
|
||||
try:
|
||||
dt = datetime.strptime(datetime_str, fmt)
|
||||
if dt.tzinfo is None:
|
||||
# Assume UTC if no timezone is provided
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
parsed_dt = parse_available_formats(date_str)
|
||||
if parsed_dt is not None:
|
||||
return parsed_dt
|
||||
|
||||
# Python can't parse timezones containing only hours (+HH), but it can parse timezones with hours and minutes
|
||||
# So we add :00 to the assumed timezone and try parsing it again
|
||||
# dt examples to handle:
|
||||
# "2021-01-01 00:00:00.000+01"
|
||||
# "2021-01-01 00:00:00.000-10"
|
||||
return parse_available_formats(date_str + ":00")
|
||||
@@ -0,0 +1,302 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.http import models
|
||||
|
||||
EPSILON = 1.1920929e-7 # https://doc.rust-lang.org/std/f32/constant.EPSILON.html
|
||||
# https://github.com/qdrant/qdrant/blob/7164ac4a5987d28f1c93f5712aef8e09e7d93555/lib/segment/src/spaces/simple_avx.rs#L99C10-L99C10
|
||||
|
||||
|
||||
class DistanceOrder(str, Enum):
|
||||
BIGGER_IS_BETTER = "bigger_is_better"
|
||||
SMALLER_IS_BETTER = "smaller_is_better"
|
||||
|
||||
|
||||
class RecoQuery:
|
||||
def __init__(
|
||||
self,
|
||||
positive: Optional[list[list[float]]] = None,
|
||||
negative: Optional[list[list[float]]] = None,
|
||||
strategy: Optional[models.RecommendStrategy] = None,
|
||||
):
|
||||
assert strategy is not None, "Recommend strategy must be provided"
|
||||
|
||||
self.strategy = strategy
|
||||
positive = positive if positive is not None else []
|
||||
negative = negative if negative is not None else []
|
||||
|
||||
self.positive: list[types.NumpyArray] = [np.array(vector) for vector in positive]
|
||||
self.negative: list[types.NumpyArray] = [np.array(vector) for vector in negative]
|
||||
|
||||
assert not np.isnan(self.positive).any(), "Positive vectors must not contain NaN"
|
||||
assert not np.isnan(self.negative).any(), "Negative vectors must not contain NaN"
|
||||
|
||||
|
||||
class ContextPair:
|
||||
def __init__(self, positive: list[float], negative: list[float]):
|
||||
self.positive: types.NumpyArray = np.array(positive)
|
||||
self.negative: types.NumpyArray = np.array(negative)
|
||||
|
||||
assert not np.isnan(self.positive).any(), "Positive vector must not contain NaN"
|
||||
assert not np.isnan(self.negative).any(), "Negative vector must not contain NaN"
|
||||
|
||||
|
||||
class DiscoveryQuery:
|
||||
def __init__(self, target: list[float], context: list[ContextPair]):
|
||||
self.target: types.NumpyArray = np.array(target)
|
||||
self.context = context
|
||||
|
||||
assert not np.isnan(self.target).any(), "Target vector must not contain NaN"
|
||||
|
||||
|
||||
class ContextQuery:
|
||||
def __init__(self, context_pairs: list[ContextPair]):
|
||||
self.context_pairs = context_pairs
|
||||
|
||||
|
||||
DenseQueryVector = Union[
|
||||
DiscoveryQuery,
|
||||
ContextQuery,
|
||||
RecoQuery,
|
||||
]
|
||||
|
||||
|
||||
def distance_to_order(distance: models.Distance) -> DistanceOrder:
|
||||
"""
|
||||
Convert distance to order
|
||||
Args:
|
||||
distance: distance to convert
|
||||
Returns:
|
||||
order
|
||||
"""
|
||||
if distance == models.Distance.EUCLID:
|
||||
return DistanceOrder.SMALLER_IS_BETTER
|
||||
elif distance == models.Distance.MANHATTAN:
|
||||
return DistanceOrder.SMALLER_IS_BETTER
|
||||
|
||||
return DistanceOrder.BIGGER_IS_BETTER
|
||||
|
||||
|
||||
def cosine_similarity(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate cosine distance between query and vectors
|
||||
Args:
|
||||
query: query vector
|
||||
vectors: vectors to calculate distance with
|
||||
Returns:
|
||||
distances
|
||||
"""
|
||||
vectors_norm = np.linalg.norm(vectors, axis=-1)[:, np.newaxis]
|
||||
vectors /= np.where(vectors_norm != 0.0, vectors_norm, EPSILON)
|
||||
|
||||
if len(query.shape) == 1:
|
||||
query_norm = np.linalg.norm(query)
|
||||
query /= np.where(query_norm != 0.0, query_norm, EPSILON)
|
||||
return np.dot(vectors, query)
|
||||
|
||||
query_norm = np.linalg.norm(query, axis=-1)[:, np.newaxis]
|
||||
query /= np.where(query_norm != 0.0, query_norm, EPSILON)
|
||||
return np.dot(query, vectors.T)
|
||||
|
||||
|
||||
def dot_product(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate dot product between query and vectors
|
||||
Args:
|
||||
query: query vector.
|
||||
vectors: vectors to calculate distance with
|
||||
Returns:
|
||||
distances
|
||||
"""
|
||||
if len(query.shape) == 1:
|
||||
return np.dot(vectors, query)
|
||||
else:
|
||||
return np.dot(query, vectors.T)
|
||||
|
||||
|
||||
def euclidean_distance(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate euclidean distance between query and vectors
|
||||
Args:
|
||||
query: query vector.
|
||||
vectors: vectors to calculate distance with
|
||||
Returns:
|
||||
distances
|
||||
"""
|
||||
if len(query.shape) == 1:
|
||||
return np.linalg.norm(vectors - query, axis=-1)
|
||||
else:
|
||||
return np.linalg.norm(vectors - query[:, np.newaxis], axis=-1)
|
||||
|
||||
|
||||
def manhattan_distance(query: types.NumpyArray, vectors: types.NumpyArray) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate manhattan distance between query and vectors
|
||||
Args:
|
||||
query: query vector.
|
||||
vectors: vectors to calculate distance with
|
||||
Returns:
|
||||
distances
|
||||
"""
|
||||
if len(query.shape) == 1:
|
||||
return np.sum(np.abs(vectors - query), axis=-1)
|
||||
else:
|
||||
return np.sum(np.abs(vectors - query[:, np.newaxis]), axis=-1)
|
||||
|
||||
|
||||
def calculate_distance(
|
||||
query: types.NumpyArray, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
assert not np.isnan(query).any(), "Query vector must not contain NaN"
|
||||
|
||||
if distance_type == models.Distance.COSINE:
|
||||
return cosine_similarity(query, vectors)
|
||||
elif distance_type == models.Distance.DOT:
|
||||
return dot_product(query, vectors)
|
||||
elif distance_type == models.Distance.EUCLID:
|
||||
return euclidean_distance(query, vectors)
|
||||
elif distance_type == models.Distance.MANHATTAN:
|
||||
return manhattan_distance(query, vectors)
|
||||
else:
|
||||
raise ValueError(f"Unknown distance type {distance_type}")
|
||||
|
||||
|
||||
def calculate_distance_core(
|
||||
query: types.NumpyArray, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
"""
|
||||
Calculate same internal distances as in core, rather than the final displayed distance
|
||||
"""
|
||||
assert not np.isnan(query).any(), "Query vector must not contain NaN"
|
||||
|
||||
if distance_type == models.Distance.EUCLID:
|
||||
return -np.square(vectors - query, dtype=np.float32).sum(axis=1, dtype=np.float32)
|
||||
if distance_type == models.Distance.MANHATTAN:
|
||||
return -np.abs(vectors - query, dtype=np.float32).sum(axis=1, dtype=np.float32)
|
||||
else:
|
||||
return calculate_distance(query, vectors, distance_type)
|
||||
|
||||
|
||||
def fast_sigmoid(x: np.float32) -> np.float32:
|
||||
if np.isnan(x) or np.isinf(x):
|
||||
# To avoid divisions on NaNs or inf, which gets: RuntimeWarning: invalid value encountered in scalar divide
|
||||
return x
|
||||
return x / np.add(1.0, abs(x))
|
||||
|
||||
|
||||
def scaled_fast_sigmoid(x: np.float32) -> np.float32:
|
||||
return 0.5 * (np.add(fast_sigmoid(x), 1.0))
|
||||
|
||||
|
||||
def calculate_recommend_best_scores(
|
||||
query: RecoQuery, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
def get_best_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
|
||||
vector_count = vectors.shape[0]
|
||||
|
||||
# Get scores to all examples
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_distance_core(example, vectors, distance_type)
|
||||
scores.append(score)
|
||||
|
||||
# Keep only max for each vector
|
||||
if len(scores) == 0:
|
||||
scores.append(np.full(vector_count, -np.inf))
|
||||
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
|
||||
|
||||
return best_scores
|
||||
|
||||
pos = get_best_scores(query.positive)
|
||||
neg = get_best_scores(query.negative)
|
||||
|
||||
# Choose from best positive or best negative,
|
||||
# in in both cases we apply sigmoid and then negate depending on the order
|
||||
return np.where(
|
||||
pos > neg,
|
||||
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
|
||||
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
|
||||
)
|
||||
|
||||
|
||||
def calculate_recommend_sum_scores(
|
||||
query: RecoQuery, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
def get_sum_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
|
||||
vector_count = vectors.shape[0]
|
||||
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_distance_core(example, vectors, distance_type)
|
||||
scores.append(score)
|
||||
|
||||
if len(scores) == 0:
|
||||
scores.append(np.zeros(vector_count))
|
||||
|
||||
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
|
||||
|
||||
return sum_scores
|
||||
|
||||
pos = get_sum_scores(query.positive)
|
||||
neg = get_sum_scores(query.negative)
|
||||
|
||||
return pos - neg
|
||||
|
||||
|
||||
def calculate_discovery_ranks(
|
||||
context: list[ContextPair],
|
||||
vectors: types.NumpyArray,
|
||||
distance_type: models.Distance,
|
||||
) -> types.NumpyArray:
|
||||
overall_ranks = np.zeros(vectors.shape[0], dtype=np.int32)
|
||||
for pair in context:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_distance_core(pair.positive, vectors, distance_type)
|
||||
neg = calculate_distance_core(pair.negative, vectors, distance_type)
|
||||
|
||||
pair_ranks = np.array(
|
||||
[
|
||||
1 if is_bigger else 0 if is_equal else -1
|
||||
for is_bigger, is_equal in zip(pos > neg, pos == neg)
|
||||
]
|
||||
)
|
||||
|
||||
overall_ranks += pair_ranks
|
||||
|
||||
return overall_ranks
|
||||
|
||||
|
||||
def calculate_discovery_scores(
|
||||
query: DiscoveryQuery, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
ranks = calculate_discovery_ranks(query.context, vectors, distance_type)
|
||||
|
||||
# Get distances to target
|
||||
distances_to_target = calculate_distance_core(query.target, vectors, distance_type)
|
||||
|
||||
sigmoided_distances = np.fromiter(
|
||||
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
|
||||
)
|
||||
|
||||
return ranks + sigmoided_distances
|
||||
|
||||
|
||||
def calculate_context_scores(
|
||||
query: ContextQuery, vectors: types.NumpyArray, distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
overall_scores = np.zeros(vectors.shape[0], dtype=np.float32)
|
||||
for pair in query.context_pairs:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_distance_core(pair.positive, vectors, distance_type)
|
||||
neg = calculate_distance_core(pair.negative, vectors, distance_type)
|
||||
|
||||
difference = pos - neg - EPSILON
|
||||
pair_scores = np.fromiter(
|
||||
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
|
||||
)
|
||||
overall_scores += pair_scores
|
||||
|
||||
return overall_scores
|
||||
@@ -0,0 +1,90 @@
|
||||
from math import asin, cos, radians, sin, sqrt
|
||||
|
||||
# Radius of earth in meters, [as recommended by the IUGG](ftp://athena.fsv.cvut.cz/ZFG/grs80-Moritz.pdf)
|
||||
MEAN_EARTH_RADIUS = 6371008.8
|
||||
|
||||
|
||||
def geo_distance(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
|
||||
"""
|
||||
Calculate distance between two points on Earth using Haversine formula.
|
||||
|
||||
Args:
|
||||
lon1: longitude of first point
|
||||
lat1: latitude of first point
|
||||
lon2: longitude of second point
|
||||
lat2: latitude of second point
|
||||
|
||||
Returns:
|
||||
distance in meters
|
||||
"""
|
||||
|
||||
# convert decimal degrees to radians
|
||||
lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
|
||||
# haversine formula
|
||||
dlon = lon2 - lon1
|
||||
dlat = lat2 - lat1
|
||||
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
|
||||
c = 2 * asin(sqrt(a))
|
||||
|
||||
return MEAN_EARTH_RADIUS * c
|
||||
|
||||
|
||||
def test_geo_distance() -> None:
|
||||
moscow = {"lon": 37.6173, "lat": 55.7558}
|
||||
london = {"lon": -0.1278, "lat": 51.5074}
|
||||
berlin = {"lon": 13.4050, "lat": 52.5200}
|
||||
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], moscow["lon"], moscow["lat"]) < 1.0
|
||||
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], london["lon"], london["lat"]) > 2400 * 1000
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], london["lon"], london["lat"]) < 2600 * 1000
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], berlin["lon"], berlin["lat"]) > 1600 * 1000
|
||||
assert geo_distance(moscow["lon"], moscow["lat"], berlin["lon"], berlin["lat"]) < 1650 * 1000
|
||||
|
||||
|
||||
def boolean_point_in_polygon(
|
||||
point: tuple[float, float],
|
||||
exterior: list[tuple[float, float]],
|
||||
interiors: list[list[tuple[float, float]]],
|
||||
) -> bool:
|
||||
inside_poly = False
|
||||
|
||||
if in_ring(point, exterior, True):
|
||||
in_hole = False
|
||||
k = 0
|
||||
while k < len(interiors) and not in_hole:
|
||||
if in_ring(point, interiors[k], False):
|
||||
in_hole = True
|
||||
k += 1
|
||||
if not in_hole:
|
||||
inside_poly = True
|
||||
|
||||
return inside_poly
|
||||
|
||||
|
||||
def in_ring(
|
||||
pt: tuple[float, float], ring: list[tuple[float, float]], ignore_boundary: bool
|
||||
) -> bool:
|
||||
is_inside = False
|
||||
if ring[0][0] == ring[len(ring) - 1][0] and ring[0][1] == ring[len(ring) - 1][1]:
|
||||
ring = ring[0 : len(ring) - 1]
|
||||
j = len(ring) - 1
|
||||
for i in range(0, len(ring)):
|
||||
xi = ring[i][0]
|
||||
yi = ring[i][1]
|
||||
xj = ring[j][0]
|
||||
yj = ring[j][1]
|
||||
on_boundary = (
|
||||
(pt[1] * (xi - xj) + yi * (xj - pt[0]) + yj * (pt[0] - xi) == 0)
|
||||
and ((xi - pt[0]) * (xj - pt[0]) <= 0)
|
||||
and ((yi - pt[1]) * (yj - pt[1]) <= 0)
|
||||
)
|
||||
if on_boundary:
|
||||
return not ignore_boundary
|
||||
intersect = ((yi > pt[1]) != (yj > pt[1])) and (
|
||||
pt[0] < (xj - xi) * (pt[1] - yi) / (yj - yi) + xi
|
||||
)
|
||||
if intersect:
|
||||
is_inside = not is_inside
|
||||
j = i
|
||||
return is_inside
|
||||
@@ -0,0 +1,151 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class JsonPathItemType(str, Enum):
|
||||
KEY = "key"
|
||||
INDEX = "index"
|
||||
WILDCARD_INDEX = "wildcard_index"
|
||||
|
||||
|
||||
class JsonPathItem(BaseModel):
|
||||
item_type: JsonPathItemType
|
||||
index: Optional[int] = (
|
||||
None # split into index and key instead of using Union, because pydantic coerces
|
||||
)
|
||||
# int to str even in case of Union[int, str]. Tested with pydantic==1.10.14
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
def parse_json_path(key: str) -> list[JsonPathItem]:
|
||||
"""Parse and validate json path
|
||||
|
||||
Args:
|
||||
key: json path
|
||||
|
||||
Returns:
|
||||
list[JsonPathItem]: json path split into separate keys
|
||||
|
||||
Raises:
|
||||
ValueError: if json path is invalid or empty
|
||||
|
||||
Examples:
|
||||
|
||||
# >>> parse_json_path("a[0][1].b")
|
||||
# [
|
||||
# JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='a'),
|
||||
# JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=0),
|
||||
# JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=1),
|
||||
# JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='b')
|
||||
# ]
|
||||
"""
|
||||
keys = []
|
||||
json_path = key
|
||||
while json_path:
|
||||
json_path_item, rest = match_quote(json_path)
|
||||
if json_path_item is None:
|
||||
json_path_item, rest = match_key(json_path)
|
||||
|
||||
if json_path_item is None:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
keys.append(json_path_item)
|
||||
brackets_chunks, rest = match_brackets(rest)
|
||||
keys.extend(brackets_chunks)
|
||||
json_path = trunk_sep(rest)
|
||||
if not json_path:
|
||||
return keys
|
||||
continue
|
||||
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
|
||||
def trunk_sep(path: str) -> str:
|
||||
if not path:
|
||||
return path
|
||||
|
||||
if len(path) == 1:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
if path.startswith("."):
|
||||
return path[1:]
|
||||
|
||||
elif path.startswith("["):
|
||||
return path
|
||||
else:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
|
||||
def match_quote(path: str) -> tuple[Optional[JsonPathItem], str]:
|
||||
if not path.startswith('"'):
|
||||
return None, path
|
||||
|
||||
left_quote_pos = 0
|
||||
right_quote_pos = path.find('"', 1)
|
||||
|
||||
if path.count('"') < 2:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
return (
|
||||
JsonPathItem(
|
||||
item_type=JsonPathItemType.KEY, key=path[left_quote_pos + 1 : right_quote_pos]
|
||||
),
|
||||
path[right_quote_pos + 1 :],
|
||||
)
|
||||
|
||||
|
||||
def match_key(path: str) -> tuple[Optional[JsonPathItem], str]:
|
||||
char_counter = 0
|
||||
for char in path:
|
||||
if not char.isalnum() and char not in ["_", "-"]:
|
||||
break
|
||||
char_counter += 1
|
||||
if char_counter == 0:
|
||||
return None, path
|
||||
|
||||
return (
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key=path[:char_counter]),
|
||||
path[char_counter:],
|
||||
)
|
||||
|
||||
|
||||
def match_brackets(rest: str) -> tuple[list[JsonPathItem], str]:
|
||||
keys = []
|
||||
|
||||
while rest:
|
||||
json_path_item, rest = _match_brackets(rest)
|
||||
|
||||
if json_path_item is None:
|
||||
break
|
||||
|
||||
keys.append(json_path_item)
|
||||
|
||||
return keys, rest
|
||||
|
||||
|
||||
def _match_brackets(path: str) -> tuple[Optional[JsonPathItem], str]:
|
||||
if "[" not in path or not path.startswith("["):
|
||||
return None, path
|
||||
|
||||
left_bracket_pos = 0
|
||||
right_bracket_pos = path.find("]", left_bracket_pos + 1)
|
||||
|
||||
if right_bracket_pos == -1:
|
||||
raise ValueError("Invalid path")
|
||||
|
||||
if right_bracket_pos == (left_bracket_pos + 1):
|
||||
return (
|
||||
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX),
|
||||
path[right_bracket_pos + 1 :],
|
||||
)
|
||||
|
||||
try:
|
||||
index = int(path[left_bracket_pos + 1 : right_bracket_pos])
|
||||
return (
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=index),
|
||||
path[right_bracket_pos + 1 :],
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError("Invalid path") from e
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,220 @@
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.local.distances import (
|
||||
calculate_distance,
|
||||
scaled_fast_sigmoid,
|
||||
EPSILON,
|
||||
fast_sigmoid,
|
||||
)
|
||||
|
||||
|
||||
class MultiRecoQuery:
|
||||
def __init__(
|
||||
self,
|
||||
positive: Optional[list[list[list[float]]]] = None, # list of matrices
|
||||
negative: Optional[list[list[list[float]]]] = None, # list of matrices
|
||||
strategy: Optional[models.RecommendStrategy] = None,
|
||||
):
|
||||
assert strategy is not None, "Recommend strategy must be provided"
|
||||
|
||||
self.strategy = strategy
|
||||
|
||||
positive = positive if positive is not None else []
|
||||
negative = negative if negative is not None else []
|
||||
|
||||
for vector in positive:
|
||||
assert not np.isnan(vector).any(), "Positive vectors must not contain NaN"
|
||||
for vector in negative:
|
||||
assert not np.isnan(vector).any(), "Negative vectors must not contain NaN"
|
||||
|
||||
self.positive: list[types.NumpyArray] = [np.array(vector) for vector in positive]
|
||||
self.negative: list[types.NumpyArray] = [np.array(vector) for vector in negative]
|
||||
|
||||
|
||||
class MultiContextPair:
|
||||
def __init__(self, positive: list[list[float]], negative: list[list[float]]):
|
||||
self.positive: types.NumpyArray = np.array(positive)
|
||||
self.negative: types.NumpyArray = np.array(negative)
|
||||
|
||||
assert not np.isnan(self.positive).any(), "Positive vector must not contain NaN"
|
||||
assert not np.isnan(self.negative).any(), "Negative vector must not contain NaN"
|
||||
|
||||
|
||||
class MultiDiscoveryQuery:
|
||||
def __init__(self, target: list[list[float]], context: list[MultiContextPair]):
|
||||
self.target: types.NumpyArray = np.array(target)
|
||||
self.context = context
|
||||
|
||||
assert not np.isnan(self.target).any(), "Target vector must not contain NaN"
|
||||
|
||||
|
||||
class MultiContextQuery:
|
||||
def __init__(self, context_pairs: list[MultiContextPair]):
|
||||
self.context_pairs = context_pairs
|
||||
|
||||
|
||||
MultiQueryVector = Union[
|
||||
MultiDiscoveryQuery,
|
||||
MultiContextQuery,
|
||||
MultiRecoQuery,
|
||||
]
|
||||
|
||||
|
||||
def calculate_multi_distance(
|
||||
query_matrix: types.NumpyArray,
|
||||
matrices: list[types.NumpyArray],
|
||||
distance_type: models.Distance,
|
||||
) -> types.NumpyArray:
|
||||
assert not np.isnan(query_matrix).any(), "Query matrix must not contain NaN"
|
||||
assert len(query_matrix.shape) == 2, "Query must be a matrix"
|
||||
|
||||
distances = calculate_multi_distance_core(query_matrix, matrices, distance_type)
|
||||
|
||||
if distance_type == models.Distance.EUCLID:
|
||||
distances = np.sqrt(np.abs(distances))
|
||||
elif distance_type == models.Distance.MANHATTAN:
|
||||
distances = np.abs(distances)
|
||||
return distances
|
||||
|
||||
|
||||
def calculate_multi_distance_core(
|
||||
query_matrix: types.NumpyArray,
|
||||
matrices: list[types.NumpyArray],
|
||||
distance_type: models.Distance,
|
||||
) -> types.NumpyArray:
|
||||
def euclidean(q: types.NumpyArray, m: types.NumpyArray, *_: Any) -> types.NumpyArray:
|
||||
return -np.square(m - q, dtype=np.float32).sum(axis=-1, dtype=np.float32)
|
||||
|
||||
def manhattan(q: types.NumpyArray, m: types.NumpyArray, *_: Any) -> types.NumpyArray:
|
||||
return -np.abs(m - q, dtype=np.float32).sum(axis=-1, dtype=np.float32)
|
||||
|
||||
assert not np.isnan(query_matrix).any(), "Query vector must not contain NaN"
|
||||
similarities: list[float] = []
|
||||
|
||||
# Euclid and Manhattan are the only ones which are calculated differently during candidate selection
|
||||
# in core, here we make sure to use the same internal similarity function as in core.
|
||||
if distance_type in [models.Distance.EUCLID, models.Distance.MANHATTAN]:
|
||||
query_matrix = query_matrix[:, np.newaxis]
|
||||
dist_func = euclidean if distance_type == models.Distance.EUCLID else manhattan
|
||||
else:
|
||||
dist_func = calculate_distance # type: ignore
|
||||
|
||||
for matrix in matrices:
|
||||
sim_matrix = dist_func(query_matrix, matrix, distance_type)
|
||||
similarity = float(np.sum(np.max(sim_matrix, axis=-1)))
|
||||
similarities.append(similarity)
|
||||
return np.array(similarities)
|
||||
|
||||
|
||||
def calculate_multi_recommend_best_scores(
|
||||
query: MultiRecoQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
def get_best_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
|
||||
matrix_count = len(matrices)
|
||||
|
||||
# Get scores to all examples
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_multi_distance_core(example, matrices, distance_type)
|
||||
scores.append(score)
|
||||
|
||||
# Keep only max for each vector
|
||||
if len(scores) == 0:
|
||||
scores.append(np.full(matrix_count, -np.inf))
|
||||
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
|
||||
|
||||
return best_scores
|
||||
|
||||
pos = get_best_scores(query.positive)
|
||||
neg = get_best_scores(query.negative)
|
||||
|
||||
# Choose from the best positive or the best negative,
|
||||
# in both cases we apply sigmoid and then negate depending on the order
|
||||
return np.where(
|
||||
pos > neg,
|
||||
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
|
||||
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
|
||||
)
|
||||
|
||||
|
||||
def calculate_multi_recommend_sum_scores(
|
||||
query: MultiRecoQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
def get_sum_scores(examples: list[types.NumpyArray]) -> types.NumpyArray:
|
||||
matrix_count = len(matrices)
|
||||
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_multi_distance_core(example, matrices, distance_type)
|
||||
scores.append(score)
|
||||
|
||||
if len(scores) == 0:
|
||||
scores.append(np.zeros(matrix_count))
|
||||
|
||||
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
|
||||
return sum_scores
|
||||
|
||||
pos = get_sum_scores(query.positive)
|
||||
neg = get_sum_scores(query.negative)
|
||||
|
||||
return pos - neg
|
||||
|
||||
|
||||
def calculate_multi_discovery_ranks(
|
||||
context: list[MultiContextPair],
|
||||
matrices: list[types.NumpyArray],
|
||||
distance_type: models.Distance,
|
||||
) -> types.NumpyArray:
|
||||
overall_ranks: types.NumpyArray = np.zeros(len(matrices), dtype=np.int32)
|
||||
for pair in context:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_multi_distance_core(pair.positive, matrices, distance_type)
|
||||
neg = calculate_multi_distance_core(pair.negative, matrices, distance_type)
|
||||
|
||||
pair_ranks = np.array(
|
||||
[
|
||||
1 if is_bigger else 0 if is_equal else -1
|
||||
for is_bigger, is_equal in zip(pos > neg, pos == neg)
|
||||
]
|
||||
)
|
||||
|
||||
overall_ranks += pair_ranks
|
||||
|
||||
return overall_ranks
|
||||
|
||||
|
||||
def calculate_multi_discovery_scores(
|
||||
query: MultiDiscoveryQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
ranks = calculate_multi_discovery_ranks(query.context, matrices, distance_type)
|
||||
|
||||
# Get distances to target
|
||||
distances_to_target = calculate_multi_distance_core(query.target, matrices, distance_type)
|
||||
|
||||
sigmoided_distances = np.fromiter(
|
||||
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
|
||||
)
|
||||
|
||||
return ranks + sigmoided_distances
|
||||
|
||||
|
||||
def calculate_multi_context_scores(
|
||||
query: MultiContextQuery, matrices: list[types.NumpyArray], distance_type: models.Distance
|
||||
) -> types.NumpyArray:
|
||||
overall_scores: types.NumpyArray = np.zeros(len(matrices), dtype=np.float32)
|
||||
for pair in query.context_pairs:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_multi_distance_core(pair.positive, matrices, distance_type)
|
||||
neg = calculate_multi_distance_core(pair.negative, matrices, distance_type)
|
||||
|
||||
difference = pos - neg - EPSILON
|
||||
pair_scores = np.fromiter(
|
||||
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
|
||||
)
|
||||
overall_scores += pair_scores
|
||||
|
||||
return overall_scores
|
||||
@@ -0,0 +1,30 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from qdrant_client.http.models import OrderValue
|
||||
from qdrant_client.local.datetime_utils import parse
|
||||
|
||||
MICROS_PER_SECOND = 1_000_000
|
||||
|
||||
|
||||
def datetime_to_microseconds(dt: datetime) -> int:
|
||||
return int(dt.timestamp() * MICROS_PER_SECOND)
|
||||
|
||||
|
||||
def to_order_value(value: Union[None, OrderValue, datetime, str]) -> Optional[OrderValue]:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# check if OrderValue
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
|
||||
if isinstance(value, datetime):
|
||||
return datetime_to_microseconds(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
dt = parse(value)
|
||||
if dt is not None:
|
||||
return datetime_to_microseconds(dt)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,337 @@
|
||||
from datetime import date, datetime, timezone
|
||||
from typing import Any, Optional, Union, Dict
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.local import datetime_utils
|
||||
from qdrant_client.local.geo import boolean_point_in_polygon, geo_distance
|
||||
from qdrant_client.local.payload_value_extractor import value_by_key
|
||||
from qdrant_client.conversions import common_types as types
|
||||
|
||||
|
||||
def get_value_counts(values: list[Any]) -> list[int]:
|
||||
counts = []
|
||||
|
||||
if all(value is None for value in values):
|
||||
counts.append(0)
|
||||
else:
|
||||
for value in values:
|
||||
if value is None:
|
||||
counts.append(0)
|
||||
elif hasattr(value, "__len__") and not isinstance(value, str):
|
||||
counts.append(len(value))
|
||||
else:
|
||||
counts.append(1)
|
||||
return counts
|
||||
|
||||
|
||||
def check_values_count(condition: models.ValuesCount, values: Optional[list[Any]]) -> bool:
|
||||
if values is None:
|
||||
return False
|
||||
|
||||
counts = get_value_counts(values)
|
||||
|
||||
if condition.lt is not None and all(count >= condition.lt for count in counts):
|
||||
return False
|
||||
if condition.lte is not None and all(count > condition.lte for count in counts):
|
||||
return False
|
||||
if condition.gt is not None and all(count <= condition.gt for count in counts):
|
||||
return False
|
||||
if condition.gte is not None and all(count < condition.gte for count in counts):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_geo_radius(condition: models.GeoRadius, values: Any) -> bool:
|
||||
if isinstance(values, dict) and "lat" in values and "lon" in values:
|
||||
lat = values["lat"]
|
||||
lon = values["lon"]
|
||||
|
||||
distance = geo_distance(
|
||||
lon1=lon,
|
||||
lat1=lat,
|
||||
lon2=condition.center.lon,
|
||||
lat2=condition.center.lat,
|
||||
)
|
||||
|
||||
return distance < condition.radius
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_geo_bounding_box(condition: models.GeoBoundingBox, values: Any) -> bool:
|
||||
if isinstance(values, dict) and "lat" in values and "lon" in values:
|
||||
lat = values["lat"]
|
||||
lon = values["lon"]
|
||||
|
||||
# handle anti-meridian crossing case
|
||||
if condition.top_left.lon > condition.bottom_right.lon:
|
||||
longitude_condition = (
|
||||
condition.top_left.lon <= lon <= 180 or -180 <= lon <= condition.bottom_right.lon
|
||||
)
|
||||
else:
|
||||
longitude_condition = condition.top_left.lon <= lon <= condition.bottom_right.lon
|
||||
|
||||
latitude_condition = condition.top_left.lat >= lat >= condition.bottom_right.lat
|
||||
|
||||
return longitude_condition and latitude_condition
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_geo_polygon(condition: models.GeoPolygon, values: Any) -> bool:
|
||||
if isinstance(values, dict) and "lat" in values and "lon" in values:
|
||||
lat = values["lat"]
|
||||
lon = values["lon"]
|
||||
exterior = [(point.lat, point.lon) for point in condition.exterior.points]
|
||||
interiors = []
|
||||
if condition.interiors is not None:
|
||||
interiors = [
|
||||
[(point.lat, point.lon) for point in interior.points]
|
||||
for interior in condition.interiors
|
||||
]
|
||||
return boolean_point_in_polygon(point=(lat, lon), exterior=exterior, interiors=interiors)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_range_interface(condition: models.RangeInterface, value: Any) -> bool:
|
||||
if isinstance(condition, models.Range):
|
||||
return check_range(condition, value)
|
||||
if isinstance(condition, models.DatetimeRange):
|
||||
return check_datetime_range(condition, value)
|
||||
return False
|
||||
|
||||
|
||||
def check_range(condition: models.Range, value: Any) -> bool:
|
||||
if not isinstance(value, (int, float)):
|
||||
return False
|
||||
return (
|
||||
(condition.lt is None or value < condition.lt)
|
||||
and (condition.lte is None or value <= condition.lte)
|
||||
and (condition.gt is None or value > condition.gt)
|
||||
and (condition.gte is None or value >= condition.gte)
|
||||
)
|
||||
|
||||
|
||||
def check_datetime_range(condition: models.DatetimeRange, value: Any) -> bool:
|
||||
def make_condition_tz_aware(dt: Optional[Union[datetime, date]]) -> Optional[datetime]:
|
||||
if isinstance(dt, date) and not isinstance(dt, datetime):
|
||||
dt = datetime.combine(dt, datetime.min.time())
|
||||
|
||||
if dt is None or dt.tzinfo is not None:
|
||||
return dt
|
||||
|
||||
# Assume UTC if no timezone is provided
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
|
||||
dt = datetime_utils.parse(value)
|
||||
|
||||
if dt is None:
|
||||
return False
|
||||
|
||||
condition.lt = make_condition_tz_aware(condition.lt)
|
||||
condition.lte = make_condition_tz_aware(condition.lte)
|
||||
condition.gt = make_condition_tz_aware(condition.gt)
|
||||
condition.gte = make_condition_tz_aware(condition.gte)
|
||||
|
||||
return (
|
||||
(condition.lt is None or dt < condition.lt)
|
||||
and (condition.lte is None or dt <= condition.lte)
|
||||
and (condition.gt is None or dt > condition.gt)
|
||||
and (condition.gte is None or dt >= condition.gte)
|
||||
)
|
||||
|
||||
|
||||
def check_match(condition: models.Match, value: Any) -> bool:
|
||||
if isinstance(condition, models.MatchValue):
|
||||
return value == condition.value
|
||||
if isinstance(condition, models.MatchText):
|
||||
return value is not None and condition.text in value
|
||||
if isinstance(condition, models.MatchTextAny):
|
||||
return value is not None and any(word in value for word in condition.text_any.split())
|
||||
if isinstance(condition, models.MatchAny):
|
||||
return value in condition.any
|
||||
if isinstance(condition, models.MatchExcept):
|
||||
return value not in condition.except_
|
||||
raise ValueError(f"Unknown match condition: {condition}")
|
||||
|
||||
|
||||
def check_nested_filter(nested_filter: models.Filter, values: list[Any]) -> bool:
|
||||
return any(check_filter(nested_filter, v, point_id=-1, has_vector={}) for v in values)
|
||||
|
||||
|
||||
def check_condition(
|
||||
condition: models.Condition,
|
||||
payload: dict[str, Any],
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
if isinstance(condition, models.IsNullCondition):
|
||||
values = value_by_key(payload, condition.is_null.key, flat=False)
|
||||
if values is None:
|
||||
return False
|
||||
if any(v is None for v in values):
|
||||
return True
|
||||
elif isinstance(condition, models.IsEmptyCondition):
|
||||
values = value_by_key(payload, condition.is_empty.key, flat=False)
|
||||
if (
|
||||
values is None
|
||||
or len(values) == 0
|
||||
or all((v is None or (isinstance(v, list) and len(v) == 0)) for v in values)
|
||||
):
|
||||
return True
|
||||
elif isinstance(condition, models.HasIdCondition):
|
||||
ids = [str(id_) if isinstance(id_, UUID) else id_ for id_ in condition.has_id]
|
||||
if point_id in ids:
|
||||
return True
|
||||
elif isinstance(condition, models.HasVectorCondition):
|
||||
if condition.has_vector in has_vector and has_vector[condition.has_vector]:
|
||||
return True
|
||||
elif isinstance(condition, models.FieldCondition):
|
||||
values = value_by_key(payload, condition.key)
|
||||
if condition.match is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_match(condition.match, v) for v in values)
|
||||
if condition.range is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_range_interface(condition.range, v) for v in values)
|
||||
if condition.geo_bounding_box is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_geo_bounding_box(condition.geo_bounding_box, v) for v in values)
|
||||
if condition.geo_radius is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_geo_radius(condition.geo_radius, v) for v in values)
|
||||
if condition.values_count is not None:
|
||||
values = value_by_key(payload, condition.key, flat=False)
|
||||
return check_values_count(condition.values_count, values)
|
||||
if condition.geo_polygon is not None:
|
||||
if values is None:
|
||||
return False
|
||||
return any(check_geo_polygon(condition.geo_polygon, v) for v in values)
|
||||
elif isinstance(condition, models.NestedCondition):
|
||||
values = value_by_key(payload, condition.nested.key)
|
||||
if values is None:
|
||||
return False
|
||||
return check_nested_filter(condition.nested.filter, values)
|
||||
elif isinstance(condition, models.Filter):
|
||||
return check_filter(condition, payload, point_id, has_vector)
|
||||
else:
|
||||
raise ValueError(f"Unknown condition: {condition}")
|
||||
return False
|
||||
|
||||
|
||||
def check_must(
|
||||
conditions: list[models.Condition],
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
return all(
|
||||
check_condition(condition, payload, point_id, has_vector) for condition in conditions
|
||||
)
|
||||
|
||||
|
||||
def check_must_not(
|
||||
conditions: list[models.Condition],
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
return all(
|
||||
not check_condition(condition, payload, point_id, has_vector) for condition in conditions
|
||||
)
|
||||
|
||||
|
||||
def check_should(
|
||||
conditions: list[models.Condition],
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
return any(
|
||||
check_condition(condition, payload, point_id, has_vector) for condition in conditions
|
||||
)
|
||||
|
||||
|
||||
def check_min_should(
|
||||
conditions: list[models.Condition],
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
vectors: Dict[str, Any],
|
||||
min_count: int,
|
||||
) -> bool:
|
||||
return (
|
||||
sum(check_condition(condition, payload, point_id, vectors) for condition in conditions)
|
||||
>= min_count
|
||||
)
|
||||
|
||||
|
||||
def check_filter(
|
||||
payload_filter: models.Filter,
|
||||
payload: dict,
|
||||
point_id: models.ExtendedPointId,
|
||||
has_vector: Dict[str, bool],
|
||||
) -> bool:
|
||||
def ensure_condition_list(
|
||||
condition: Union[models.Condition, list[models.Condition]],
|
||||
) -> list[models.Condition]:
|
||||
if isinstance(condition, list):
|
||||
return condition
|
||||
return [condition]
|
||||
|
||||
if payload_filter.must is not None:
|
||||
if not check_must(
|
||||
ensure_condition_list(payload_filter.must), payload, point_id, has_vector
|
||||
):
|
||||
return False
|
||||
if payload_filter.must_not is not None:
|
||||
if not check_must_not(
|
||||
ensure_condition_list(payload_filter.must_not), payload, point_id, has_vector
|
||||
):
|
||||
return False
|
||||
if payload_filter.should is not None:
|
||||
if not check_should(
|
||||
ensure_condition_list(payload_filter.should), payload, point_id, has_vector
|
||||
):
|
||||
return False
|
||||
if payload_filter.min_should is not None:
|
||||
if not check_min_should(
|
||||
payload_filter.min_should.conditions,
|
||||
payload,
|
||||
point_id,
|
||||
has_vector,
|
||||
payload_filter.min_should.min_count,
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def calculate_payload_mask(
|
||||
payloads: list[dict],
|
||||
payload_filter: Optional[models.Filter],
|
||||
ids_inv: list[models.ExtendedPointId],
|
||||
deleted_per_vector: Dict[str, np.ndarray],
|
||||
) -> types.NumpyArray:
|
||||
if payload_filter is None:
|
||||
return np.ones(len(payloads), dtype=bool)
|
||||
|
||||
mask: types.NumpyArray = np.zeros(len(payloads), dtype=bool)
|
||||
for i, payload in enumerate(payloads):
|
||||
has_vector = {}
|
||||
for vector_name, deleted in deleted_per_vector.items():
|
||||
if not deleted[i]:
|
||||
has_vector[vector_name] = True
|
||||
|
||||
if check_filter(payload_filter, payload, ids_inv[i], has_vector):
|
||||
mask[i] = True
|
||||
return mask
|
||||
+92
@@ -0,0 +1,92 @@
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from qdrant_client.local.json_path_parser import (
|
||||
JsonPathItem,
|
||||
JsonPathItemType,
|
||||
parse_json_path,
|
||||
)
|
||||
|
||||
|
||||
def value_by_key(payload: dict[str, Any], key: str, flat: bool = True) -> Optional[list[Any]]:
|
||||
"""
|
||||
Get value from payload by key.
|
||||
Args:
|
||||
payload: arbitrary json-like object
|
||||
flat: If True, extend list of values. If False, append. By default, we use True and flatten the arrays,
|
||||
we need it for filters, however for `count` method we need to keep the arrays as is.
|
||||
key:
|
||||
Key or path to value in payload.
|
||||
Examples:
|
||||
- "name"
|
||||
- "address.city"
|
||||
- "location[].name"
|
||||
- "location[0].name"
|
||||
|
||||
Returns:
|
||||
List of values or None if key not found.
|
||||
"""
|
||||
keys = parse_json_path(key)
|
||||
result = []
|
||||
|
||||
def _get_value(data: Any, k_list: list[JsonPathItem]) -> None:
|
||||
if not k_list:
|
||||
return
|
||||
|
||||
current_key = k_list.pop(0)
|
||||
if len(k_list) == 0:
|
||||
if isinstance(data, dict) and current_key.item_type == JsonPathItemType.KEY:
|
||||
if current_key.key in data:
|
||||
value = data[current_key.key]
|
||||
if isinstance(value, list) and flat:
|
||||
result.extend(value)
|
||||
else:
|
||||
result.append(value)
|
||||
|
||||
elif isinstance(data, list):
|
||||
if current_key.item_type == JsonPathItemType.WILDCARD_INDEX:
|
||||
result.extend(data)
|
||||
|
||||
elif current_key.item_type == JsonPathItemType.INDEX:
|
||||
assert current_key.index is not None
|
||||
|
||||
if current_key.index < len(data):
|
||||
result.append(data[current_key.index])
|
||||
|
||||
elif current_key.item_type == JsonPathItemType.KEY:
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
if current_key.key in data:
|
||||
_get_value(data[current_key.key], k_list.copy())
|
||||
|
||||
elif current_key.item_type == JsonPathItemType.INDEX:
|
||||
assert current_key.index is not None
|
||||
|
||||
if not isinstance(data, list):
|
||||
return
|
||||
|
||||
if current_key.index < len(data):
|
||||
_get_value(data[current_key.index], k_list.copy())
|
||||
|
||||
elif current_key.item_type == JsonPathItemType.WILDCARD_INDEX:
|
||||
if not isinstance(data, list):
|
||||
return
|
||||
|
||||
for item in data:
|
||||
_get_value(item, k_list.copy())
|
||||
|
||||
_get_value(payload, keys)
|
||||
return result if result else None
|
||||
|
||||
|
||||
def parse_uuid(value: Any) -> Optional[uuid.UUID]:
|
||||
"""
|
||||
Parse UUID from value.
|
||||
Args:
|
||||
value: arbitrary value
|
||||
"""
|
||||
try:
|
||||
return uuid.UUID(str(value))
|
||||
except ValueError:
|
||||
return None
|
||||
@@ -0,0 +1,247 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from qdrant_client.local.json_path_parser import JsonPathItem, JsonPathItemType
|
||||
|
||||
|
||||
def set_value_by_key(payload: dict, keys: list[JsonPathItem], value: Any) -> None:
|
||||
"""
|
||||
Set value in payload by key.
|
||||
Args:
|
||||
payload: arbitrary json-like object
|
||||
keys:
|
||||
list of json path items, e.g.:
|
||||
[
|
||||
JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='a'),
|
||||
JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=0),
|
||||
JsonPathItem(item_type=<JsonPathItemType.INDEX: 'index'>, value=1),
|
||||
JsonPathItem(item_type=<JsonPathItemType.KEY: 'key'>, value='b')
|
||||
]
|
||||
|
||||
The original keys could look like this:
|
||||
- "name"
|
||||
- "address.city"
|
||||
- "location[].name"
|
||||
- "location[0].name"
|
||||
|
||||
value: value to set
|
||||
"""
|
||||
Setter.set(payload, keys.copy(), value, None, None)
|
||||
|
||||
|
||||
class Setter:
|
||||
TYPE: Any
|
||||
SETTERS: dict[JsonPathItemType, Type["Setter"]] = {}
|
||||
|
||||
@classmethod
|
||||
def add_setter(cls, item_type: JsonPathItemType, setter: Type["Setter"]) -> None:
|
||||
cls.SETTERS[item_type] = setter
|
||||
|
||||
@classmethod
|
||||
def set(
|
||||
cls,
|
||||
data: Any,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
if not k_list:
|
||||
return
|
||||
|
||||
current_key = k_list.pop(0)
|
||||
cls.SETTERS[current_key.item_type]._set(
|
||||
data,
|
||||
current_key,
|
||||
k_list,
|
||||
value,
|
||||
prev_data,
|
||||
prev_key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _set(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
if isinstance(data, cls.TYPE):
|
||||
cls._set_compatible_types(
|
||||
data=data, current_key=current_key, k_list=k_list, value=value
|
||||
)
|
||||
else:
|
||||
cls._set_incompatible_types(
|
||||
current_key=current_key,
|
||||
k_list=k_list,
|
||||
value=value,
|
||||
prev_data=prev_data,
|
||||
prev_key=prev_key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _set_incompatible_types(
|
||||
cls,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class KeySetter(Setter):
|
||||
TYPE = dict
|
||||
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
if current_key.key not in data:
|
||||
data[current_key.key] = {}
|
||||
|
||||
if len(k_list) == 0:
|
||||
if isinstance(data[current_key.key], dict):
|
||||
data[current_key.key].update(value)
|
||||
else:
|
||||
data[current_key.key] = value
|
||||
else:
|
||||
cls.set(data[current_key.key], k_list.copy(), value, data, current_key)
|
||||
|
||||
@classmethod
|
||||
def _set_incompatible_types(
|
||||
cls,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
assert prev_key is not None
|
||||
|
||||
if len(k_list) == 0:
|
||||
if prev_key.item_type == JsonPathItemType.KEY:
|
||||
prev_data[prev_key.key] = {current_key.key: value}
|
||||
else: # if prev key was WILDCARD, we need to pass INDEX instead with an index set
|
||||
prev_data[prev_key.index] = {current_key.key: value}
|
||||
else:
|
||||
if prev_key.item_type == JsonPathItemType.KEY:
|
||||
prev_data[prev_key.key] = {current_key.key: {}}
|
||||
cls.set(
|
||||
prev_data[prev_key.key][current_key.key],
|
||||
k_list.copy(),
|
||||
value,
|
||||
prev_data[prev_key.key],
|
||||
current_key,
|
||||
)
|
||||
else:
|
||||
prev_data[prev_key.index] = {current_key.key: {}}
|
||||
cls.set(
|
||||
prev_data[prev_key.index][current_key.key],
|
||||
k_list.copy(),
|
||||
value,
|
||||
prev_data[prev_key.index],
|
||||
current_key,
|
||||
)
|
||||
|
||||
|
||||
class _ListSetter(Setter):
|
||||
TYPE = list
|
||||
|
||||
@classmethod
|
||||
def _set_incompatible_types(
|
||||
cls,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
prev_data: Any,
|
||||
prev_key: Optional[JsonPathItem],
|
||||
) -> None:
|
||||
assert prev_key is not None
|
||||
|
||||
if prev_key.item_type == JsonPathItemType.KEY:
|
||||
prev_data[prev_key.key] = []
|
||||
return
|
||||
else:
|
||||
prev_data[prev_key.index] = []
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class IndexSetter(_ListSetter):
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
assert current_key.index is not None
|
||||
|
||||
if current_key.index < len(data):
|
||||
if len(k_list) == 0:
|
||||
if isinstance(data[current_key.index], dict):
|
||||
data[current_key.index].update(value)
|
||||
else:
|
||||
data[current_key.index] = value
|
||||
return
|
||||
|
||||
cls.set(data[current_key.index], k_list.copy(), value, data, current_key)
|
||||
|
||||
|
||||
class WildcardIndexSetter(_ListSetter):
|
||||
@classmethod
|
||||
def _set_compatible_types(
|
||||
cls,
|
||||
data: Any,
|
||||
current_key: JsonPathItem,
|
||||
k_list: list[JsonPathItem],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
if len(k_list) == 0:
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, dict):
|
||||
data[i].update(value)
|
||||
else:
|
||||
data[i] = value
|
||||
else:
|
||||
for i, item in enumerate(data):
|
||||
cls.set(
|
||||
item,
|
||||
k_list.copy(),
|
||||
value,
|
||||
data,
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=i),
|
||||
)
|
||||
|
||||
|
||||
Setter.add_setter(JsonPathItemType.KEY, KeySetter)
|
||||
Setter.add_setter(JsonPathItemType.INDEX, IndexSetter)
|
||||
Setter.add_setter(JsonPathItemType.WILDCARD_INDEX, WildcardIndexSetter)
|
||||
@@ -0,0 +1,175 @@
|
||||
import base64
|
||||
import dbm
|
||||
import logging
|
||||
import pickle
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
STORAGE_FILE_NAME_OLD = "storage.dbm"
|
||||
STORAGE_FILE_NAME = "storage.sqlite"
|
||||
|
||||
|
||||
def try_migrate_to_sqlite(location: str) -> None:
|
||||
dbm_path = Path(location) / STORAGE_FILE_NAME_OLD
|
||||
sql_path = Path(location) / STORAGE_FILE_NAME
|
||||
|
||||
if sql_path.exists():
|
||||
return
|
||||
|
||||
if not dbm_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
dbm_storage = dbm.open(str(dbm_path), "c")
|
||||
|
||||
con = sqlite3.connect(str(sql_path))
|
||||
cur = con.cursor()
|
||||
|
||||
# Create table
|
||||
cur.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
|
||||
|
||||
for key in dbm_storage.keys():
|
||||
value = dbm_storage[key]
|
||||
if isinstance(key, str):
|
||||
key = key.encode("utf-8")
|
||||
key = pickle.loads(key)
|
||||
sqlite_key = CollectionPersistence.encode_key(key)
|
||||
# Insert a row of data
|
||||
cur.execute(
|
||||
"INSERT INTO points VALUES (?, ?)",
|
||||
(
|
||||
sqlite_key,
|
||||
sqlite3.Binary(value),
|
||||
),
|
||||
)
|
||||
con.commit()
|
||||
con.close()
|
||||
dbm_storage.close()
|
||||
dbm_path.unlink()
|
||||
except Exception as e:
|
||||
logging.error("Failed to migrate dbm to sqlite:", e)
|
||||
logging.error(
|
||||
"Please try to use previous version of qdrant-client or re-create collection"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class CollectionPersistence:
|
||||
CHECK_SAME_THREAD: Optional[bool] = None
|
||||
|
||||
@classmethod
|
||||
def encode_key(cls, key: models.ExtendedPointId) -> str:
|
||||
return base64.b64encode(pickle.dumps(key)).decode("utf-8")
|
||||
|
||||
def __init__(self, location: str, force_disable_check_same_thread: bool = False):
|
||||
"""
|
||||
Create or load a collection from the local storage.
|
||||
Args:
|
||||
location: path to the collection directory.
|
||||
"""
|
||||
|
||||
try_migrate_to_sqlite(location)
|
||||
|
||||
self.location = Path(location) / STORAGE_FILE_NAME
|
||||
self.location.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if self.CHECK_SAME_THREAD is None and force_disable_check_same_thread is False:
|
||||
with sqlite3.connect(":memory:") as tmp_conn:
|
||||
# it is unsafe to use `sqlite3.threadsafety` until python3.11 since it was hardcoded to 1, thus we
|
||||
# need to fetch threadsafe with a query
|
||||
# THREADSAFE = 0: Threads may not share the module
|
||||
# THREADSAFE = 1: Threads may share the module, connections and cursors. Default for Linux.
|
||||
# THREADSAFE = 2: Threads may share the module, but not connections. Default for macOS.
|
||||
threadsafe = tmp_conn.execute(
|
||||
"select * from pragma_compile_options where compile_options like 'THREADSAFE=%'"
|
||||
).fetchone()[0]
|
||||
self.__class__.CHECK_SAME_THREAD = threadsafe != "THREADSAFE=1"
|
||||
|
||||
if force_disable_check_same_thread:
|
||||
self.__class__.CHECK_SAME_THREAD = False
|
||||
|
||||
self.storage = sqlite3.connect(
|
||||
str(self.location), check_same_thread=self.CHECK_SAME_THREAD # type: ignore
|
||||
)
|
||||
|
||||
self._ensure_table()
|
||||
|
||||
def close(self) -> None:
|
||||
self.storage.close()
|
||||
|
||||
def _ensure_table(self) -> None:
|
||||
cursor = self.storage.cursor()
|
||||
cursor.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)")
|
||||
self.storage.commit()
|
||||
|
||||
def persist(self, point: models.PointStruct) -> None:
|
||||
"""
|
||||
Persist a point in the local storage.
|
||||
Args:
|
||||
point: point to persist
|
||||
"""
|
||||
key = self.encode_key(point.id)
|
||||
value = pickle.dumps(point)
|
||||
|
||||
cursor = self.storage.cursor()
|
||||
# Insert or update by key
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO points VALUES (?, ?)",
|
||||
(
|
||||
key,
|
||||
sqlite3.Binary(value),
|
||||
),
|
||||
)
|
||||
|
||||
self.storage.commit()
|
||||
|
||||
def delete(self, point_id: models.ExtendedPointId) -> None:
|
||||
"""
|
||||
Delete a point from the local storage.
|
||||
Args:
|
||||
point_id: id of the point to delete
|
||||
"""
|
||||
key = self.encode_key(point_id)
|
||||
cursor = self.storage.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM points WHERE id = ?",
|
||||
(key,),
|
||||
)
|
||||
self.storage.commit()
|
||||
|
||||
def load(self) -> Iterable[models.PointStruct]:
|
||||
"""
|
||||
Load a point from the local storage.
|
||||
Returns:
|
||||
point: loaded point
|
||||
"""
|
||||
cursor = self.storage.cursor()
|
||||
cursor.execute("SELECT point FROM points")
|
||||
for row in cursor.fetchall():
|
||||
yield pickle.loads(row[0])
|
||||
|
||||
|
||||
def test_persistence() -> None:
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
persistence = CollectionPersistence(tmpdir)
|
||||
point = models.PointStruct(id=1, vector=[1.0, 2.0, 3.0], payload={"a": 1})
|
||||
persistence.persist(point)
|
||||
for loaded_point in persistence.load():
|
||||
assert loaded_point == point
|
||||
break
|
||||
|
||||
del persistence
|
||||
persistence = CollectionPersistence(tmpdir)
|
||||
for loaded_point in persistence.load():
|
||||
assert loaded_point == point
|
||||
break
|
||||
|
||||
persistence.delete(point.id)
|
||||
persistence.delete(point.id)
|
||||
for _ in persistence.load():
|
||||
assert False, "Should not load anything"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http.models import SparseVector
|
||||
|
||||
|
||||
def empty_sparse_vector() -> SparseVector:
|
||||
return SparseVector(
|
||||
indices=[],
|
||||
values=[],
|
||||
)
|
||||
|
||||
|
||||
def validate_sparse_vector(vector: SparseVector) -> None:
|
||||
assert len(vector.indices) == len(
|
||||
vector.values
|
||||
), "Indices and values must have the same length"
|
||||
assert not np.isnan(vector.values).any(), "Values must not contain NaN"
|
||||
assert len(vector.indices) == len(set(vector.indices)), "Indices must be unique"
|
||||
|
||||
|
||||
def is_sorted(vector: SparseVector) -> bool:
|
||||
for i in range(1, len(vector.indices)):
|
||||
if vector.indices[i] < vector.indices[i - 1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def sort_sparse_vector(vector: SparseVector) -> SparseVector:
|
||||
if is_sorted(vector):
|
||||
return vector
|
||||
|
||||
sorted_indices = np.argsort(vector.indices)
|
||||
return SparseVector(
|
||||
indices=[vector.indices[i] for i in sorted_indices],
|
||||
values=[vector.values[i] for i in sorted_indices],
|
||||
)
|
||||
@@ -0,0 +1,314 @@
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.http.models import SparseVector
|
||||
from qdrant_client.local.distances import EPSILON, fast_sigmoid, scaled_fast_sigmoid
|
||||
from qdrant_client.local.sparse import (
|
||||
empty_sparse_vector,
|
||||
is_sorted,
|
||||
sort_sparse_vector,
|
||||
validate_sparse_vector,
|
||||
)
|
||||
|
||||
|
||||
class SparseRecoQuery:
|
||||
def __init__(
|
||||
self,
|
||||
positive: Optional[list[SparseVector]] = None,
|
||||
negative: Optional[list[SparseVector]] = None,
|
||||
strategy: Optional[types.RecommendStrategy] = None,
|
||||
):
|
||||
assert strategy is not None, "Recommend strategy must be provided"
|
||||
|
||||
self.strategy = strategy
|
||||
|
||||
positive = positive if positive is not None else []
|
||||
negative = negative if negative is not None else []
|
||||
|
||||
for i, vector in enumerate(positive):
|
||||
validate_sparse_vector(vector)
|
||||
positive[i] = sort_sparse_vector(vector)
|
||||
|
||||
for i, vector in enumerate(negative):
|
||||
validate_sparse_vector(vector)
|
||||
negative[i] = sort_sparse_vector(vector)
|
||||
|
||||
self.positive = positive
|
||||
self.negative = negative
|
||||
|
||||
def transform_sparse(
|
||||
self, foo: Callable[["SparseVector"], "SparseVector"]
|
||||
) -> "SparseRecoQuery":
|
||||
return SparseRecoQuery(
|
||||
positive=[foo(vector) for vector in self.positive],
|
||||
negative=[foo(vector) for vector in self.negative],
|
||||
strategy=self.strategy,
|
||||
)
|
||||
|
||||
|
||||
class SparseContextPair:
|
||||
def __init__(self, positive: SparseVector, negative: SparseVector):
|
||||
validate_sparse_vector(positive)
|
||||
validate_sparse_vector(negative)
|
||||
self.positive: SparseVector = sort_sparse_vector(positive)
|
||||
self.negative: SparseVector = sort_sparse_vector(negative)
|
||||
|
||||
|
||||
class SparseDiscoveryQuery:
|
||||
def __init__(self, target: SparseVector, context: list[SparseContextPair]):
|
||||
validate_sparse_vector(target)
|
||||
self.target: SparseVector = sort_sparse_vector(target)
|
||||
self.context = context
|
||||
|
||||
def transform_sparse(
|
||||
self, foo: Callable[["SparseVector"], "SparseVector"]
|
||||
) -> "SparseDiscoveryQuery":
|
||||
return SparseDiscoveryQuery(
|
||||
target=foo(self.target),
|
||||
context=[
|
||||
SparseContextPair(foo(pair.positive), foo(pair.negative)) for pair in self.context
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class SparseContextQuery:
|
||||
def __init__(self, context_pairs: list[SparseContextPair]):
|
||||
self.context_pairs = context_pairs
|
||||
|
||||
def transform_sparse(
|
||||
self, foo: Callable[["SparseVector"], "SparseVector"]
|
||||
) -> "SparseContextQuery":
|
||||
return SparseContextQuery(
|
||||
context_pairs=[
|
||||
SparseContextPair(foo(pair.positive), foo(pair.negative))
|
||||
for pair in self.context_pairs
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
SparseQueryVector = Union[
|
||||
SparseVector,
|
||||
SparseDiscoveryQuery,
|
||||
SparseContextQuery,
|
||||
SparseRecoQuery,
|
||||
]
|
||||
|
||||
|
||||
def calculate_distance_sparse(
|
||||
query: SparseVector, vectors: list[SparseVector], empty_is_zero: bool = False
|
||||
) -> types.NumpyArray:
|
||||
"""Calculate distances between a query sparse vector and a list of sparse vectors.
|
||||
|
||||
Args:
|
||||
query (SparseVector): The query sparse vector.
|
||||
vectors (list[SparseVector]): A list of sparse vectors to compare against.
|
||||
empty_is_zero (bool): If True, distance between vectors with no overlap is treated as zero.
|
||||
Otherwise, it is treated as negative infinity.
|
||||
Simple nearest search requires `empty_is_zero` to be False, while methods like
|
||||
recommend, discovery, and context search require True.
|
||||
"""
|
||||
scores = []
|
||||
|
||||
for vector in vectors:
|
||||
score = sparse_dot_product(query, vector)
|
||||
if score is not None:
|
||||
scores.append(score)
|
||||
elif not empty_is_zero:
|
||||
# means no overlap
|
||||
scores.append(np.float32("-inf"))
|
||||
else:
|
||||
scores.append(np.float32(0.0))
|
||||
|
||||
return np.array(scores, dtype=np.float32)
|
||||
|
||||
|
||||
# Expects sorted indices
|
||||
# Returns None if no overlap
|
||||
def sparse_dot_product(vector1: SparseVector, vector2: SparseVector) -> Optional[np.float32]:
|
||||
result = 0.0
|
||||
i, j = 0, 0
|
||||
overlap = False
|
||||
|
||||
assert is_sorted(vector1), "Query sparse vector must be sorted"
|
||||
assert is_sorted(vector2), "Sparse vector to compare with must be sorted"
|
||||
|
||||
while i < len(vector1.indices) and j < len(vector2.indices):
|
||||
if vector1.indices[i] == vector2.indices[j]:
|
||||
overlap = True
|
||||
result += vector1.values[i] * vector2.values[j]
|
||||
i += 1
|
||||
j += 1
|
||||
elif vector1.indices[i] < vector2.indices[j]:
|
||||
i += 1
|
||||
else:
|
||||
j += 1
|
||||
|
||||
if overlap:
|
||||
return np.float32(result)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def calculate_sparse_discovery_ranks(
|
||||
context: list[SparseContextPair],
|
||||
vectors: list[SparseVector],
|
||||
) -> types.NumpyArray:
|
||||
overall_ranks: types.NumpyArray = np.zeros(len(vectors), dtype=np.int32)
|
||||
for pair in context:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_distance_sparse(pair.positive, vectors, empty_is_zero=True)
|
||||
neg = calculate_distance_sparse(pair.negative, vectors, empty_is_zero=True)
|
||||
|
||||
pair_ranks = np.array(
|
||||
[
|
||||
1 if is_bigger else 0 if is_equal else -1
|
||||
for is_bigger, is_equal in zip(pos > neg, pos == neg)
|
||||
]
|
||||
)
|
||||
|
||||
overall_ranks += pair_ranks
|
||||
|
||||
return overall_ranks
|
||||
|
||||
|
||||
def calculate_sparse_discovery_scores(
|
||||
query: SparseDiscoveryQuery, vectors: list[SparseVector]
|
||||
) -> types.NumpyArray:
|
||||
ranks = calculate_sparse_discovery_ranks(query.context, vectors)
|
||||
|
||||
# Get distances to target
|
||||
distances_to_target = calculate_distance_sparse(query.target, vectors, empty_is_zero=True)
|
||||
|
||||
sigmoided_distances = np.fromiter(
|
||||
(scaled_fast_sigmoid(xi) for xi in distances_to_target), np.float32
|
||||
)
|
||||
|
||||
return ranks + sigmoided_distances
|
||||
|
||||
|
||||
def calculate_sparse_context_scores(
|
||||
query: SparseContextQuery, vectors: list[SparseVector]
|
||||
) -> types.NumpyArray:
|
||||
overall_scores: types.NumpyArray = np.zeros(len(vectors), dtype=np.float32)
|
||||
for pair in query.context_pairs:
|
||||
# Get distances to positive and negative vectors
|
||||
pos = calculate_distance_sparse(pair.positive, vectors, empty_is_zero=True)
|
||||
neg = calculate_distance_sparse(pair.negative, vectors, empty_is_zero=True)
|
||||
|
||||
difference = pos - neg - EPSILON
|
||||
pair_scores = np.fromiter(
|
||||
(fast_sigmoid(xi) for xi in np.minimum(difference, 0.0)), np.float32
|
||||
)
|
||||
overall_scores += pair_scores
|
||||
|
||||
return overall_scores
|
||||
|
||||
|
||||
def calculate_sparse_recommend_best_scores(
|
||||
query: SparseRecoQuery, vectors: list[SparseVector]
|
||||
) -> types.NumpyArray:
|
||||
def get_best_scores(examples: list[SparseVector]) -> types.NumpyArray:
|
||||
vector_count = len(vectors)
|
||||
|
||||
# Get scores to all examples
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_distance_sparse(example, vectors, empty_is_zero=True)
|
||||
scores.append(score)
|
||||
|
||||
# Keep only max for each vector
|
||||
if len(scores) == 0:
|
||||
scores.append(np.full(vector_count, -np.inf))
|
||||
best_scores = np.array(scores, dtype=np.float32).max(axis=0)
|
||||
|
||||
return best_scores
|
||||
|
||||
pos = get_best_scores(query.positive)
|
||||
neg = get_best_scores(query.negative)
|
||||
|
||||
# Choose from best positive or best negative,
|
||||
# in both cases we apply sigmoid and then negate depending on the order
|
||||
return np.where(
|
||||
pos > neg,
|
||||
np.fromiter((scaled_fast_sigmoid(xi) for xi in pos), pos.dtype),
|
||||
np.fromiter((-scaled_fast_sigmoid(xi) for xi in neg), neg.dtype),
|
||||
)
|
||||
|
||||
|
||||
def calculate_sparse_recommend_sum_scores(
|
||||
query: SparseRecoQuery, vectors: list[SparseVector]
|
||||
) -> types.NumpyArray:
|
||||
def get_sum_scores(examples: list[SparseVector]) -> types.NumpyArray:
|
||||
vector_count = len(vectors)
|
||||
|
||||
scores: list[types.NumpyArray] = []
|
||||
for example in examples:
|
||||
score = calculate_distance_sparse(example, vectors, empty_is_zero=True)
|
||||
scores.append(score)
|
||||
|
||||
if len(scores) == 0:
|
||||
scores.append(np.zeros(vector_count))
|
||||
|
||||
sum_scores = np.array(scores, dtype=np.float32).sum(axis=0)
|
||||
return sum_scores
|
||||
|
||||
pos = get_sum_scores(query.positive)
|
||||
neg = get_sum_scores(query.negative)
|
||||
|
||||
return pos - neg
|
||||
|
||||
|
||||
# Expects sorted indices
|
||||
def combine_aggregate(vector1: SparseVector, vector2: SparseVector, op: Callable) -> SparseVector:
|
||||
result = empty_sparse_vector()
|
||||
i, j = 0, 0
|
||||
while i < len(vector1.indices) and j < len(vector2.indices):
|
||||
if vector1.indices[i] == vector2.indices[j]:
|
||||
result.indices.append(vector1.indices[i])
|
||||
result.values.append(op(vector1.values[i], vector2.values[j]))
|
||||
i += 1
|
||||
j += 1
|
||||
elif vector1.indices[i] < vector2.indices[j]:
|
||||
result.indices.append(vector1.indices[i])
|
||||
result.values.append(op(vector1.values[i], 0.0))
|
||||
i += 1
|
||||
else:
|
||||
result.indices.append(vector2.indices[j])
|
||||
result.values.append(op(0.0, vector2.values[j]))
|
||||
j += 1
|
||||
|
||||
while i < len(vector1.indices):
|
||||
result.indices.append(vector1.indices[i])
|
||||
result.values.append(op(vector1.values[i], 0.0))
|
||||
i += 1
|
||||
|
||||
while j < len(vector2.indices):
|
||||
result.indices.append(vector2.indices[j])
|
||||
result.values.append(op(0.0, vector2.values[j]))
|
||||
j += 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Expects sorted indices
|
||||
def sparse_avg(vectors: Sequence[SparseVector]) -> SparseVector:
|
||||
result = empty_sparse_vector()
|
||||
if len(vectors) == 0:
|
||||
return result
|
||||
|
||||
sparse_count = 0
|
||||
for vector in vectors:
|
||||
sparse_count += 1
|
||||
result = combine_aggregate(result, vector, lambda v1, v2: v1 + v2)
|
||||
|
||||
result.values = np.divide(result.values, sparse_count).tolist()
|
||||
return result
|
||||
|
||||
|
||||
# Expects sorted indices
|
||||
def merge_positive_and_negative_avg(
|
||||
positive: SparseVector, negative: SparseVector
|
||||
) -> SparseVector:
|
||||
return combine_aggregate(positive, negative, lambda pos, neg: pos + pos - neg)
|
||||
@@ -0,0 +1,57 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from qdrant_client.local.datetime_utils import parse
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # type: ignore
|
||||
"date_str, expected",
|
||||
[
|
||||
("2021-01-01T00:00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01T00:00:00Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01T00:00:00+00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01T00:00:00.000000", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01T00:00:00.000000Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
(
|
||||
"2021-01-01T00:00:00.000000+01:00",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=1))),
|
||||
),
|
||||
(
|
||||
"2021-01-01T00:00:00.000000-10:00",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-10))),
|
||||
),
|
||||
("2021-01-01", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01 00:00:00", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01 00:00:00Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
(
|
||||
"2021-01-01 00:00:00+0200",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=2))),
|
||||
),
|
||||
("2021-01-01 00:00:00.000000", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
("2021-01-01 00:00:00.000000Z", datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)),
|
||||
(
|
||||
"2021-01-01 00:00:00.000000+00:30",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(minutes=30))),
|
||||
),
|
||||
(
|
||||
"2021-01-01 00:00:00.000009+00:30",
|
||||
datetime(2021, 1, 1, 0, 0, 0, 9, tzinfo=timezone(timedelta(minutes=30))),
|
||||
),
|
||||
# this is accepted in core but not here, there is no specifier for only-hour offset
|
||||
(
|
||||
"2021-01-01 00:00:00.000+01",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=1))),
|
||||
),
|
||||
(
|
||||
"2021-01-01 00:00:00.000-10",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-10))),
|
||||
),
|
||||
(
|
||||
"2021-01-01 00:00:00-03:00",
|
||||
datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-3))),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_dates(date_str: str, expected: datetime):
|
||||
assert parse(date_str) == expected
|
||||
@@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.local.distances import calculate_distance
|
||||
from qdrant_client.local.multi_distances import calculate_multi_distance
|
||||
from qdrant_client.local.sparse_distances import calculate_distance_sparse
|
||||
|
||||
|
||||
def test_distances() -> None:
|
||||
query = np.array([1.0, 2.0, 3.0])
|
||||
vectors = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
|
||||
assert np.allclose(calculate_distance(query, vectors, models.Distance.DOT), [14.0, 14.0])
|
||||
assert np.allclose(calculate_distance(query, vectors, models.Distance.EUCLID), [0.0, 0.0])
|
||||
assert np.allclose(calculate_distance(query, vectors, models.Distance.MANHATTAN), [0.0, 0.0])
|
||||
# cosine modifies vectors inplace
|
||||
assert np.allclose(calculate_distance(query, vectors, models.Distance.COSINE), [1.0, 1.0])
|
||||
|
||||
query = np.array([1.0, 0.0, 1.0])
|
||||
vectors = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 0.0]])
|
||||
|
||||
assert np.allclose(
|
||||
calculate_distance(query, vectors, models.Distance.DOT), [4.0, 0.0], atol=0.0001
|
||||
)
|
||||
assert np.allclose(
|
||||
calculate_distance(query, vectors, models.Distance.EUCLID),
|
||||
[2.82842712, 1.7320508],
|
||||
atol=0.0001,
|
||||
)
|
||||
|
||||
assert np.allclose(
|
||||
calculate_distance(query, vectors, models.Distance.MANHATTAN),
|
||||
[4.0, 3.0],
|
||||
atol=0.0001,
|
||||
)
|
||||
# cosine modifies vectors inplace
|
||||
assert np.allclose(
|
||||
calculate_distance(query, vectors, models.Distance.COSINE),
|
||||
[0.75592895, 0.0],
|
||||
atol=0.0001,
|
||||
)
|
||||
|
||||
sparse_query = models.SparseVector(indices=[1, 2], values=[1, 2])
|
||||
sparse_vectors = [models.SparseVector(indices=[10, 20], values=[1, 2])]
|
||||
|
||||
assert calculate_distance_sparse(sparse_query, sparse_vectors) == [np.float32("-inf")]
|
||||
|
||||
sparse_vectors = [
|
||||
models.SparseVector(indices=[1, 2], values=[3, 4]),
|
||||
models.SparseVector(indices=[1, 2, 3], values=[1, 2, 3]),
|
||||
]
|
||||
assert np.allclose(
|
||||
calculate_distance_sparse(sparse_query, sparse_vectors), [11.0, 5], atol=0.0001
|
||||
)
|
||||
|
||||
multivector_query = np.array([[1, 2, 3], [3, 4, 5]])
|
||||
docs = [np.array([[1, 2, 3], [0, 1, 2]])]
|
||||
assert calculate_multi_distance(multivector_query, docs, models.Distance.DOT)[0] == 40.0
|
||||
+189
@@ -0,0 +1,189 @@
|
||||
from qdrant_client.http.models import models
|
||||
from qdrant_client.local.payload_filters import check_filter
|
||||
|
||||
|
||||
def test_nested_payload_filters():
|
||||
payload = {
|
||||
"country": {
|
||||
"name": "Germany",
|
||||
"capital": "Berlin",
|
||||
"cities": [
|
||||
{
|
||||
"name": "Berlin",
|
||||
"population": 3.7,
|
||||
"location": {
|
||||
"lon": 13.76116,
|
||||
"lat": 52.33826,
|
||||
},
|
||||
"sightseeing": ["Brandenburg Gate", "Reichstag"],
|
||||
},
|
||||
{
|
||||
"name": "Munich",
|
||||
"population": 1.5,
|
||||
"location": {
|
||||
"lon": 11.57549,
|
||||
"lat": 48.13743,
|
||||
},
|
||||
"sightseeing": ["Marienplatz", "Olympiapark"],
|
||||
},
|
||||
{
|
||||
"name": "Hamburg",
|
||||
"population": 1.8,
|
||||
"location": {
|
||||
"lon": 9.99368,
|
||||
"lat": 53.55108,
|
||||
},
|
||||
"sightseeing": ["Reeperbahn", "Elbphilharmonie"],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"nested": {
|
||||
"key": "country.cities",
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "population",
|
||||
"range": {
|
||||
"gte": 1.0,
|
||||
},
|
||||
}
|
||||
],
|
||||
"must_not": [{"key": "sightseeing", "values_count": {"gt": 1}}],
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is False
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"nested": {
|
||||
"key": "country.cities",
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "population",
|
||||
"range": {
|
||||
"gte": 1.0,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is True
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"nested": {
|
||||
"key": "country.cities",
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "population",
|
||||
"range": {
|
||||
"gte": 1.0,
|
||||
},
|
||||
},
|
||||
{"key": "sightseeing", "values_count": {"gt": 2}},
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is False
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"nested": {
|
||||
"key": "country.cities",
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "population",
|
||||
"range": {
|
||||
"gte": 9.0,
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is False
|
||||
|
||||
|
||||
def test_geo_polygon_filter_query():
|
||||
payload = {
|
||||
"location": [
|
||||
{
|
||||
"lon": 70.0,
|
||||
"lat": 70.0,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
query = models.Filter(
|
||||
**{
|
||||
"must": [
|
||||
{
|
||||
"key": "location",
|
||||
"geo_polygon": {
|
||||
"exterior": {
|
||||
"points": [
|
||||
{"lon": 55.455868, "lat": 55.495862},
|
||||
{"lon": 86.455868, "lat": 55.495862},
|
||||
{"lon": 86.455868, "lat": 86.495862},
|
||||
{"lon": 55.455868, "lat": 86.495862},
|
||||
{"lon": 55.455868, "lat": 55.495862},
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is True
|
||||
|
||||
payload = {
|
||||
"location": [
|
||||
{
|
||||
"lon": 30.693738,
|
||||
"lat": 30.502165,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
res = check_filter(query, payload, 0, has_vector={})
|
||||
assert res is False
|
||||
+549
@@ -0,0 +1,549 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from qdrant_client.local.json_path_parser import (
|
||||
JsonPathItem,
|
||||
JsonPathItemType,
|
||||
parse_json_path,
|
||||
)
|
||||
from qdrant_client.local.payload_value_extractor import value_by_key
|
||||
from qdrant_client.local.payload_value_setter import set_value_by_key
|
||||
|
||||
|
||||
def test_parse_json_path() -> None:
|
||||
jp_key = "a"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [JsonPathItem(item_type=JsonPathItemType.KEY, key="a")]
|
||||
|
||||
jp_key = "a.b"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
jp_key = 'a."a[b]".c'
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a[b]"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="c"),
|
||||
]
|
||||
|
||||
jp_key = "a[0]"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
]
|
||||
|
||||
jp_key = "a[0].b"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
jp_key = "a[0].b[1]"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
|
||||
]
|
||||
|
||||
jp_key = "a[][]"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX, index=None),
|
||||
JsonPathItem(item_type=JsonPathItemType.WILDCARD_INDEX, index=None),
|
||||
]
|
||||
|
||||
jp_key = "a[0][1]"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
|
||||
]
|
||||
|
||||
jp_key = "a[0][1].b"
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=0),
|
||||
JsonPathItem(item_type=JsonPathItemType.INDEX, index=1),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
jp_key = 'a."k.c"'
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="k.c"),
|
||||
]
|
||||
|
||||
jp_key = 'a."c[][]".b'
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="c[][]"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
jp_key = 'a."c..q".b'
|
||||
keys = parse_json_path(jp_key)
|
||||
assert keys == [
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="a"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="c..q"),
|
||||
JsonPathItem(item_type=JsonPathItemType.KEY, key="b"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a."k.c'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a."k.c".'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a."k.c".[]'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a.'k.c'"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a["
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[]]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[][]."
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[][]b"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = ".a"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[x]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a[]""'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = '""b'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "[]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a[.]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = 'a["1"]'
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = ""
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a..c"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a.c[]b[]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
jp_key = "a.c[].[]"
|
||||
parse_json_path(jp_key)
|
||||
|
||||
|
||||
def test_value_by_key() -> None:
|
||||
payload = {
|
||||
"name": "John",
|
||||
"age": 25,
|
||||
"counts": [1, 2, 3],
|
||||
"address": {
|
||||
"city": "New York",
|
||||
},
|
||||
"location": [
|
||||
{"name": "home", "counts": [1, 2, 3]},
|
||||
{"name": "work", "counts": [4, 5, 6]},
|
||||
],
|
||||
"nested": [{"empty": []}, {"empty": []}, {"empty": None}],
|
||||
"the_null": None,
|
||||
"the": {"nested.key": "cuckoo"},
|
||||
"double-nest-array": [[1, 2], [3, 4], [5, 6]],
|
||||
}
|
||||
# region flat=True
|
||||
assert value_by_key(payload, "name") == ["John"]
|
||||
assert value_by_key(payload, "address.city") == ["New York"]
|
||||
assert value_by_key(payload, "location[].name") == ["home", "work"]
|
||||
assert value_by_key(payload, "location[0].name") == ["home"]
|
||||
assert value_by_key(payload, "location[1].name") == ["work"]
|
||||
assert value_by_key(payload, "location[2].name") is None
|
||||
assert value_by_key(payload, "location[].name[0]") is None
|
||||
assert value_by_key(payload, "location[0]") == [{"name": "home", "counts": [1, 2, 3]}]
|
||||
assert value_by_key(payload, "not_exits") is None
|
||||
assert value_by_key(payload, "address") == [{"city": "New York"}]
|
||||
assert value_by_key(payload, "address.city[0]") is None
|
||||
assert value_by_key(payload, "counts") == [1, 2, 3]
|
||||
assert value_by_key(payload, "location[].counts") == [1, 2, 3, 4, 5, 6]
|
||||
assert value_by_key(payload, "nested[].empty") == [None]
|
||||
assert value_by_key(payload, "the_null") == [None]
|
||||
assert value_by_key(payload, 'the."nested.key"') == ["cuckoo"]
|
||||
assert value_by_key(payload, "double-nest-array[][]") == [1, 2, 3, 4, 5, 6]
|
||||
assert value_by_key(payload, "double-nest-array[0][]") == [1, 2]
|
||||
assert value_by_key(payload, "double-nest-array[0][0]") == [1]
|
||||
assert value_by_key(payload, "double-nest-array[0][0]") == [1]
|
||||
assert value_by_key(payload, "double-nest-array[][1]") == [2, 4, 6]
|
||||
# endregion
|
||||
|
||||
# region flat=False
|
||||
assert value_by_key(payload, "name", flat=False) == ["John"]
|
||||
assert value_by_key(payload, "address.city", flat=False) == ["New York"]
|
||||
assert value_by_key(payload, "location[].name", flat=False) == ["home", "work"]
|
||||
assert value_by_key(payload, "location[0].name", flat=False) == ["home"]
|
||||
assert value_by_key(payload, "location[1].name", flat=False) == ["work"]
|
||||
assert value_by_key(payload, "location[2].name", flat=False) is None
|
||||
assert value_by_key(payload, "location[].name[0]", flat=False) is None
|
||||
assert value_by_key(payload, "location[0]", flat=False) == [
|
||||
{"name": "home", "counts": [1, 2, 3]}
|
||||
]
|
||||
assert value_by_key(payload, "not_exist", flat=False) is None
|
||||
assert value_by_key(payload, "address", flat=False) == [{"city": "New York"}]
|
||||
assert value_by_key(payload, "address.city[0]", flat=False) is None
|
||||
assert value_by_key(payload, "counts", flat=False) == [[1, 2, 3]]
|
||||
assert value_by_key(payload, "location[].counts", flat=False) == [
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
]
|
||||
assert value_by_key(payload, "nested[].empty", flat=False) == [[], [], None]
|
||||
assert value_by_key(payload, "the_null", flat=False) == [None]
|
||||
|
||||
assert value_by_key(payload, "age.nested.not_exist") is None
|
||||
# endregion
|
||||
|
||||
|
||||
def test_set_value_by_key() -> None:
|
||||
# region valid keys
|
||||
payload: dict[str, Any] = {}
|
||||
new_value: dict[str, Any] = {}
|
||||
key = "a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {}}, payload
|
||||
|
||||
payload = {"a": {"a": 2}}
|
||||
new_value = {}
|
||||
key = "a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": 2}}, payload
|
||||
|
||||
payload = {"a": {"a": 2}}
|
||||
new_value = {"b": 3}
|
||||
key = "a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": 2, "b": 3}}, payload
|
||||
|
||||
payload = {"a": {"a": 2}}
|
||||
new_value = {"a": 3}
|
||||
key = "a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": 3}}, payload
|
||||
|
||||
payload = {"a": {"a": 2}}
|
||||
new_value = {"a": 3}
|
||||
key = "a.a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": {"a": 3}}}, payload
|
||||
|
||||
payload = {"a": {"a": {"a": 1}}}
|
||||
new_value = {"b": 2}
|
||||
key = "a.a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": {"a": 1, "b": 2}}}, payload
|
||||
|
||||
payload = {"a": {"a": {"a": 1}}}
|
||||
new_value = {"a": 2}
|
||||
key = "a.a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"a": {"a": 2}}}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": []}, payload
|
||||
|
||||
payload = {"a": [{}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"b": 2}]}, payload
|
||||
|
||||
payload = {"a": [{"a": 1}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"a": 1, "b": 2}]}, payload
|
||||
|
||||
payload = {"a": [[]]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"b": 2}]}, payload
|
||||
|
||||
payload = {"a": [[]]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[]]}, payload
|
||||
|
||||
payload = {"a": [{"a": []}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[0].a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"a": {"b": 2}}]}, payload
|
||||
|
||||
payload = {"a": [{"a": []}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[].a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"a": {"b": 2}}]}, payload
|
||||
|
||||
payload = {"a": [{"a": []}, {"a": []}]}
|
||||
new_value = {"b": 2}
|
||||
key = "a[].a"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"a": {"b": 2}}, {"a": {"b": 2}}]}, payload
|
||||
|
||||
payload = {"a": 1, "b": 2}
|
||||
new_value = {"c": 3}
|
||||
key = "c"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": 1, "b": 2, "c": {"c": 3}}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": 1}}}
|
||||
new_value = {"d": 2}
|
||||
key = "a.b.d"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": 1, "d": {"d": 2}}}}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": 1}}}
|
||||
new_value = {"c": 2}
|
||||
key = "a.b"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": 2}}}, payload
|
||||
|
||||
payload = {"a": [{"b": 1}, {"b": 2}]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"b": 1}, {"b": 2, "c": 3}]}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {"b": {"c": 1}}
|
||||
key = "a[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": []}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": {"d": {"e": 1}}}}}
|
||||
new_value = {"f": 2}
|
||||
key = "a.b.c.d"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": {"d": {"e": 1, "f": 2}}}}}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": 1}}}
|
||||
new_value = {"d": {"e": 2}}
|
||||
key = "a.b.c"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": {"d": {"e": 2}}}}}, payload
|
||||
|
||||
payload = {"a": [{"b": 1}]}
|
||||
new_value = {"c": 2}
|
||||
key = "a[1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [{"b": 1}]}, payload
|
||||
|
||||
payload = {"a": {"b": [{"c": 1}, {"c": 2}]}}
|
||||
new_value = {"d": 3}
|
||||
key = "a.b[0].c"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": [{"c": {"d": 3}}, {"c": 2}]}}, payload
|
||||
|
||||
payload = {"a": {"b": {"c": [{"d": 1}]}}}
|
||||
new_value = {"e": {"f": 2}}
|
||||
key = "a.b.c[0].d"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": [{"d": {"e": {"f": 2}}}]}}}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[0][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2}]]}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[1][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1}], [{"b": 2, "c": 3}]]}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[1][1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1}], [{"b": 2}]]}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2, "c": 3}]]}, payload
|
||||
|
||||
payload = {"a": [[{"b": 1}], [{"b": 2}]]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[][]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": [[{"b": 1, "c": 3}], [{"b": 2, "c": 3}]]}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {"c": 3}
|
||||
key = 'a."b.c"'
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b.c": {"c": 3}}}, payload
|
||||
|
||||
payload = {"a": {"c": [1]}}
|
||||
new_value = {"a": 1}
|
||||
key = "a.c[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"c": [{"a": 1}]}}, payload
|
||||
|
||||
payload = {"a": {"c": [1]}}
|
||||
new_value = {"a": 1}
|
||||
key = "a.c[0].d"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"c": [{"d": {"a": 1}}]}}, payload
|
||||
|
||||
payload = {"": 2}
|
||||
new_value = {"a": 1}
|
||||
key = '""'
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"": {"a": 1}}, payload
|
||||
# endregion
|
||||
|
||||
# region exceptions
|
||||
|
||||
try:
|
||||
payload = {"a": []}
|
||||
new_value = {"c": 3}
|
||||
key = "a.'b.c'"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert False, f"Should've raised an exception due to the key with incorrect quotes: {key}"
|
||||
except Exception:
|
||||
assert True
|
||||
|
||||
try:
|
||||
payload = {"a": [{"b": 1}, {"b": 2}]}
|
||||
new_value = {"c": 3}
|
||||
key = "a[-1]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert False, "Negative indexation is not supported"
|
||||
except Exception:
|
||||
assert True
|
||||
|
||||
try:
|
||||
payload = {"a": [{"b": 1}, {"b": 2}]}
|
||||
new_value = {"c": 3}
|
||||
key = "a["
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert False, f"Should've raised an exception due to the incorrect key: {key}"
|
||||
except Exception:
|
||||
assert True
|
||||
|
||||
try:
|
||||
payload = {"a": [{"b": 1}, {"b": 2}]}
|
||||
new_value = {"c": 3}
|
||||
key = "a]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert False, f"Should've raise an exception due to the incorrect key: {key}"
|
||||
except Exception:
|
||||
assert True
|
||||
|
||||
# endregion
|
||||
|
||||
# region wrong keys
|
||||
payload = {"a": []}
|
||||
new_value = {}
|
||||
key = "a.b[0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": []}}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {}
|
||||
key = "a.b"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {}}}, payload
|
||||
|
||||
payload = {"a": []}
|
||||
new_value = {"c": 2}
|
||||
key = "a.b"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": {"c": 2}}}, payload
|
||||
|
||||
payload = {"a": [[{"a": 1}]]}
|
||||
new_value = {"a": 2}
|
||||
key = "a.b[0][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"b": []}}, payload
|
||||
|
||||
payload = {"a": {"c": 2}}
|
||||
new_value = {"a": 1}
|
||||
key = "a[]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": []}, payload
|
||||
|
||||
payload = {"a": {"c": 2}}
|
||||
new_value = {"a": 1}
|
||||
key = "a[].b"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": []}, payload
|
||||
|
||||
payload = {"a": {"c": [1]}}
|
||||
new_value = {"a": 1}
|
||||
key = "a.c[][][0]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"c": [[]]}}, payload
|
||||
|
||||
payload = {"a": {"c": [{"d": 1}]}}
|
||||
new_value = {"a": 1}
|
||||
key = "a.c[][]"
|
||||
set_value_by_key(payload, parse_json_path(key), new_value)
|
||||
assert payload == {"a": {"c": [[]]}}, payload
|
||||
# endregion
|
||||
+135
@@ -0,0 +1,135 @@
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from qdrant_client import models
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def client():
|
||||
"""
|
||||
Sets up multiple collections with a bunch of points
|
||||
"""
|
||||
client = QdrantLocal(":memory:")
|
||||
client.create_collection(
|
||||
"collection_default",
|
||||
vectors_config=models.VectorParams(
|
||||
size=4,
|
||||
distance=models.Distance.DOT,
|
||||
),
|
||||
)
|
||||
|
||||
client.create_collection(
|
||||
"collection_multiple_vectors",
|
||||
vectors_config={
|
||||
"": models.VectorParams(
|
||||
size=4,
|
||||
distance=models.Distance.DOT,
|
||||
),
|
||||
"byte": models.VectorParams(
|
||||
size=4, distance=models.Distance.DOT, datatype=models.Datatype.UINT8
|
||||
),
|
||||
"colbert": models.VectorParams(
|
||||
size=4,
|
||||
distance=models.Distance.DOT,
|
||||
multivector_config=models.MultiVectorConfig(
|
||||
comparator=models.MultiVectorComparator.MAX_SIM
|
||||
),
|
||||
),
|
||||
},
|
||||
sparse_vectors_config={"sparse": models.SparseVectorParams()},
|
||||
)
|
||||
|
||||
client.upsert(
|
||||
"collection_default",
|
||||
[
|
||||
models.PointStruct(id=1, vector=[0.25, 0.0, 0.0, 0.0]),
|
||||
],
|
||||
)
|
||||
|
||||
client.upsert(
|
||||
"collection_multiple_vectors",
|
||||
[
|
||||
models.PointStruct(
|
||||
id=1,
|
||||
vector={
|
||||
"": [0.0, 0.25, 0.0, 0.0],
|
||||
"byte": [0, 25, 0, 0],
|
||||
"colbert": [[0.0, 0.25, 0.0, 0.0], [0.0, 0.25, 0.0, 0.0]],
|
||||
"sparse": models.SparseVector(indices=[1], values=[0.25]),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
models.NearestQuery(nearest=1),
|
||||
models.RecommendQuery(recommend=models.RecommendInput(positive=[1], negative=[1])),
|
||||
models.DiscoverQuery(
|
||||
discover=models.DiscoverInput(
|
||||
target=1, context=[models.ContextPair(**{"positive": 1, "negative": 1})]
|
||||
)
|
||||
),
|
||||
models.ContextQuery(context=[models.ContextPair(**{"positive": 1, "negative": 1})]),
|
||||
models.OrderByQuery(order_by=models.OrderBy(key="price", direction=models.Direction.ASC)),
|
||||
models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"using, lookup_from, expected, mentioned",
|
||||
[
|
||||
(None, None, [0.25, 0.0, 0.0, 0.0], True),
|
||||
("", None, [0.25, 0.0, 0.0, 0.0], True),
|
||||
(
|
||||
"byte",
|
||||
models.LookupLocation(collection="collection_multiple_vectors"),
|
||||
[0, 25, 0, 0],
|
||||
False,
|
||||
),
|
||||
(
|
||||
"",
|
||||
models.LookupLocation(collection="collection_multiple_vectors", vector="colbert"),
|
||||
[[0.0, 0.25, 0.0, 0.0], [0.0, 0.25, 0.0, 0.0]],
|
||||
False,
|
||||
),
|
||||
(
|
||||
None,
|
||||
models.LookupLocation(collection="collection_multiple_vectors", vector="sparse"),
|
||||
models.SparseVector(indices=[1], values=[0.25]),
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_vector_dereferencing(client, query, using, lookup_from, expected, mentioned):
|
||||
resolved, mentioned_ids = client._resolve_query_input(
|
||||
collection_name="collection_default",
|
||||
query=copy.deepcopy(query),
|
||||
using=using,
|
||||
lookup_from=lookup_from,
|
||||
)
|
||||
|
||||
if isinstance(resolved, models.NearestQuery):
|
||||
assert resolved.nearest == expected
|
||||
elif isinstance(resolved, models.RecommendQuery):
|
||||
assert resolved.recommend.positive == [expected]
|
||||
assert resolved.recommend.negative == [expected]
|
||||
elif isinstance(resolved, models.DiscoverQuery):
|
||||
assert resolved.discover.target == expected
|
||||
assert resolved.discover.context[0].positive == expected
|
||||
assert resolved.discover.context[0].negative == expected
|
||||
elif isinstance(resolved, models.ContextQuery):
|
||||
assert resolved.context[0].positive == expected
|
||||
assert resolved.context[0].negative == expected
|
||||
else:
|
||||
mentioned = False
|
||||
assert resolved == query
|
||||
|
||||
if mentioned:
|
||||
assert mentioned_ids == {1}
|
||||
@@ -0,0 +1,21 @@
|
||||
import random
|
||||
|
||||
from qdrant_client import models
|
||||
from qdrant_client.local.local_collection import LocalCollection, DEFAULT_VECTOR_NAME
|
||||
|
||||
|
||||
def test_get_vectors():
|
||||
collection = LocalCollection(
|
||||
models.CreateCollection(
|
||||
vectors=models.VectorParams(size=2, distance=models.Distance.MANHATTAN)
|
||||
)
|
||||
)
|
||||
collection.upsert(
|
||||
points=[
|
||||
models.PointStruct(id=i, vector=[random.random(), random.random()]) for i in range(10)
|
||||
]
|
||||
)
|
||||
|
||||
assert collection._get_vectors(idx=1, with_vectors=DEFAULT_VECTOR_NAME)
|
||||
assert collection._get_vectors(idx=2, with_vectors=True)
|
||||
assert collection._get_vectors(idx=3, with_vectors=False) is None
|
||||
Reference in New Issue
Block a user