refactor: excel parse
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,94 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.embed.models import NumericVector
|
||||
|
||||
|
||||
class BuiltinEmbedder:
|
||||
_SUPPORTED_MODELS = ("Qdrant/Bm25",)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def embed(
|
||||
self,
|
||||
model_name: str,
|
||||
texts: Optional[list[str]] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> NumericVector:
|
||||
if texts is None:
|
||||
if "images" in kwargs:
|
||||
raise ValueError(
|
||||
"Image processing is only available with cloud inference of FastEmbed"
|
||||
)
|
||||
|
||||
raise ValueError("Texts must be provided for the inference")
|
||||
|
||||
if not self.is_supported_sparse_model(model_name):
|
||||
raise ValueError(
|
||||
f"Model {model_name} is not supported in {self.__class__.__name__}. "
|
||||
f"Did you forget to enable cloud inference or install FastEmbed for local inference?"
|
||||
)
|
||||
|
||||
return [models.Document(text=text, options=options, model=model_name) for text in texts]
|
||||
|
||||
@classmethod
|
||||
def is_supported_text_model(cls, model_name: str) -> bool:
|
||||
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return False # currently only Qdrant/Bm25 is supported
|
||||
|
||||
@classmethod
|
||||
def is_supported_image_model(cls, model_name: str) -> bool:
|
||||
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return False # currently only Qdrant/Bm25 is supported
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_text_model(cls, model_name: str) -> bool:
|
||||
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return False # currently only Qdrant/Bm25 is supported
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_multimodal_model(cls, model_name: str) -> bool:
|
||||
"""Mock embedder interface, only sparse text model Qdrant/Bm25 is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return False # currently only Qdrant/Bm25 is supported
|
||||
|
||||
@classmethod
|
||||
def is_supported_sparse_model(cls, model_name: str) -> bool:
|
||||
"""Checks if the model is supported. Only `Qdrant/Bm25` is supported
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return model_name.lower() in [model.lower() for model in cls._SUPPORTED_MODELS]
|
||||
@@ -0,0 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
INFERENCE_OBJECT_NAMES: set[str] = {"Document", "Image", "InferenceObject"}
|
||||
INFERENCE_OBJECT_TYPES = Union[models.Document, models.Image, models.InferenceObject]
|
||||
@@ -0,0 +1,176 @@
|
||||
from copy import copy
|
||||
from typing import Union, Optional, Iterable, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client._pydantic_compat import model_fields_set
|
||||
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
|
||||
from qdrant_client.embed.schema_parser import ModelSchemaParser
|
||||
|
||||
from qdrant_client.embed.utils import convert_paths, FieldPath
|
||||
|
||||
|
||||
class InspectorEmbed:
|
||||
"""Inspector which collects paths to objects requiring inference in the received models
|
||||
|
||||
Attributes:
|
||||
parser: ModelSchemaParser instance
|
||||
"""
|
||||
|
||||
def __init__(self, parser: Optional[ModelSchemaParser] = None) -> None:
|
||||
self.parser = ModelSchemaParser() if parser is None else parser
|
||||
|
||||
def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> list[FieldPath]:
|
||||
"""Looks for all the paths to objects requiring inference in the received models
|
||||
|
||||
Args:
|
||||
points: models to inspect
|
||||
|
||||
Returns:
|
||||
list of FieldPath objects
|
||||
"""
|
||||
paths = []
|
||||
if isinstance(points, BaseModel):
|
||||
self.parser.parse_model(points.__class__)
|
||||
paths.extend(self._inspect_model(points))
|
||||
elif isinstance(points, dict):
|
||||
for value in points.values():
|
||||
paths.extend(self.inspect(value))
|
||||
elif isinstance(points, Iterable):
|
||||
for point in points:
|
||||
if isinstance(point, BaseModel):
|
||||
self.parser.parse_model(point.__class__)
|
||||
paths.extend(self._inspect_model(point))
|
||||
|
||||
paths = sorted(set(paths))
|
||||
|
||||
return convert_paths(paths)
|
||||
|
||||
def _inspect_model(
|
||||
self, mod: BaseModel, paths: Optional[list[FieldPath]] = None, accum: Optional[str] = None
|
||||
) -> list[str]:
|
||||
"""Looks for all the paths to objects requiring inference in the received model
|
||||
|
||||
Args:
|
||||
mod: model to inspect
|
||||
paths: list of paths to the fields possibly containing objects for inference
|
||||
accum: accumulator for the path. Path is a dot separated string of field names which we assemble recursively
|
||||
|
||||
Returns:
|
||||
list of paths to the model fields containing objects for inference
|
||||
"""
|
||||
paths = self.parser.path_cache.get(mod.__class__.__name__, []) if paths is None else paths
|
||||
|
||||
found_paths = []
|
||||
for path in paths:
|
||||
found_paths.extend(
|
||||
self._inspect_inner_models(
|
||||
mod, path.current, path.tail if path.tail else [], accum
|
||||
)
|
||||
)
|
||||
return found_paths
|
||||
|
||||
def _inspect_inner_models(
|
||||
self,
|
||||
original_model: BaseModel,
|
||||
current_path: str,
|
||||
tail: list[FieldPath],
|
||||
accum: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Looks for all the paths to objects requiring inference in the received model
|
||||
|
||||
Args:
|
||||
original_model: model to inspect
|
||||
current_path: the field to inspect on the current iteration
|
||||
tail: list of FieldPath objects to the fields possibly containing objects for inference
|
||||
accum: accumulator for the path. Path is a dot separated string of field names which we assemble recursively
|
||||
|
||||
Returns:
|
||||
list of paths to the model fields containing objects for inference
|
||||
"""
|
||||
found_paths = []
|
||||
if accum is None:
|
||||
accum = current_path
|
||||
else:
|
||||
accum += f".{current_path}"
|
||||
|
||||
def inspect_recursive(member: BaseModel, accumulator: str) -> list[str]:
|
||||
"""Iterates over the set model fields, expand recursive ones and find paths to objects requiring inference
|
||||
|
||||
Args:
|
||||
member: currently inspected model, which may or may not contain recursive fields
|
||||
accumulator: accumulator for the path, which is a dot separated string assembled recursively
|
||||
"""
|
||||
recursive_paths = []
|
||||
for field in model_fields_set(member):
|
||||
if field in self.parser.name_recursive_ref_mapping:
|
||||
mapped_field = self.parser.name_recursive_ref_mapping[field]
|
||||
recursive_paths.extend(self.parser.path_cache[mapped_field])
|
||||
|
||||
return self._inspect_model(member, copy(recursive_paths), accumulator)
|
||||
|
||||
model = getattr(original_model, current_path, None)
|
||||
if model is None:
|
||||
return []
|
||||
|
||||
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return [accum]
|
||||
|
||||
if isinstance(model, BaseModel):
|
||||
found_paths.extend(inspect_recursive(model, accum))
|
||||
|
||||
for next_path in tail:
|
||||
found_paths.extend(
|
||||
self._inspect_inner_models(
|
||||
model, next_path.current, next_path.tail if next_path.tail else [], accum
|
||||
)
|
||||
)
|
||||
|
||||
return found_paths
|
||||
|
||||
elif isinstance(model, list):
|
||||
for current_model in model:
|
||||
if not isinstance(current_model, BaseModel):
|
||||
continue
|
||||
|
||||
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
found_paths.append(accum)
|
||||
|
||||
found_paths.extend(inspect_recursive(current_model, accum))
|
||||
|
||||
for next_path in tail:
|
||||
for current_model in model:
|
||||
found_paths.extend(
|
||||
self._inspect_inner_models(
|
||||
current_model,
|
||||
next_path.current,
|
||||
next_path.tail if next_path.tail else [],
|
||||
accum,
|
||||
)
|
||||
)
|
||||
return found_paths
|
||||
|
||||
elif isinstance(model, dict):
|
||||
found_paths = []
|
||||
for key, values in model.items():
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
for current_model in values:
|
||||
if not isinstance(current_model, BaseModel):
|
||||
continue
|
||||
|
||||
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
found_paths.append(accum)
|
||||
|
||||
found_paths.extend(inspect_recursive(current_model, accum))
|
||||
|
||||
for next_path in tail:
|
||||
for current_model in values:
|
||||
found_paths.extend(
|
||||
self._inspect_inner_models(
|
||||
current_model,
|
||||
next_path.current,
|
||||
next_path.tail if next_path.tail else [],
|
||||
accum,
|
||||
)
|
||||
)
|
||||
return found_paths
|
||||
@@ -0,0 +1,447 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Sequence, Any, TypeVar, Generic
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.embed.models import NumericVector
|
||||
from qdrant_client.fastembed_common import (
|
||||
OnnxProvider,
|
||||
ImageInput,
|
||||
TextEmbedding,
|
||||
SparseTextEmbedding,
|
||||
LateInteractionTextEmbedding,
|
||||
LateInteractionMultimodalEmbedding,
|
||||
ImageEmbedding,
|
||||
FastEmbedMisc,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ModelInstance(BaseModel, Generic[T], arbitrary_types_allowed=True): # type: ignore[call-arg]
|
||||
model: T
|
||||
options: dict[str, Any]
|
||||
deprecated: bool = False
|
||||
|
||||
|
||||
class Embedder:
|
||||
def __init__(self, threads: Optional[int] = None, **kwargs: Any) -> None:
|
||||
self.embedding_models: dict[str, list[ModelInstance[TextEmbedding]]] = defaultdict(list)
|
||||
self.sparse_embedding_models: dict[str, list[ModelInstance[SparseTextEmbedding]]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
self.late_interaction_embedding_models: dict[
|
||||
str, list[ModelInstance[LateInteractionTextEmbedding]]
|
||||
] = defaultdict(list)
|
||||
self.image_embedding_models: dict[str, list[ModelInstance[ImageEmbedding]]] = defaultdict(
|
||||
list
|
||||
)
|
||||
self.late_interaction_multimodal_embedding_models: dict[
|
||||
str, list[ModelInstance[LateInteractionMultimodalEmbedding]]
|
||||
] = defaultdict(list)
|
||||
self._threads = threads
|
||||
|
||||
def get_or_init_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
deprecated: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> TextEmbedding:
|
||||
if not FastEmbedMisc.is_supported_text_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_text_models()}"
|
||||
)
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
for instance in self.embedding_models[model_name]:
|
||||
if (deprecated and instance.deprecated) or (
|
||||
not deprecated and instance.options == options
|
||||
):
|
||||
return instance.model
|
||||
|
||||
model = TextEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[TextEmbedding] = ModelInstance(
|
||||
model=model, options=options, deprecated=deprecated
|
||||
)
|
||||
self.embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def get_or_init_sparse_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
deprecated: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> SparseTextEmbedding:
|
||||
if not FastEmbedMisc.is_supported_sparse_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_sparse_models()}"
|
||||
)
|
||||
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for instance in self.sparse_embedding_models[model_name]:
|
||||
if (deprecated and instance.deprecated) or (
|
||||
not deprecated and instance.options == options
|
||||
):
|
||||
return instance.model
|
||||
|
||||
model = SparseTextEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[SparseTextEmbedding] = ModelInstance(
|
||||
model=model, options=options, deprecated=deprecated
|
||||
)
|
||||
self.sparse_embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def get_or_init_late_interaction_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LateInteractionTextEmbedding:
|
||||
if not FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. "
|
||||
f"Supported models: {FastEmbedMisc.list_late_interaction_text_models()}"
|
||||
)
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for instance in self.late_interaction_embedding_models[model_name]:
|
||||
if instance.options == options:
|
||||
return instance.model
|
||||
|
||||
model = LateInteractionTextEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[LateInteractionTextEmbedding] = ModelInstance(
|
||||
model=model, options=options
|
||||
)
|
||||
self.late_interaction_embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def get_or_init_late_interaction_multimodal_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LateInteractionMultimodalEmbedding:
|
||||
if not FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. "
|
||||
f"Supported models: {FastEmbedMisc.list_late_interaction_multimodal_models()}"
|
||||
)
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for instance in self.late_interaction_multimodal_embedding_models[model_name]:
|
||||
if instance.options == options:
|
||||
return instance.model
|
||||
|
||||
model = LateInteractionMultimodalEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[LateInteractionMultimodalEmbedding] = ModelInstance(
|
||||
model=model, options=options
|
||||
)
|
||||
self.late_interaction_multimodal_embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def get_or_init_image_model(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
providers: Optional[Sequence["OnnxProvider"]] = None,
|
||||
cuda: bool = False,
|
||||
device_ids: Optional[list[int]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageEmbedding:
|
||||
if not FastEmbedMisc.is_supported_image_model(model_name):
|
||||
raise ValueError(
|
||||
f"Unsupported embedding model: {model_name}. Supported models: {FastEmbedMisc.list_image_models()}"
|
||||
)
|
||||
options = {
|
||||
"cache_dir": cache_dir,
|
||||
"threads": threads or self._threads,
|
||||
"providers": providers,
|
||||
"cuda": cuda,
|
||||
"device_ids": device_ids,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for instance in self.image_embedding_models[model_name]:
|
||||
if instance.options == options:
|
||||
return instance.model
|
||||
|
||||
model = ImageEmbedding(model_name=model_name, **options)
|
||||
model_instance: ModelInstance[ImageEmbedding] = ModelInstance(model=model, options=options)
|
||||
self.image_embedding_models[model_name].append(model_instance)
|
||||
return model
|
||||
|
||||
def embed(
|
||||
self,
|
||||
model_name: str,
|
||||
texts: Optional[list[str]] = None,
|
||||
images: Optional[list[ImageInput]] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
is_query: bool = False,
|
||||
batch_size: int = 8,
|
||||
) -> NumericVector:
|
||||
if (texts is None) is (images is None):
|
||||
raise ValueError("Either documents or images should be provided")
|
||||
|
||||
embeddings: NumericVector # define type for a static type checker
|
||||
if texts is not None:
|
||||
if FastEmbedMisc.is_supported_text_model(model_name):
|
||||
embeddings = self._embed_dense_text(
|
||||
texts, model_name, options, is_query, batch_size
|
||||
)
|
||||
elif FastEmbedMisc.is_supported_sparse_model(model_name):
|
||||
embeddings = self._embed_sparse_text(
|
||||
texts, model_name, options, is_query, batch_size
|
||||
)
|
||||
elif FastEmbedMisc.is_supported_late_interaction_text_model(model_name):
|
||||
embeddings = self._embed_late_interaction_text(
|
||||
texts, model_name, options, is_query, batch_size
|
||||
)
|
||||
elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
|
||||
embeddings = self._embed_late_interaction_multimodal_text(
|
||||
texts, model_name, options, batch_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding model: {model_name}")
|
||||
else:
|
||||
assert (
|
||||
images is not None
|
||||
) # just to satisfy mypy which can't infer it from the previous conditions
|
||||
if FastEmbedMisc.is_supported_image_model(model_name):
|
||||
embeddings = self._embed_dense_image(images, model_name, options, batch_size)
|
||||
elif FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name):
|
||||
embeddings = self._embed_late_interaction_multimodal_image(
|
||||
images, model_name, options, batch_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding model: {model_name}")
|
||||
|
||||
return embeddings
|
||||
|
||||
def _embed_dense_text(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
is_query: bool,
|
||||
batch_size: int,
|
||||
) -> list[list[float]]:
|
||||
embedding_model_inst = self.get_or_init_model(model_name=model_name, **options or {})
|
||||
|
||||
if not is_query:
|
||||
embeddings = [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
|
||||
]
|
||||
else:
|
||||
embeddings = [
|
||||
embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
|
||||
]
|
||||
return embeddings
|
||||
|
||||
def _embed_sparse_text(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
is_query: bool,
|
||||
batch_size: int,
|
||||
) -> list[models.SparseVector]:
|
||||
embedding_model_inst = self.get_or_init_sparse_model(
|
||||
model_name=model_name, **options or {}
|
||||
)
|
||||
if not is_query:
|
||||
embeddings = [
|
||||
models.SparseVector(
|
||||
indices=sparse_embedding.indices.tolist(),
|
||||
values=sparse_embedding.values.tolist(),
|
||||
)
|
||||
for sparse_embedding in embedding_model_inst.embed(
|
||||
documents=texts, batch_size=batch_size
|
||||
)
|
||||
]
|
||||
else:
|
||||
embeddings = [
|
||||
models.SparseVector(
|
||||
indices=sparse_embedding.indices.tolist(),
|
||||
values=sparse_embedding.values.tolist(),
|
||||
)
|
||||
for sparse_embedding in embedding_model_inst.query_embed(query=texts)
|
||||
]
|
||||
return embeddings
|
||||
|
||||
def _embed_late_interaction_text(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
is_query: bool,
|
||||
batch_size: int,
|
||||
) -> list[list[list[float]]]:
|
||||
embedding_model_inst = self.get_or_init_late_interaction_model(
|
||||
model_name=model_name, **options or {}
|
||||
)
|
||||
if not is_query:
|
||||
embeddings = [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed(documents=texts, batch_size=batch_size)
|
||||
]
|
||||
else:
|
||||
embeddings = [
|
||||
embedding.tolist() for embedding in embedding_model_inst.query_embed(query=texts)
|
||||
]
|
||||
return embeddings
|
||||
|
||||
def _embed_late_interaction_multimodal_text(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
batch_size: int,
|
||||
) -> list[list[list[float]]]:
|
||||
embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
|
||||
model_name=model_name, **options or {}
|
||||
)
|
||||
return [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed_text(
|
||||
documents=texts, batch_size=batch_size
|
||||
)
|
||||
]
|
||||
|
||||
def _embed_late_interaction_multimodal_image(
|
||||
self,
|
||||
images: list[ImageInput],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
batch_size: int,
|
||||
) -> list[list[list[float]]]:
|
||||
embedding_model_inst = self.get_or_init_late_interaction_multimodal_model(
|
||||
model_name=model_name, **options or {}
|
||||
)
|
||||
return [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed_image(images=images, batch_size=batch_size)
|
||||
]
|
||||
|
||||
def _embed_dense_image(
|
||||
self,
|
||||
images: list[ImageInput],
|
||||
model_name: str,
|
||||
options: Optional[dict[str, Any]],
|
||||
batch_size: int,
|
||||
) -> list[list[float]]:
|
||||
embedding_model_inst = self.get_or_init_image_model(model_name=model_name, **options or {})
|
||||
embeddings = [
|
||||
embedding.tolist()
|
||||
for embedding in embedding_model_inst.embed(images=images, batch_size=batch_size)
|
||||
]
|
||||
return embeddings
|
||||
|
||||
@classmethod
|
||||
def is_supported_text_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_text_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def is_supported_image_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_image_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_text_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_late_interaction_text_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def is_supported_late_interaction_multimodal_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_late_interaction_multimodal_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def is_supported_sparse_model(cls, model_name: str) -> bool:
|
||||
"""Check if model is supported by fastembed
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is supported, False otherwise.
|
||||
"""
|
||||
return FastEmbedMisc.is_supported_sparse_model(model_name)
|
||||
@@ -0,0 +1,498 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from multiprocessing import get_all_start_methods
|
||||
from typing import Optional, Union, Iterable, Any, Type, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client.embed.builtin_embedder import BuiltinEmbedder
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
|
||||
from qdrant_client.embed.embed_inspector import InspectorEmbed
|
||||
from qdrant_client.embed.embedder import Embedder
|
||||
from qdrant_client.embed.models import NumericVector, NumericVectorStruct
|
||||
from qdrant_client.embed.schema_parser import ModelSchemaParser
|
||||
from qdrant_client.embed.utils import FieldPath
|
||||
from qdrant_client.fastembed_common import FastEmbedMisc
|
||||
from qdrant_client.parallel_processor import ParallelWorkerPool, Worker
|
||||
from qdrant_client.uploader.uploader import iter_batch
|
||||
|
||||
|
||||
class ModelEmbedderWorker(Worker):
|
||||
def __init__(self, batch_size: int, **kwargs: Any):
|
||||
self.model_embedder = ModelEmbedder(**kwargs)
|
||||
self.batch_size = batch_size
|
||||
|
||||
@classmethod
|
||||
def start(cls, batch_size: int, **kwargs: Any) -> "ModelEmbedderWorker":
|
||||
return cls(threads=1, batch_size=batch_size, **kwargs)
|
||||
|
||||
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
|
||||
for idx, batch in items:
|
||||
yield (
|
||||
idx,
|
||||
list(
|
||||
self.model_embedder.embed_models_batch(
|
||||
batch, inference_batch_size=self.batch_size
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ModelEmbedder:
|
||||
MAX_INTERNAL_BATCH_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parser: Optional[ModelSchemaParser] = None,
|
||||
is_local_mode: bool = False,
|
||||
server_version: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._batch_accumulator: dict[str, list[INFERENCE_OBJECT_TYPES]] = {}
|
||||
self._embed_storage: dict[str, list[NumericVector]] = {}
|
||||
self._embed_inspector = InspectorEmbed(parser=parser)
|
||||
self._is_builtin_embedder_available = self._check_builtin_embedder_availability(
|
||||
is_local_mode, server_version
|
||||
)
|
||||
self.embedder = (
|
||||
Embedder(**kwargs) if FastEmbedMisc.is_installed() else BuiltinEmbedder(**kwargs)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_builtin_embedder_availability(
|
||||
is_local_mode: bool, server_version: Optional[str]
|
||||
) -> bool:
|
||||
if is_local_mode:
|
||||
return False
|
||||
|
||||
if (
|
||||
server_version is None
|
||||
): # failed to detect server version, it might happen due to security or network
|
||||
# problems even on supported server versions, so we are not blocking usage of BuiltinEmbedder.
|
||||
return True
|
||||
|
||||
try:
|
||||
major, minor, patch = server_version.split(".")
|
||||
patch = patch.split("-")[0]
|
||||
|
||||
if (int(major), int(minor), int(patch)) >= (1, 15, 3):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
def embed_models(
|
||||
self,
|
||||
raw_models: Union[BaseModel, Iterable[BaseModel]],
|
||||
is_query: bool = False,
|
||||
batch_size: int = 8,
|
||||
) -> Iterable[BaseModel]:
|
||||
"""Embed raw data fields in models and return models with vectors
|
||||
|
||||
If any of model fields required inference, a deepcopy of a model with computed embeddings is returned,
|
||||
otherwise returns original models.
|
||||
Args:
|
||||
raw_models: Iterable[BaseModel] - models which can contain fields with raw data
|
||||
is_query: bool - flag to determine which embed method to use. Defaults to False.
|
||||
batch_size: int - batch size for inference
|
||||
Returns:
|
||||
list[BaseModel]: models with embedded fields
|
||||
"""
|
||||
if not self._is_builtin_embedder_available:
|
||||
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
|
||||
|
||||
if isinstance(raw_models, BaseModel):
|
||||
raw_models = [raw_models]
|
||||
for raw_models_batch in iter_batch(raw_models, batch_size):
|
||||
yield from self.embed_models_batch(
|
||||
raw_models_batch, is_query, inference_batch_size=batch_size
|
||||
)
|
||||
|
||||
def embed_models_strict(
|
||||
self,
|
||||
raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]],
|
||||
batch_size: int = 8,
|
||||
parallel: Optional[int] = None,
|
||||
) -> Iterable[Union[dict[str, BaseModel], BaseModel]]:
|
||||
"""Embed raw data fields in models and return models with vectors
|
||||
|
||||
Requires every input sequences element to contain raw data fields to inference.
|
||||
Does not accept ready vectors.
|
||||
|
||||
Args:
|
||||
raw_models: Iterable[BaseModel] - models which contain fields with raw data to inference
|
||||
batch_size: int - batch size for inference
|
||||
parallel: int - number of parallel processes to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Iterable[Union[dict[str, BaseModel], BaseModel]]: models with embedded fields
|
||||
"""
|
||||
if not self._is_builtin_embedder_available:
|
||||
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
|
||||
|
||||
is_small = False
|
||||
|
||||
if isinstance(raw_models, list):
|
||||
if len(raw_models) < batch_size:
|
||||
is_small = True
|
||||
|
||||
if (
|
||||
isinstance(self.embedder, BuiltinEmbedder)
|
||||
or parallel is None
|
||||
or parallel == 1
|
||||
or is_small
|
||||
):
|
||||
for batch in iter_batch(raw_models, batch_size):
|
||||
yield from self.embed_models_batch(batch, inference_batch_size=batch_size)
|
||||
else:
|
||||
multiprocessing_batch_size = 1 # larger batch sizes do not help with data parallel
|
||||
# on cpu. todo: adjust when multi-gpu is available
|
||||
raw_models_batches = iter_batch(raw_models, size=multiprocessing_batch_size)
|
||||
if parallel == 0:
|
||||
parallel = os.cpu_count()
|
||||
|
||||
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
||||
assert parallel is not None # just a mypy complaint
|
||||
pool = ParallelWorkerPool(
|
||||
num_workers=parallel,
|
||||
worker=self._get_worker_class(),
|
||||
start_method=start_method,
|
||||
max_internal_batch_size=self.MAX_INTERNAL_BATCH_SIZE,
|
||||
)
|
||||
|
||||
for batch in pool.ordered_map(
|
||||
raw_models_batches, batch_size=multiprocessing_batch_size
|
||||
):
|
||||
yield from batch
|
||||
|
||||
def embed_models_batch(
|
||||
self,
|
||||
raw_models: list[Union[dict[str, BaseModel], BaseModel]],
|
||||
is_query: bool = False,
|
||||
inference_batch_size: int = 8,
|
||||
) -> Iterable[BaseModel]:
|
||||
"""Embed a batch of models with raw data fields and return models with vectors
|
||||
|
||||
If any of model fields required inference, a deepcopy of a model with computed embeddings is returned,
|
||||
otherwise returns original models.
|
||||
Args:
|
||||
raw_models: list[Union[dict[str, BaseModel], BaseModel]] - models which can contain fields with raw data
|
||||
is_query: bool - flag to determine which embed method to use. Defaults to False.
|
||||
inference_batch_size: int - batch size for inference
|
||||
Returns:
|
||||
Iterable[BaseModel]: models with embedded fields
|
||||
"""
|
||||
if not self._is_builtin_embedder_available:
|
||||
FastEmbedMisc.import_fastembed() # fail fast if fastembed is required
|
||||
|
||||
for raw_model in raw_models:
|
||||
self._process_model(raw_model, is_query=is_query, accumulating=True)
|
||||
|
||||
if not self._batch_accumulator:
|
||||
yield from raw_models
|
||||
else:
|
||||
yield from (
|
||||
self._process_model(
|
||||
raw_model,
|
||||
is_query=is_query,
|
||||
accumulating=False,
|
||||
inference_batch_size=inference_batch_size,
|
||||
)
|
||||
for raw_model in raw_models
|
||||
)
|
||||
|
||||
def _process_model(
|
||||
self,
|
||||
model: Union[dict[str, BaseModel], BaseModel],
|
||||
paths: Optional[list[FieldPath]] = None,
|
||||
is_query: bool = False,
|
||||
accumulating: bool = False,
|
||||
inference_batch_size: Optional[int] = None,
|
||||
) -> Union[dict[str, BaseModel], dict[str, NumericVector], BaseModel, NumericVector]:
|
||||
"""Embed model's fields requiring inference
|
||||
|
||||
Args:
|
||||
model: Qdrant http model containing fields to embed
|
||||
paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])]
|
||||
is_query: Flag to determine which embed method to use. Defaults to False.
|
||||
accumulating: Flag to determine if we are accumulating models for batch embedding. Defaults to False.
|
||||
inference_batch_size: Optional[int] - batch size for inference
|
||||
|
||||
Returns:
|
||||
A deepcopy of the method with embedded fields
|
||||
"""
|
||||
|
||||
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
if accumulating:
|
||||
self._accumulate(model) # type: ignore
|
||||
else:
|
||||
assert (
|
||||
inference_batch_size is not None
|
||||
), "inference_batch_size should be passed for inference"
|
||||
return self._drain_accumulator(
|
||||
model, # type: ignore
|
||||
is_query=is_query,
|
||||
inference_batch_size=inference_batch_size,
|
||||
)
|
||||
|
||||
if paths is None:
|
||||
model = deepcopy(model) if not accumulating else model
|
||||
|
||||
if isinstance(model, dict):
|
||||
for key, value in model.items():
|
||||
if accumulating:
|
||||
self._process_model(value, paths, accumulating=True)
|
||||
else:
|
||||
model[key] = self._process_model(
|
||||
value,
|
||||
paths,
|
||||
is_query=is_query,
|
||||
accumulating=False,
|
||||
inference_batch_size=inference_batch_size,
|
||||
)
|
||||
return model
|
||||
|
||||
paths = paths if paths is not None else self._embed_inspector.inspect(model)
|
||||
|
||||
for path in paths:
|
||||
list_model = [model] if not isinstance(model, list) else model
|
||||
for item in list_model:
|
||||
current_model = getattr(item, path.current, None)
|
||||
if current_model is None:
|
||||
continue
|
||||
if path.tail:
|
||||
self._process_model(
|
||||
current_model,
|
||||
path.tail,
|
||||
is_query=is_query,
|
||||
accumulating=accumulating,
|
||||
inference_batch_size=inference_batch_size,
|
||||
)
|
||||
else:
|
||||
was_list = isinstance(current_model, list)
|
||||
current_model = current_model if was_list else [current_model]
|
||||
|
||||
if not accumulating:
|
||||
assert (
|
||||
inference_batch_size is not None
|
||||
), "inference_batch_size should be passed for inference"
|
||||
embeddings = [
|
||||
self._drain_accumulator(
|
||||
data, is_query=is_query, inference_batch_size=inference_batch_size
|
||||
)
|
||||
for data in current_model
|
||||
]
|
||||
if was_list:
|
||||
setattr(item, path.current, embeddings)
|
||||
else:
|
||||
setattr(item, path.current, embeddings[0])
|
||||
else:
|
||||
for data in current_model:
|
||||
self._accumulate(data)
|
||||
return model
|
||||
|
||||
def _accumulate(self, data: models.VectorStruct) -> None:
|
||||
"""Add data to batch accumulator
|
||||
|
||||
Args:
|
||||
data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - add them
|
||||
to the accumulator, otherwise - do nothing. `InferenceObject` instances are converted to proper types.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
self._accumulate(value)
|
||||
return None
|
||||
|
||||
if isinstance(data, list):
|
||||
for value in data:
|
||||
if not isinstance(value, get_args(INFERENCE_OBJECT_TYPES)): # if value is a vector
|
||||
return None
|
||||
self._accumulate(value)
|
||||
|
||||
if not isinstance(data, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return None
|
||||
|
||||
data = self._resolve_inference_object(data)
|
||||
if data.model not in self._batch_accumulator:
|
||||
self._batch_accumulator[data.model] = []
|
||||
self._batch_accumulator[data.model].append(data)
|
||||
return None
|
||||
|
||||
def _drain_accumulator(
|
||||
self, data: models.VectorStruct, is_query: bool, inference_batch_size: int = 8
|
||||
) -> NumericVectorStruct:
|
||||
"""Drain accumulator and replaces inference objects with computed embeddings
|
||||
It is assumed objects are traversed in the same order as they were added to the accumulator
|
||||
|
||||
Args:
|
||||
data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - replace
|
||||
them with computed embeddings. If embeddings haven't yet been computed - compute them and then replace
|
||||
inference objects.
|
||||
inference_batch_size: int - batch size for inference
|
||||
|
||||
Returns:
|
||||
NumericVectorStruct: data with replaced inference objects
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
data[key] = self._drain_accumulator(
|
||||
value, is_query=is_query, inference_batch_size=inference_batch_size
|
||||
)
|
||||
return data
|
||||
|
||||
if isinstance(data, list):
|
||||
for i, value in enumerate(data):
|
||||
if not isinstance(value, get_args(INFERENCE_OBJECT_TYPES)): # if value is vector
|
||||
return data
|
||||
|
||||
data[i] = self._drain_accumulator(
|
||||
value, is_query=is_query, inference_batch_size=inference_batch_size
|
||||
)
|
||||
return data
|
||||
|
||||
if not isinstance(
|
||||
data, get_args(INFERENCE_OBJECT_TYPES)
|
||||
): # ide type checker ignores `not` and scolds
|
||||
return data # type: ignore
|
||||
|
||||
if not self._embed_storage or not self._embed_storage.get(data.model, None):
|
||||
self._embed_accumulator(is_query=is_query, inference_batch_size=inference_batch_size)
|
||||
|
||||
return self._next_embed(data.model)
|
||||
|
||||
def _embed_accumulator(self, is_query: bool = False, inference_batch_size: int = 8) -> None:
|
||||
"""Embed all accumulated objects for all models
|
||||
|
||||
Args:
|
||||
is_query: bool - flag to determine which embed method to use. Defaults to False.
|
||||
inference_batch_size: int - batch size for inference
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def embed(
|
||||
objects: list[INFERENCE_OBJECT_TYPES], model_name: str, batch_size: int
|
||||
) -> list[NumericVector]:
|
||||
"""
|
||||
Assemble batches by options and data type based groups, embeds and return embeddings in the original order
|
||||
"""
|
||||
unique_options: list[dict[str, Any]] = []
|
||||
unique_options_is_text: list[bool] = [] # multimodal models can have both text
|
||||
# and image data, we need to track which data we process to construct separate batches for texts and images
|
||||
batches: list[Any] = []
|
||||
group_indices: dict[int, list[int]] = defaultdict(list)
|
||||
for i, obj in enumerate(objects):
|
||||
is_text = isinstance(obj, models.Document)
|
||||
for j, (options, options_is_text) in enumerate(
|
||||
zip(unique_options, unique_options_is_text)
|
||||
):
|
||||
if options == obj.options and is_text == options_is_text:
|
||||
group_indices[j].append(i)
|
||||
batches[j].append(obj.text if is_text else obj.image)
|
||||
break
|
||||
else:
|
||||
# Create a new group if no match was found
|
||||
group_indices[len(unique_options)] = [i]
|
||||
unique_options.append(obj.options)
|
||||
unique_options_is_text.append(is_text)
|
||||
batches.append([obj.text if is_text else obj.image])
|
||||
|
||||
embeddings = []
|
||||
for i, (options, is_text) in enumerate(zip(unique_options, unique_options_is_text)):
|
||||
embeddings.extend(
|
||||
[
|
||||
embedding
|
||||
for embedding in self.embedder.embed(
|
||||
model_name=model_name,
|
||||
texts=batches[i] if is_text else None,
|
||||
images=batches[i] if not is_text else None,
|
||||
is_query=is_query,
|
||||
options=options or {},
|
||||
batch_size=batch_size,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
iter_embeddings = iter(embeddings)
|
||||
ordered_embeddings: list[list[NumericVector]] = [[]] * len(objects)
|
||||
for indices in group_indices.values():
|
||||
for index in indices:
|
||||
ordered_embeddings[index] = next(iter_embeddings)
|
||||
return ordered_embeddings
|
||||
|
||||
for model in self._batch_accumulator:
|
||||
if not any(
|
||||
(
|
||||
self.embedder.is_supported_text_model(model),
|
||||
self.embedder.is_supported_sparse_model(model),
|
||||
self.embedder.is_supported_late_interaction_text_model(model),
|
||||
self.embedder.is_supported_image_model(model),
|
||||
self.embedder.is_supported_late_interaction_multimodal_model(model),
|
||||
)
|
||||
):
|
||||
if isinstance(self.embedder, BuiltinEmbedder):
|
||||
raise ValueError(
|
||||
f"{model} is not among supported models. "
|
||||
f"Have you forgotten to set `cloud_inference` or install `fastembed` for local inference?"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{model} is not among supported models")
|
||||
|
||||
for model, data in self._batch_accumulator.items():
|
||||
self._embed_storage[model] = embed(
|
||||
objects=data, model_name=model, batch_size=inference_batch_size
|
||||
)
|
||||
self._batch_accumulator.clear()
|
||||
|
||||
def _next_embed(self, model_name: str) -> NumericVector:
|
||||
"""Get next computed embedding from embedded batch
|
||||
|
||||
Args:
|
||||
model_name: str - retrieve embedding from the storage by this model name
|
||||
|
||||
Returns:
|
||||
NumericVector: computed embedding
|
||||
"""
|
||||
return self._embed_storage[model_name].pop(0)
|
||||
|
||||
def _resolve_inference_object(self, data: models.VectorStruct) -> models.VectorStruct:
|
||||
"""Resolve inference object into a model
|
||||
|
||||
Args:
|
||||
data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type,
|
||||
otherwise - keep unchanged
|
||||
|
||||
Returns:
|
||||
models.VectorStruct: resolved data
|
||||
"""
|
||||
|
||||
if not isinstance(data, models.InferenceObject):
|
||||
return data
|
||||
|
||||
model_name = data.model
|
||||
value = data.object
|
||||
options = data.options
|
||||
if any(
|
||||
(
|
||||
self.embedder.is_supported_text_model(model_name),
|
||||
self.embedder.is_supported_sparse_model(model_name),
|
||||
self.embedder.is_supported_late_interaction_text_model(model_name),
|
||||
)
|
||||
):
|
||||
return models.Document(model=model_name, text=value, options=options)
|
||||
if self.embedder.is_supported_image_model(model_name):
|
||||
return models.Image(model=model_name, image=value, options=options)
|
||||
if self.embedder.is_supported_late_interaction_multimodal_model(model_name):
|
||||
raise ValueError(f"{model_name} does not support `InferenceObject` interface")
|
||||
|
||||
raise ValueError(f"{model_name} is not among supported models")
|
||||
|
||||
@classmethod
|
||||
def _get_worker_class(cls) -> Type[ModelEmbedderWorker]:
|
||||
return ModelEmbedderWorker
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import StrictFloat, StrictStr
|
||||
|
||||
from qdrant_client.http.models import ExtendedPointId, SparseVector
|
||||
|
||||
|
||||
NumericVector = Union[
|
||||
list[StrictFloat],
|
||||
SparseVector,
|
||||
list[list[StrictFloat]],
|
||||
]
|
||||
NumericVectorInput = Union[
|
||||
list[StrictFloat],
|
||||
SparseVector,
|
||||
list[list[StrictFloat]],
|
||||
ExtendedPointId,
|
||||
]
|
||||
NumericVectorStruct = Union[
|
||||
list[StrictFloat],
|
||||
list[list[StrictFloat]],
|
||||
dict[StrictStr, NumericVector],
|
||||
]
|
||||
|
||||
__all__ = ["NumericVector", "NumericVectorInput", "NumericVectorStruct"]
|
||||
@@ -0,0 +1,305 @@
|
||||
from copy import copy, deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Type, Union, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client._pydantic_compat import model_json_schema
|
||||
from qdrant_client.embed.utils import FieldPath, convert_paths
|
||||
|
||||
|
||||
try:
|
||||
from qdrant_client.embed._inspection_cache import (
|
||||
DEFS,
|
||||
CACHE_STR_PATH,
|
||||
RECURSIVE_REFS,
|
||||
EXCLUDED_RECURSIVE_REFS,
|
||||
INCLUDED_RECURSIVE_REFS,
|
||||
NAME_RECURSIVE_REF_MAPPING,
|
||||
)
|
||||
except ImportError as e:
|
||||
DEFS = {}
|
||||
CACHE_STR_PATH = {}
|
||||
RECURSIVE_REFS = set() # type: ignore
|
||||
EXCLUDED_RECURSIVE_REFS = {"Filter"} # type: ignore
|
||||
INCLUDED_RECURSIVE_REFS = set() # type: ignore
|
||||
NAME_RECURSIVE_REF_MAPPING = {}
|
||||
|
||||
|
||||
class ModelSchemaParser:
|
||||
"""Model schema parser. Parses json schemas to retrieve paths to objects requiring inference.
|
||||
|
||||
The parser is stateful, it accumulates the results of parsing in its internal structures.
|
||||
|
||||
Attributes:
|
||||
_defs: definitions extracted from json schemas
|
||||
_recursive_refs: set of recursive refs found in the processed schemas, e.g.:
|
||||
{"Filter", "Prefetch"}
|
||||
_excluded_recursive_refs: predefined time-consuming recursive refs which don't have inference objects, e.g.:
|
||||
{"Filter"}
|
||||
_included_recursive_refs: set of recursive refs which have inference objects, e.g.:
|
||||
{"Prefetch"}
|
||||
_cache: cache of string paths for models containing objects for inference, e.g.:
|
||||
{"Prefetch": ['prefetch.query', 'prefetch.query.context.negative', ...]}
|
||||
path_cache: cache of FieldPath objects for models containing objects for inference, e.g.:
|
||||
{
|
||||
"Prefetch": [
|
||||
FieldPath(
|
||||
current="prefetch",
|
||||
tail=[
|
||||
FieldPath(
|
||||
current="query",
|
||||
tail=[
|
||||
FieldPath(
|
||||
current="recommend",
|
||||
tail=[
|
||||
FieldPath(current="negative", tail=None),
|
||||
FieldPath(current="positive", tail=None),
|
||||
],
|
||||
),
|
||||
...,
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
}
|
||||
name_recursive_ref_mapping: mapping of model field names to ref names, e.g.:
|
||||
{"prefetch": "Prefetch"}
|
||||
"""
|
||||
|
||||
CACHE_PATH = "_inspection_cache.py"
|
||||
INFERENCE_OBJECT_NAMES = {"Document", "Image", "InferenceObject"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
# self._defs does not include the whole schema, but only the part with the structures used in $defs
|
||||
self._defs: dict[str, Union[dict[str, Any], list[dict[str, Any]]]] = deepcopy(DEFS) # type: ignore[arg-type]
|
||||
self._cache: dict[str, list[str]] = deepcopy(CACHE_STR_PATH)
|
||||
|
||||
self._recursive_refs: set[str] = set(RECURSIVE_REFS)
|
||||
self._excluded_recursive_refs: set[str] = set(EXCLUDED_RECURSIVE_REFS)
|
||||
self._included_recursive_refs: set[str] = set(INCLUDED_RECURSIVE_REFS)
|
||||
|
||||
self.name_recursive_ref_mapping: dict[str, str] = {
|
||||
k: v for k, v in NAME_RECURSIVE_REF_MAPPING.items()
|
||||
}
|
||||
self.path_cache: dict[str, list[FieldPath]] = {
|
||||
model: convert_paths(paths) for model, paths in self._cache.items()
|
||||
}
|
||||
self._processed_recursive_defs: dict[str, Any] = {}
|
||||
|
||||
def _replace_refs(
|
||||
self,
|
||||
schema: Union[dict[str, Any], list[dict[str, Any]]],
|
||||
parent: Optional[str] = None,
|
||||
seen_refs: Optional[set] = None,
|
||||
) -> Union[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Replace refs in schema with their definitions
|
||||
|
||||
Args:
|
||||
schema: schema to parse
|
||||
parent: previous level key
|
||||
seen_refs: set of seen refs to spot recursive paths
|
||||
|
||||
Returns:
|
||||
schema with replaced refs
|
||||
"""
|
||||
parent = parent if parent else None
|
||||
seen_refs = seen_refs if seen_refs else set()
|
||||
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"]
|
||||
def_key = ref_path.split("/")[-1]
|
||||
if def_key in self._processed_recursive_defs:
|
||||
return self._processed_recursive_defs[def_key]
|
||||
|
||||
if def_key == parent or def_key in seen_refs:
|
||||
self._recursive_refs.add(def_key)
|
||||
self._processed_recursive_defs[def_key] = schema
|
||||
return schema
|
||||
|
||||
seen_refs.add(def_key)
|
||||
|
||||
return self._replace_refs(
|
||||
self._defs[def_key], parent=def_key, seen_refs=copy(seen_refs)
|
||||
)
|
||||
|
||||
schemes = {}
|
||||
if "properties" in schema:
|
||||
for k, v in schema.items():
|
||||
if k == "properties":
|
||||
schemes[k] = self._replace_refs(
|
||||
schema=v, parent=parent, seen_refs=copy(seen_refs)
|
||||
)
|
||||
else:
|
||||
schemes[k] = v
|
||||
else:
|
||||
for k, v in schema.items():
|
||||
parent_key = k if isinstance(v, dict) and "properties" in v else parent
|
||||
schemes[k] = self._replace_refs(
|
||||
schema=v, parent=parent_key, seen_refs=copy(seen_refs)
|
||||
)
|
||||
|
||||
return schemes
|
||||
elif isinstance(schema, list):
|
||||
return [
|
||||
self._replace_refs(schema=item, parent=parent, seen_refs=copy(seen_refs)) # type: ignore
|
||||
for item in schema
|
||||
]
|
||||
else:
|
||||
return schema
|
||||
|
||||
def _find_document_paths(
|
||||
self,
|
||||
schema: Union[dict[str, Any], list[dict[str, Any]]],
|
||||
current_path: str = "",
|
||||
after_properties: bool = False,
|
||||
seen_refs: Optional[set] = None,
|
||||
) -> list[str]:
|
||||
"""Read a schema and find paths to objects requiring inference
|
||||
|
||||
Populates model fields names to ref names mapping
|
||||
|
||||
Args:
|
||||
schema: schema to parse
|
||||
current_path: current path in the schema
|
||||
after_properties: flag indicating if the current path is after "properties" key
|
||||
seen_refs: set of seen refs to spot recursive paths
|
||||
|
||||
Returns:
|
||||
List of string dot separated paths to objects requiring inference
|
||||
"""
|
||||
document_paths: list[str] = []
|
||||
seen_recursive_refs = seen_refs if seen_refs is not None else set()
|
||||
|
||||
parts = current_path.split(".")
|
||||
if len(parts) != len(set(parts)): # check for recursive paths
|
||||
return document_paths
|
||||
|
||||
if not isinstance(schema, dict):
|
||||
return document_paths
|
||||
|
||||
if "title" in schema and schema["title"] in self.INFERENCE_OBJECT_NAMES:
|
||||
document_paths.append(current_path)
|
||||
return document_paths
|
||||
|
||||
for key, value in schema.items():
|
||||
if key == "$defs":
|
||||
continue
|
||||
|
||||
if key == "$ref":
|
||||
model_name = value.split("/")[-1]
|
||||
|
||||
value = self._defs[model_name]
|
||||
if model_name in seen_recursive_refs:
|
||||
continue
|
||||
|
||||
if (
|
||||
model_name in self._excluded_recursive_refs
|
||||
): # on the first run it might be empty
|
||||
continue
|
||||
|
||||
if (
|
||||
model_name in self._recursive_refs
|
||||
): # included and excluded refs might not be filled up yet, we're looking in all recursive refs
|
||||
# we would need to clean up name recursive ref mapping later and delete excluded refs from there
|
||||
seen_recursive_refs.add(model_name)
|
||||
self.name_recursive_ref_mapping[current_path.split(".")[-1]] = model_name
|
||||
|
||||
if after_properties: # field name seen in pydantic models comes after "properties" key
|
||||
if current_path:
|
||||
new_path = f"{current_path}.{key}"
|
||||
else:
|
||||
new_path = key
|
||||
else:
|
||||
new_path = current_path
|
||||
|
||||
if isinstance(value, dict):
|
||||
document_paths.extend(
|
||||
self._find_document_paths(
|
||||
value, new_path, key == "properties", seen_refs=seen_recursive_refs
|
||||
)
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
document_paths.extend(
|
||||
self._find_document_paths(
|
||||
item,
|
||||
new_path,
|
||||
key == "properties",
|
||||
seen_refs=seen_recursive_refs,
|
||||
)
|
||||
)
|
||||
|
||||
return sorted(set(document_paths))
|
||||
|
||||
def parse_model(self, model: Type[BaseModel]) -> None:
|
||||
"""Parse model schema to retrieve paths to objects requiring inference.
|
||||
|
||||
Checks model json schema, extracts definitions and finds paths to objects requiring inference.
|
||||
No parsing happens if model has already been processed.
|
||||
|
||||
Args:
|
||||
model: model to parse
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
model_name = model.__name__
|
||||
if model_name in self._cache:
|
||||
return None
|
||||
|
||||
schema = model_json_schema(model)
|
||||
|
||||
for k, v in schema.get("$defs", {}).items():
|
||||
if k not in self._defs:
|
||||
self._defs[k] = v
|
||||
|
||||
if "$defs" in schema:
|
||||
raw_refs = (
|
||||
{"$ref": schema["$ref"]}
|
||||
if "$ref" in schema
|
||||
else {"properties": schema["properties"]}
|
||||
)
|
||||
refs = self._replace_refs(raw_refs)
|
||||
self._cache[model_name] = self._find_document_paths(refs)
|
||||
else:
|
||||
self._cache[model_name] = []
|
||||
|
||||
for ref in self._recursive_refs:
|
||||
if ref in self._excluded_recursive_refs or ref in self._included_recursive_refs:
|
||||
continue
|
||||
|
||||
if self._find_document_paths(self._defs[ref]):
|
||||
self._included_recursive_refs.add(ref)
|
||||
else:
|
||||
self._excluded_recursive_refs.add(ref)
|
||||
|
||||
self.name_recursive_ref_mapping = {
|
||||
k: v
|
||||
for k, v in self.name_recursive_ref_mapping.items()
|
||||
if v not in self._excluded_recursive_refs
|
||||
}
|
||||
|
||||
# convert str paths to FieldPath objects which group path parts and reduce the time of the traversal
|
||||
self.path_cache = {model: convert_paths(paths) for model, paths in self._cache.items()}
|
||||
|
||||
def _persist(self, output_path: Union[Path, str] = CACHE_PATH) -> None:
|
||||
"""Persist the parser state to a file
|
||||
|
||||
Args:
|
||||
output_path: path to the file to save the parser state
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with open(output_path, "w") as f:
|
||||
f.write(f"CACHE_STR_PATH = {self._cache}\n")
|
||||
f.write(f"DEFS = {self._defs}\n")
|
||||
# `sorted is required` to use `diff` in comparisons
|
||||
f.write(f"RECURSIVE_REFS = {sorted(self._recursive_refs)}\n")
|
||||
f.write(f"INCLUDED_RECURSIVE_REFS = {sorted(self._included_recursive_refs)}\n")
|
||||
f.write(f"EXCLUDED_RECURSIVE_REFS = {sorted(self._excluded_recursive_refs)}\n")
|
||||
f.write(f"NAME_RECURSIVE_REF_MAPPING = {self.name_recursive_ref_mapping}\n")
|
||||
@@ -0,0 +1,149 @@
|
||||
from typing import Union, Optional, Iterable, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from qdrant_client._pydantic_compat import model_fields_set
|
||||
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
|
||||
|
||||
from qdrant_client.embed.schema_parser import ModelSchemaParser
|
||||
from qdrant_client.embed.utils import FieldPath
|
||||
|
||||
|
||||
class Inspector:
|
||||
"""Inspector which tries to find at least one occurrence of an object requiring inference
|
||||
|
||||
Inspector is stateful and accumulates parsed model schemes in its parser.
|
||||
|
||||
Attributes:
|
||||
parser: ModelSchemaParser instance to inspect model json schemas
|
||||
"""
|
||||
|
||||
def __init__(self, parser: Optional[ModelSchemaParser] = None) -> None:
|
||||
self.parser = ModelSchemaParser() if parser is None else parser
|
||||
|
||||
def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> bool:
|
||||
"""Looks for at least one occurrence of an object requiring inference in the received models
|
||||
|
||||
Args:
|
||||
points: models to inspect
|
||||
|
||||
Returns:
|
||||
True if at least one object requiring inference is found, False otherwise
|
||||
"""
|
||||
if isinstance(points, BaseModel):
|
||||
self.parser.parse_model(points.__class__)
|
||||
return self._inspect_model(points)
|
||||
|
||||
elif isinstance(points, dict):
|
||||
for value in points.values():
|
||||
if self.inspect(value):
|
||||
return True
|
||||
|
||||
elif isinstance(points, Iterable):
|
||||
for point in points:
|
||||
if isinstance(point, BaseModel):
|
||||
self.parser.parse_model(point.__class__)
|
||||
if self._inspect_model(point):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def _inspect_model(self, model: BaseModel, paths: Optional[list[FieldPath]] = None) -> bool:
|
||||
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return True
|
||||
|
||||
paths = (
|
||||
self.parser.path_cache.get(model.__class__.__name__, []) if paths is None else paths
|
||||
)
|
||||
|
||||
for path in paths:
|
||||
type_found = self._inspect_inner_models(
|
||||
model, path.current, path.tail if path.tail else []
|
||||
)
|
||||
if type_found:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _inspect_inner_models(
|
||||
self, original_model: BaseModel, current_path: str, tail: list[FieldPath]
|
||||
) -> bool:
|
||||
def inspect_recursive(member: BaseModel) -> bool:
|
||||
recursive_paths = []
|
||||
for field_name in model_fields_set(member):
|
||||
if field_name in self.parser.name_recursive_ref_mapping:
|
||||
mapped_model_name = self.parser.name_recursive_ref_mapping[field_name]
|
||||
recursive_paths.extend(self.parser.path_cache[mapped_model_name])
|
||||
|
||||
if recursive_paths:
|
||||
found = self._inspect_model(member, recursive_paths)
|
||||
if found:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
model = getattr(original_model, current_path, None)
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
if isinstance(model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return True
|
||||
|
||||
if isinstance(model, BaseModel):
|
||||
type_found = inspect_recursive(model)
|
||||
if type_found:
|
||||
return True
|
||||
|
||||
for next_path in tail:
|
||||
type_found = self._inspect_inner_models(
|
||||
model, next_path.current, next_path.tail if next_path.tail else []
|
||||
)
|
||||
if type_found:
|
||||
return True
|
||||
return False
|
||||
|
||||
elif isinstance(model, list):
|
||||
for current_model in model:
|
||||
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return True
|
||||
|
||||
if not isinstance(current_model, BaseModel):
|
||||
continue
|
||||
|
||||
type_found = inspect_recursive(current_model)
|
||||
if type_found:
|
||||
return True
|
||||
|
||||
for next_path in tail:
|
||||
for current_model in model:
|
||||
type_found = self._inspect_inner_models(
|
||||
current_model, next_path.current, next_path.tail if next_path.tail else []
|
||||
)
|
||||
if type_found:
|
||||
return True
|
||||
return False
|
||||
|
||||
elif isinstance(model, dict):
|
||||
for key, values in model.items():
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
for current_model in values:
|
||||
if isinstance(current_model, get_args(INFERENCE_OBJECT_TYPES)):
|
||||
return True
|
||||
|
||||
if not isinstance(current_model, BaseModel):
|
||||
continue
|
||||
|
||||
found_type = inspect_recursive(current_model)
|
||||
if found_type:
|
||||
return True
|
||||
|
||||
for next_path in tail:
|
||||
for current_model in values:
|
||||
found_type = self._inspect_inner_models(
|
||||
current_model,
|
||||
next_path.current,
|
||||
next_path.tail if next_path.tail else [],
|
||||
)
|
||||
if found_type:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,79 @@
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FieldPath(BaseModel):
|
||||
current: str
|
||||
tail: Optional[list["FieldPath"]] = Field(default=None)
|
||||
|
||||
def as_str_list(self) -> list[str]:
|
||||
"""
|
||||
>>> FieldPath(current='a', tail=[FieldPath(current='b', tail=[FieldPath(current='c'), FieldPath(current='d')])]).as_str_list()
|
||||
['a.b.c', 'a.b.d']
|
||||
"""
|
||||
|
||||
# Recursive function to collect all paths
|
||||
def collect_paths(path: FieldPath, prefix: str = "") -> list[str]:
|
||||
current_path = prefix + path.current
|
||||
if not path.tail:
|
||||
return [current_path]
|
||||
else:
|
||||
paths = []
|
||||
for sub_path in path.tail:
|
||||
paths.extend(collect_paths(sub_path, current_path + "."))
|
||||
return paths
|
||||
|
||||
# Collect all paths starting from this object
|
||||
return collect_paths(self)
|
||||
|
||||
|
||||
def convert_paths(paths: list[str]) -> list[FieldPath]:
|
||||
"""Convert string paths into FieldPath objects
|
||||
|
||||
Paths which share the same root are grouped together.
|
||||
|
||||
Args:
|
||||
paths: List[str]: List of str paths containing "." as separator
|
||||
|
||||
Returns:
|
||||
List[FieldPath]: List of FieldPath objects
|
||||
"""
|
||||
sorted_paths = sorted(paths)
|
||||
prev_root = None
|
||||
converted_paths = []
|
||||
for path in sorted_paths:
|
||||
parts = path.split(".")
|
||||
root = parts[0]
|
||||
if root != prev_root:
|
||||
converted_paths.append(FieldPath(current=root))
|
||||
prev_root = root
|
||||
current = converted_paths[-1]
|
||||
for part in parts[1:]:
|
||||
if current.tail is None:
|
||||
current.tail = []
|
||||
found = False
|
||||
for tail in current.tail:
|
||||
if tail.current == part:
|
||||
current = tail
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
new_tail = FieldPath(current=part)
|
||||
assert current.tail is not None
|
||||
current.tail.append(new_tail)
|
||||
current = new_tail
|
||||
return converted_paths
|
||||
|
||||
|
||||
def read_base64(file_path: Union[str, Path]) -> str:
|
||||
"""Convert a file path to a base64 encoded string."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"The file {path} does not exist.")
|
||||
|
||||
with open(path, "rb") as file:
|
||||
file_content = file.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
Reference in New Issue
Block a user