mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-12 23:23:21 +00:00
157 lines
5.1 KiB
Python
157 lines
5.1 KiB
Python
from qdrant_client import QdrantClient, models
|
|
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
|
from app.config import settings
|
|
from typing import List, Dict, Any, Optional
|
|
import logging
|
|
import uuid
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_qdrant_client: QdrantClient = None
|
|
|
|
|
|
def get_qdrant_client() -> QdrantClient:
|
|
global _qdrant_client
|
|
if _qdrant_client is None:
|
|
kwargs = {"url": settings.qdrant_url}
|
|
if settings.qdrant_api_key:
|
|
kwargs["api_key"] = settings.qdrant_api_key
|
|
_qdrant_client = QdrantClient(**kwargs)
|
|
return _qdrant_client
|
|
|
|
|
|
class VectorStoreService:
|
|
VECTOR_SIZE = 1536 # text-embedding-3-small
|
|
|
|
def __init__(self):
|
|
self.client = get_qdrant_client()
|
|
|
|
def create_collection(self, collection_name: str) -> bool:
|
|
try:
|
|
self.client.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=VectorParams(
|
|
size=self.VECTOR_SIZE,
|
|
distance=Distance.COSINE,
|
|
),
|
|
)
|
|
# Create payload index for filtering/deleting by document_id
|
|
self.client.create_payload_index(
|
|
collection_name=collection_name,
|
|
field_name="document_id",
|
|
field_schema="keyword",
|
|
)
|
|
logger.info(f"Created collection: {collection_name}")
|
|
return True
|
|
except Exception as e:
|
|
if "already exists" in str(e).lower():
|
|
return True
|
|
logger.error(f"Error creating collection {collection_name}: {e}")
|
|
raise
|
|
|
|
def delete_collection(self, collection_name: str) -> bool:
|
|
"""Delete a chatbot's collection"""
|
|
try:
|
|
self.client.delete_collection(collection_name=collection_name)
|
|
logger.info(f"Deleted collection: {collection_name}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting collection {collection_name}: {e}")
|
|
return False
|
|
|
|
def collection_exists(self, collection_name: str) -> bool:
|
|
try:
|
|
self.client.get_collection(collection_name)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def upsert_vectors(
|
|
self,
|
|
collection_name: str,
|
|
vectors: List[List[float]],
|
|
payloads: List[Dict[str, Any]],
|
|
ids: Optional[List[str]] = None,
|
|
) -> bool:
|
|
"""Upsert vectors into collection"""
|
|
if ids is None:
|
|
ids = [str(uuid.uuid4()) for _ in vectors]
|
|
|
|
points = [
|
|
PointStruct(
|
|
id=idx,
|
|
vector=vector,
|
|
payload=payload,
|
|
)
|
|
for idx, vector, payload in zip(ids, vectors, payloads)
|
|
]
|
|
|
|
try:
|
|
self.client.upsert(
|
|
collection_name=collection_name,
|
|
points=points,
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error upserting vectors: {e}")
|
|
raise
|
|
|
|
def search(
|
|
self,
|
|
collection_name: str,
|
|
query_vector: List[float],
|
|
limit: int = 5,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Search for similar vectors, returning the top-N by cosine score."""
|
|
try:
|
|
results = self.client.query_points(
|
|
collection_name=collection_name,
|
|
query=query_vector,
|
|
limit=limit,
|
|
).points
|
|
return [
|
|
{
|
|
"id": str(r.id),
|
|
"score": r.score,
|
|
"payload": r.payload,
|
|
}
|
|
for r in results
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"Error searching vectors in '{collection_name}': {e}", exc_info=True)
|
|
return []
|
|
|
|
def delete_by_document_id(self, collection_name: str, document_id: str) -> bool:
|
|
"""Delete all vectors for a document"""
|
|
try:
|
|
self.client.delete(
|
|
collection_name=collection_name,
|
|
points_selector=models.FilterSelector(
|
|
filter=models.Filter(
|
|
must=[
|
|
models.FieldCondition(
|
|
key="document_id",
|
|
match=models.MatchValue(value=document_id),
|
|
)
|
|
]
|
|
)
|
|
),
|
|
wait=True,
|
|
)
|
|
logger.info(f"Deleted vectors for document '{document_id}' from '{collection_name}'")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting vectors for document '{document_id}' in '{collection_name}': {e}", exc_info=True)
|
|
return False
|
|
|
|
def count_vectors(self, collection_name: str) -> int:
|
|
"""Count vectors in a collection"""
|
|
try:
|
|
result = self.client.count(collection_name=collection_name)
|
|
return result.count
|
|
except Exception:
|
|
return 0
|
|
|
|
|
|
vector_store = VectorStoreService()
|