183 lines
7.3 KiB
Python
183 lines
7.3 KiB
Python
import time
|
|
from typing import Iterable, Optional
|
|
|
|
from qdrant_client._pydantic_compat import to_dict, model_fields
|
|
from qdrant_client.client_base import QdrantBase
|
|
from qdrant_client.http import models
|
|
|
|
|
|
def upload_with_retry(
|
|
client: QdrantBase,
|
|
collection_name: str,
|
|
points: Iterable[models.PointStruct],
|
|
max_attempts: int = 3,
|
|
pause: float = 3.0,
|
|
) -> None:
|
|
attempts = 1
|
|
while attempts <= max_attempts:
|
|
try:
|
|
client.upload_points(
|
|
collection_name=collection_name,
|
|
points=points,
|
|
wait=True,
|
|
)
|
|
return
|
|
except Exception as e:
|
|
print(f"Exception: {e}, attempt {attempts}/{max_attempts}")
|
|
if attempts < max_attempts:
|
|
print(f"Next attempt in {pause} seconds")
|
|
time.sleep(pause)
|
|
attempts += 1
|
|
|
|
raise Exception(f"Failed to upload points after {max_attempts} attempts")
|
|
|
|
|
|
def migrate(
|
|
source_client: QdrantBase,
|
|
dest_client: QdrantBase,
|
|
collection_names: Optional[list[str]] = None,
|
|
recreate_on_collision: bool = False,
|
|
batch_size: int = 100,
|
|
) -> None:
|
|
"""
|
|
Migrate collections from source client to destination client
|
|
|
|
Args:
|
|
source_client (QdrantBase): Source client
|
|
dest_client (QdrantBase): Destination client
|
|
collection_names (list[str], optional): List of collection names to migrate.
|
|
If None - migrate all source client collections. Defaults to None.
|
|
recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise
|
|
raise ValueError.
|
|
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
|
|
"""
|
|
collection_names = _select_source_collections(source_client, collection_names)
|
|
if any(
|
|
_has_custom_shards(source_client, collection_name) for collection_name in collection_names
|
|
):
|
|
raise ValueError("Migration of collections with custom shards is not supported yet")
|
|
|
|
collisions = _find_collisions(dest_client, collection_names)
|
|
absent_dest_collections = set(collection_names) - set(collisions)
|
|
|
|
if collisions and not recreate_on_collision:
|
|
raise ValueError(f"Collections already exist in dest_client: {collisions}")
|
|
|
|
for collection_name in absent_dest_collections:
|
|
_recreate_collection(source_client, dest_client, collection_name)
|
|
_migrate_collection(source_client, dest_client, collection_name, batch_size)
|
|
|
|
for collection_name in collisions:
|
|
_recreate_collection(source_client, dest_client, collection_name)
|
|
_migrate_collection(source_client, dest_client, collection_name, batch_size)
|
|
|
|
|
|
def _has_custom_shards(source_client: QdrantBase, collection_name: str) -> bool:
|
|
collection_info = source_client.get_collection(collection_name)
|
|
return (
|
|
getattr(collection_info.config.params, "sharding_method", None)
|
|
== models.ShardingMethod.CUSTOM
|
|
)
|
|
|
|
|
|
def _select_source_collections(
|
|
source_client: QdrantBase, collection_names: Optional[list[str]] = None
|
|
) -> list[str]:
|
|
source_collections = source_client.get_collections().collections
|
|
source_collection_names = [collection.name for collection in source_collections]
|
|
|
|
if collection_names is not None:
|
|
assert all(
|
|
collection_name in source_collection_names for collection_name in collection_names
|
|
), f"Source client does not have collections: {set(collection_names) - set(source_collection_names)}"
|
|
else:
|
|
collection_names = source_collection_names
|
|
|
|
return collection_names
|
|
|
|
|
|
def _find_collisions(dest_client: QdrantBase, collection_names: list[str]) -> list[str]:
|
|
dest_collections = dest_client.get_collections().collections
|
|
dest_collection_names = {collection.name for collection in dest_collections}
|
|
existing_dest_collections = dest_collection_names & set(collection_names)
|
|
return list(existing_dest_collections)
|
|
|
|
|
|
def _recreate_collection(
|
|
source_client: QdrantBase,
|
|
dest_client: QdrantBase,
|
|
collection_name: str,
|
|
) -> None:
|
|
src_collection_info = source_client.get_collection(collection_name)
|
|
src_config = src_collection_info.config
|
|
src_payload_schema = src_collection_info.payload_schema
|
|
if dest_client.collection_exists(collection_name):
|
|
dest_client.delete_collection(collection_name)
|
|
|
|
strict_mode_config: Optional[models.StrictModeConfig] = None
|
|
if src_config.strict_mode_config is not None:
|
|
strict_mode_config = models.StrictModeConfig(
|
|
**{
|
|
k: v
|
|
for k, v in to_dict(src_config.strict_mode_config).items()
|
|
if k in model_fields(models.StrictModeConfig)
|
|
}
|
|
)
|
|
dest_client.create_collection(
|
|
collection_name,
|
|
vectors_config=src_config.params.vectors,
|
|
sparse_vectors_config=src_config.params.sparse_vectors,
|
|
shard_number=src_config.params.shard_number,
|
|
replication_factor=src_config.params.replication_factor,
|
|
write_consistency_factor=src_config.params.write_consistency_factor,
|
|
on_disk_payload=src_config.params.on_disk_payload,
|
|
hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)),
|
|
optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)),
|
|
wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
|
|
quantization_config=src_config.quantization_config,
|
|
strict_mode_config=strict_mode_config,
|
|
)
|
|
|
|
_recreate_payload_schema(dest_client, collection_name, src_payload_schema)
|
|
|
|
|
|
def _recreate_payload_schema(
|
|
dest_client: QdrantBase,
|
|
collection_name: str,
|
|
payload_schema: dict[str, models.PayloadIndexInfo],
|
|
) -> None:
|
|
for field_name, field_info in payload_schema.items():
|
|
dest_client.create_payload_index(
|
|
collection_name,
|
|
field_name=field_name,
|
|
field_schema=field_info.data_type if field_info.params is None else field_info.params,
|
|
)
|
|
|
|
|
|
def _migrate_collection(
|
|
source_client: QdrantBase,
|
|
dest_client: QdrantBase,
|
|
collection_name: str,
|
|
batch_size: int = 100,
|
|
) -> None:
|
|
"""Migrate collection from source client to destination client
|
|
|
|
Args:
|
|
collection_name (str): Collection name
|
|
source_client (QdrantBase): Source client
|
|
dest_client (QdrantBase): Destination client
|
|
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
|
|
"""
|
|
records, next_offset = source_client.scroll(collection_name, limit=2, with_vectors=True)
|
|
upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
|
|
while next_offset is not None:
|
|
records, next_offset = source_client.scroll(
|
|
collection_name, offset=next_offset, limit=batch_size, with_vectors=True
|
|
)
|
|
upload_with_retry(client=dest_client, collection_name=collection_name, points=records) # type: ignore
|
|
source_client_vectors_count = source_client.count(collection_name).count
|
|
dest_client_vectors_count = dest_client.count(collection_name).count
|
|
assert (
|
|
source_client_vectors_count == dest_client_vectors_count
|
|
), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}"
|