refactor: excel parse
This commit is contained in:
@@ -0,0 +1,147 @@
|
||||
from itertools import count
|
||||
from time import sleep
|
||||
from typing import Any, Generator, Iterable, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
from qdrant_client import grpc as grpc
|
||||
from qdrant_client import models as rest
|
||||
from qdrant_client.common.client_exceptions import ResourceExhaustedResponse
|
||||
from qdrant_client.connection import get_channel
|
||||
from qdrant_client.conversions.conversion import RestToGrpc, payload_to_grpc
|
||||
from qdrant_client.uploader.uploader import BaseUploader
|
||||
from qdrant_client.common.client_warnings import show_warning
|
||||
from qdrant_client.conversions import common_types as types
|
||||
|
||||
|
||||
def upload_batch_grpc(
|
||||
points_client: grpc.PointsStub,
|
||||
collection_name: str,
|
||||
batch: Union[rest.Batch, tuple], # type: ignore[name-defined]
|
||||
max_retries: int,
|
||||
shard_key_selector: Optional[grpc.ShardKeySelector], # type: ignore[name-defined]
|
||||
update_filter: Optional[grpc.Filter],
|
||||
wait: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
) -> bool:
|
||||
ids_batch, vectors_batch, payload_batch = batch
|
||||
|
||||
ids_batch = (
|
||||
(grpc.PointId(uuid=str(uuid4())) for _ in count()) if ids_batch is None else ids_batch
|
||||
)
|
||||
payload_batch = (None for _ in count()) if payload_batch is None else payload_batch
|
||||
|
||||
points = [
|
||||
grpc.PointStruct(
|
||||
id=RestToGrpc.convert_extended_point_id(idx)
|
||||
if not isinstance(idx, grpc.PointId)
|
||||
else idx,
|
||||
vectors=RestToGrpc.convert_vector_struct(vector),
|
||||
payload=payload_to_grpc(payload or {}),
|
||||
)
|
||||
for idx, vector, payload in zip(ids_batch, vectors_batch, payload_batch)
|
||||
]
|
||||
|
||||
attempt = 0
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
points_client.Upsert(
|
||||
grpc.UpsertPoints(
|
||||
collection_name=collection_name,
|
||||
points=points,
|
||||
wait=wait,
|
||||
shard_key_selector=shard_key_selector,
|
||||
update_filter=update_filter,
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
break
|
||||
except ResourceExhaustedResponse as ex:
|
||||
show_warning(
|
||||
message=f"Batch upload failed due to rate limit. Waiting for {ex.retry_after_s} seconds before retrying...",
|
||||
category=UserWarning,
|
||||
stacklevel=8,
|
||||
)
|
||||
sleep(ex.retry_after_s)
|
||||
|
||||
except Exception as e:
|
||||
show_warning(
|
||||
message=f"Batch upload failed {attempt + 1} times. Retrying...",
|
||||
category=UserWarning,
|
||||
stacklevel=8,
|
||||
)
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
raise e
|
||||
|
||||
attempt += 1
|
||||
return True
|
||||
|
||||
|
||||
class GrpcBatchUploader(BaseUploader):
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
collection_name: str,
|
||||
max_retries: int,
|
||||
wait: bool = False,
|
||||
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._host = host
|
||||
self._port = port
|
||||
self.max_retries = max_retries
|
||||
self._kwargs = kwargs
|
||||
self._wait = wait
|
||||
self._shard_key_selector = (
|
||||
RestToGrpc.convert_shard_key_selector(shard_key_selector)
|
||||
if shard_key_selector is not None
|
||||
else None
|
||||
)
|
||||
self._timeout = kwargs.pop("timeout", None)
|
||||
self._update_filter = (
|
||||
RestToGrpc.convert_filter(update_filter)
|
||||
if isinstance(update_filter, rest.Filter) # type: ignore[attr-defined]
|
||||
else update_filter
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def start(
|
||||
cls,
|
||||
collection_name: Optional[str] = None,
|
||||
host: str = "localhost",
|
||||
port: int = 6334,
|
||||
max_retries: int = 3,
|
||||
**kwargs: Any,
|
||||
) -> "GrpcBatchUploader":
|
||||
if not collection_name:
|
||||
raise RuntimeError("Collection name could not be empty")
|
||||
|
||||
return cls(
|
||||
host=host,
|
||||
port=port,
|
||||
collection_name=collection_name,
|
||||
max_retries=max_retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def process_upload(self, items: Iterable[Any]) -> Generator[bool, None, None]:
|
||||
channel = get_channel(host=self._host, port=self._port, **self._kwargs)
|
||||
points_client = grpc.PointsStub(channel)
|
||||
for batch in items:
|
||||
yield upload_batch_grpc(
|
||||
points_client,
|
||||
self.collection_name,
|
||||
batch,
|
||||
shard_key_selector=self._shard_key_selector,
|
||||
update_filter=self._update_filter,
|
||||
max_retries=self.max_retries,
|
||||
wait=self._wait,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
def process(self, items: Iterable[Any]) -> Iterable[bool]:
|
||||
yield from self.process_upload(items)
|
||||
@@ -0,0 +1,118 @@
|
||||
from itertools import count
|
||||
from time import sleep
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client import grpc as grpc
|
||||
from qdrant_client.common.client_exceptions import ResourceExhaustedResponse
|
||||
from qdrant_client.http import SyncApis
|
||||
from qdrant_client import models as rest
|
||||
from qdrant_client.uploader.uploader import BaseUploader
|
||||
from qdrant_client.common.client_warnings import show_warning
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.conversions.conversion import GrpcToRest
|
||||
|
||||
|
||||
def upload_batch(
|
||||
openapi_client: SyncApis,
|
||||
collection_name: str,
|
||||
batch: Union[tuple, rest.Batch], # type: ignore[name-defined]
|
||||
max_retries: int,
|
||||
shard_key_selector: Optional[rest.ShardKeySelector], # type: ignore[name-defined]
|
||||
update_filter: Optional[rest.Filter], # type: ignore[name-defined]
|
||||
wait: bool = False,
|
||||
) -> bool:
|
||||
ids_batch, vectors_batch, payload_batch = batch
|
||||
|
||||
ids_batch = (str(uuid4()) for _ in count()) if ids_batch is None else ids_batch
|
||||
payload_batch = (None for _ in count()) if payload_batch is None else payload_batch
|
||||
|
||||
points = [
|
||||
rest.PointStruct( # type: ignore[attr-defined]
|
||||
id=idx,
|
||||
vector=(vector.tolist() if isinstance(vector, np.ndarray) else vector) or {},
|
||||
payload=payload,
|
||||
)
|
||||
for idx, vector, payload in zip(ids_batch, vectors_batch, payload_batch)
|
||||
]
|
||||
|
||||
attempt = 0
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
openapi_client.points_api.upsert_points(
|
||||
collection_name=collection_name,
|
||||
point_insert_operations=rest.PointsList( # type: ignore[attr-defined]
|
||||
points=points, shard_key=shard_key_selector, update_filter=update_filter
|
||||
),
|
||||
wait=wait,
|
||||
)
|
||||
break
|
||||
except ResourceExhaustedResponse as ex:
|
||||
show_warning(
|
||||
message=f"Batch upload failed due to rate limit. Waiting for {ex.retry_after_s} seconds before retrying...",
|
||||
category=UserWarning,
|
||||
stacklevel=7,
|
||||
)
|
||||
sleep(ex.retry_after_s)
|
||||
|
||||
except Exception as e:
|
||||
show_warning(
|
||||
message=f"Batch upload failed {attempt + 1} times. Retrying...",
|
||||
category=UserWarning,
|
||||
stacklevel=7,
|
||||
)
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
raise e
|
||||
|
||||
attempt += 1
|
||||
return True
|
||||
|
||||
|
||||
class RestBatchUploader(BaseUploader):
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
collection_name: str,
|
||||
max_retries: int,
|
||||
wait: bool = False,
|
||||
shard_key_selector: Optional[types.ShardKeySelector] = None,
|
||||
update_filter: Optional[types.Filter] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.openapi_client: SyncApis = SyncApis(host=uri, **kwargs)
|
||||
self.max_retries = max_retries
|
||||
self._wait = wait
|
||||
self._shard_key_selector = shard_key_selector
|
||||
self._update_filter = (
|
||||
GrpcToRest.convert_filter(model=update_filter)
|
||||
if isinstance(update_filter, grpc.Filter)
|
||||
else update_filter
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def start(
|
||||
cls,
|
||||
collection_name: Optional[str] = None,
|
||||
uri: str = "http://localhost:6333",
|
||||
max_retries: int = 3,
|
||||
**kwargs: Any,
|
||||
) -> "RestBatchUploader":
|
||||
if not collection_name:
|
||||
raise RuntimeError("Collection name could not be empty")
|
||||
return cls(uri=uri, collection_name=collection_name, max_retries=max_retries, **kwargs)
|
||||
|
||||
def process(self, items: Iterable[Any]) -> Iterable[bool]:
|
||||
for batch in items:
|
||||
yield upload_batch(
|
||||
self.openapi_client,
|
||||
self.collection_name,
|
||||
batch,
|
||||
shard_key_selector=self._shard_key_selector,
|
||||
max_retries=self.max_retries,
|
||||
update_filter=self._update_filter,
|
||||
wait=self._wait,
|
||||
)
|
||||
@@ -0,0 +1,94 @@
|
||||
from abc import ABC
|
||||
from itertools import count, islice
|
||||
from typing import Any, Generator, Iterable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qdrant_client.conversions import common_types as types
|
||||
from qdrant_client.conversions.common_types import Record
|
||||
from qdrant_client.http.models import ExtendedPointId
|
||||
from qdrant_client.parallel_processor import Worker
|
||||
|
||||
|
||||
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
|
||||
"""
|
||||
>>> list(iter_batch([1,2,3,4,5], 3))
|
||||
[[1, 2, 3], [4, 5]]
|
||||
"""
|
||||
source_iter = iter(iterable)
|
||||
while source_iter:
|
||||
b = list(islice(source_iter, size))
|
||||
if len(b) == 0:
|
||||
break
|
||||
yield b
|
||||
|
||||
|
||||
class BaseUploader(Worker, ABC):
|
||||
@classmethod
|
||||
def iterate_records_batches(
|
||||
cls,
|
||||
records: Iterable[Union[Record, types.PointStruct]],
|
||||
batch_size: int,
|
||||
) -> Iterable:
|
||||
record_batches = iter_batch(records, batch_size)
|
||||
for record_batch in record_batches:
|
||||
ids_batch, vectors_batch, payload_batch = [], [], []
|
||||
|
||||
for record in record_batch:
|
||||
ids_batch.append(record.id)
|
||||
vectors_batch.append(record.vector)
|
||||
payload_batch.append(record.payload)
|
||||
|
||||
yield ids_batch, vectors_batch, payload_batch
|
||||
|
||||
@classmethod
|
||||
def iterate_batches(
|
||||
cls,
|
||||
vectors: Union[
|
||||
dict[str, types.NumpyArray], types.NumpyArray, Iterable[types.VectorStruct]
|
||||
],
|
||||
payload: Optional[Iterable[dict]],
|
||||
ids: Optional[Iterable[ExtendedPointId]],
|
||||
batch_size: int,
|
||||
) -> Iterable:
|
||||
if ids is None:
|
||||
ids_batches: Iterable = (None for _ in count())
|
||||
else:
|
||||
ids_batches = iter_batch(ids, batch_size)
|
||||
|
||||
if payload is None:
|
||||
payload_batches: Iterable = (None for _ in count())
|
||||
else:
|
||||
payload_batches = iter_batch(payload, batch_size)
|
||||
|
||||
if isinstance(vectors, np.ndarray):
|
||||
vector_batches: Iterable[Any] = cls._vector_batches_from_numpy(vectors, batch_size)
|
||||
elif isinstance(vectors, dict) and any(
|
||||
isinstance(value, np.ndarray) for value in vectors.values()
|
||||
):
|
||||
vector_batches = cls._vector_batches_from_numpy_named_vectors(vectors, batch_size)
|
||||
else:
|
||||
vector_batches = iter_batch(vectors, batch_size)
|
||||
|
||||
yield from zip(ids_batches, vector_batches, payload_batches)
|
||||
|
||||
@staticmethod
|
||||
def _vector_batches_from_numpy(vectors: types.NumpyArray, batch_size: int) -> Iterable[float]:
|
||||
for i in range(0, vectors.shape[0], batch_size):
|
||||
yield vectors[i : i + batch_size].tolist()
|
||||
|
||||
@staticmethod
|
||||
def _vector_batches_from_numpy_named_vectors(
|
||||
vectors: dict[str, types.NumpyArray], batch_size: int
|
||||
) -> Iterable[dict[str, list[float]]]:
|
||||
assert (
|
||||
len(set([arr.shape[0] for arr in vectors.values()])) == 1
|
||||
), "Each named vector should have the same number of vectors"
|
||||
|
||||
num_vectors = next(iter(vectors.values())).shape[0]
|
||||
# Convert dict[str, np.ndarray] to Generator(dict[str, list[float]])
|
||||
vector_batches = (
|
||||
{name: vectors[name][i].tolist() for name in vectors.keys()}
|
||||
for i in range(num_vectors)
|
||||
)
|
||||
yield from iter_batch(vector_batches, batch_size)
|
||||
Reference in New Issue
Block a user