mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-12 23:23:21 +00:00
fixed the RAg in test pipeline issue
This commit is contained in:
@@ -1,9 +1,43 @@
|
||||
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ordered fallback chain — tried in sequence when the primary model fails.
|
||||
# Fireworks models are used for free/starter plans so they must always be available.
|
||||
# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working).
|
||||
_FIREWORKS_FALLBACKS = [
|
||||
"accounts/fireworks/models/kimi-k2p5-instruct",
|
||||
"accounts/fireworks/models/deepseek-v3p2",
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_model(model: str) -> str:
|
||||
"""Strip date-based version suffixes from Fireworks model IDs.
|
||||
e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905' → 'accounts/fireworks/models/kimi-k2-instruct'
|
||||
Matches only purely-numeric suffixes (4–8 digits) so names like 'llama-v3p3-70b' are untouched."""
|
||||
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
|
||||
model = re.sub(r"-\d{4,8}$", "", model)
|
||||
return model
|
||||
|
||||
|
||||
def _infer_provider(model: str) -> str:
|
||||
"""Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS.
|
||||
Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'."""
|
||||
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
|
||||
return "fireworks"
|
||||
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"):
|
||||
return "openai"
|
||||
if model.startswith("claude-"):
|
||||
return "anthropic"
|
||||
if model.startswith("gemini-"):
|
||||
return "google"
|
||||
logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks")
|
||||
return "fireworks"
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""Routes requests to appropriate LLM provider"""
|
||||
@@ -16,7 +50,8 @@ class LLMService:
|
||||
temperature: float = 0.7,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a response from the LLM"""
|
||||
provider = MODEL_PROVIDERS.get(model, "openai")
|
||||
model = _normalize_model(model)
|
||||
provider = MODEL_PROVIDERS.get(model) or _infer_provider(model)
|
||||
|
||||
try:
|
||||
if provider == "fireworks":
|
||||
@@ -31,9 +66,16 @@ class LLMService:
|
||||
return await self._call_openai(messages, model, max_tokens, temperature)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM error ({provider}/{model}): {e}")
|
||||
fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct"
|
||||
if model != fallback and settings.fireworks_api_key:
|
||||
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
|
||||
if not settings.fireworks_api_key:
|
||||
raise
|
||||
for fallback in _FIREWORKS_FALLBACKS:
|
||||
if model == fallback:
|
||||
continue
|
||||
try:
|
||||
logger.warning(f"[LLM] Falling back to {fallback}")
|
||||
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
|
||||
except Exception as fe:
|
||||
logger.error(f"[LLM] Fallback {fallback} also failed: {fe}")
|
||||
raise
|
||||
|
||||
async def _call_fireworks(
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}.
|
||||
Your role is to answer questions based on the provided context from company documents.
|
||||
Your role is to answer questions based on the provided context from the knowledge base (documents and web pages).
|
||||
|
||||
IMPORTANT RULES:
|
||||
1. Answer based on the provided context below
|
||||
@@ -20,7 +20,7 @@ IMPORTANT RULES:
|
||||
{language_instruction}
|
||||
{custom_instructions}
|
||||
|
||||
Context from knowledge base:
|
||||
Knowledge base context:
|
||||
{context}
|
||||
"""
|
||||
|
||||
@@ -74,14 +74,22 @@ class RAGEngine:
|
||||
}
|
||||
|
||||
# Step 2: Retrieve relevant chunks
|
||||
# Fetch more than needed so that after filtering low-quality results
|
||||
# we still have enough context. score_threshold=0.55 keeps only chunks
|
||||
# that are genuinely relevant for text-embedding-3-small cosine similarity.
|
||||
# Retrieve more candidates than needed (10) with a slightly relaxed threshold (0.45)
|
||||
# so that content from both document and URL sources gets fair representation.
|
||||
# Scraped web text embeds less cleanly than structured documents, so 0.55 was
|
||||
# filtering out valid URL chunks. Context is capped by char limit below.
|
||||
total_in_collection = self.vector_svc.count_vectors(collection_name)
|
||||
logger.info(f"[RAG] Collection '{collection_name}' has {total_in_collection} vectors total")
|
||||
|
||||
# No score_threshold — always return the top-N most similar chunks by rank.
|
||||
# Absolute cosine scores vary widely by document type and embedding model;
|
||||
# filtering by a fixed cutoff here discards valid context when scores are
|
||||
# uniformly low. The confidence_score below captures retrieval quality for
|
||||
# handoff/fallback decisions without silencing the LLM's context.
|
||||
retrieved = self.vector_svc.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_embedding,
|
||||
limit=8,
|
||||
score_threshold=0.55,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'")
|
||||
@@ -90,25 +98,38 @@ class RAGEngine:
|
||||
text_preview = item.get("payload", {}).get("text", "")[:80]
|
||||
logger.info(f"[RAG] Chunk {i+1}: score={score:.4f}, preview='{text_preview}...'")
|
||||
|
||||
# Step 3: Build sources
|
||||
# Step 3: Build sources and labeled context
|
||||
# Each chunk is prefixed with its source so the LLM can synthesize
|
||||
# correctly when mixing document and URL content.
|
||||
MAX_CONTEXT_CHARS = 10_000
|
||||
sources = []
|
||||
context_parts = []
|
||||
seen_texts = set()
|
||||
total_chars = 0
|
||||
|
||||
for item in retrieved:
|
||||
payload = item.get("payload", {})
|
||||
text = payload.get("text", "")
|
||||
if text and text not in seen_texts:
|
||||
seen_texts.add(text)
|
||||
context_parts.append(text)
|
||||
sources.append(
|
||||
SourceDocument(
|
||||
document_name=payload.get("file_name", "Document"),
|
||||
chunk_text=text[:200] + "..." if len(text) > 200 else text,
|
||||
score=item.get("score", 0.0),
|
||||
page_number=payload.get("page_number"),
|
||||
)
|
||||
if not text or text in seen_texts:
|
||||
continue
|
||||
if total_chars + len(text) > MAX_CONTEXT_CHARS:
|
||||
break
|
||||
seen_texts.add(text)
|
||||
total_chars += len(text)
|
||||
|
||||
file_name = payload.get("file_name", "Document")
|
||||
source_url = payload.get("source_url")
|
||||
label = f"[Source: {source_url}]" if source_url else f"[Source: {file_name}]"
|
||||
context_parts.append(f"{label}\n{text}")
|
||||
|
||||
sources.append(
|
||||
SourceDocument(
|
||||
document_name=file_name,
|
||||
chunk_text=text[:200] + "..." if len(text) > 200 else text,
|
||||
score=item.get("score", 0.0),
|
||||
page_number=payload.get("page_number"),
|
||||
)
|
||||
)
|
||||
|
||||
if context_parts:
|
||||
context = "\n\n---\n\n".join(context_parts)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from qdrant_client import QdrantClient, models
|
||||
from qdrant_client.http.models import (
|
||||
Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
||||
)
|
||||
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
||||
from app.config import settings
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
@@ -103,15 +101,13 @@ class VectorStoreService:
|
||||
collection_name: str,
|
||||
query_vector: List[float],
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.3,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar vectors"""
|
||||
"""Search for similar vectors, returning the top-N by cosine score."""
|
||||
try:
|
||||
results = self.client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vector,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
).points
|
||||
return [
|
||||
{
|
||||
@@ -122,7 +118,7 @@ class VectorStoreService:
|
||||
for r in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching vectors: {e}")
|
||||
logger.error(f"Error searching vectors in '{collection_name}': {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def delete_by_document_id(self, collection_name: str, document_id: str) -> bool:
|
||||
@@ -131,19 +127,21 @@ class VectorStoreService:
|
||||
self.client.delete(
|
||||
collection_name=collection_name,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=Filter(
|
||||
filter=models.Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
models.FieldCondition(
|
||||
key="document_id",
|
||||
match=MatchValue(value=document_id),
|
||||
match=models.MatchValue(value=document_id),
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
wait=True,
|
||||
)
|
||||
logger.info(f"Deleted vectors for document '{document_id}' from '{collection_name}'")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting document vectors: {e}")
|
||||
logger.error(f"Error deleting vectors for document '{document_id}' in '{collection_name}': {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def count_vectors(self, collection_name: str) -> int:
|
||||
|
||||
@@ -42,9 +42,22 @@ async def scrape_url(url: str) -> dict:
|
||||
main = soup.find("main") or soup.find("article") or soup.find("body") or soup
|
||||
text = main.get_text(separator="\n", strip=True)
|
||||
|
||||
# Clean up whitespace
|
||||
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
||||
text = "\n".join(lines)
|
||||
# Clean up whitespace and filter structural noise
|
||||
seen_lines: set[str] = set()
|
||||
clean_lines = []
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Skip very short lines (nav items, button labels, breadcrumb separators)
|
||||
if len(line) < 15:
|
||||
continue
|
||||
# Skip duplicate lines (nav/footer repeated across sections)
|
||||
if line in seen_lines:
|
||||
continue
|
||||
seen_lines.add(line)
|
||||
clean_lines.append(line)
|
||||
text = "\n".join(clean_lines)
|
||||
|
||||
# Limit size
|
||||
if len(text.encode("utf-8")) > MAX_TEXT_BYTES:
|
||||
|
||||
Reference in New Issue
Block a user