fixed the RAg in test pipeline issue

This commit is contained in:
belviskhoremk
2026-04-26 21:43:19 +00:00
parent 78023ae9c5
commit 260a9c6353
9 changed files with 262 additions and 78 deletions

View File

@@ -1,9 +1,43 @@
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
from typing import List, Dict, Any, Optional, AsyncGenerator
import logging
import re
logger = logging.getLogger(__name__)
# Ordered fallback chain — tried in sequence when the primary model fails.
# Fireworks models are used for free/starter plans so they must always be available.
# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working).
_FIREWORKS_FALLBACKS = [
"accounts/fireworks/models/kimi-k2p5-instruct",
"accounts/fireworks/models/deepseek-v3p2",
"accounts/fireworks/models/llama-v3p3-70b-instruct",
]
def _normalize_model(model: str) -> str:
"""Strip date-based version suffixes from Fireworks model IDs.
e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905''accounts/fireworks/models/kimi-k2-instruct'
Matches only purely-numeric suffixes (48 digits) so names like 'llama-v3p3-70b' are untouched."""
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
model = re.sub(r"-\d{4,8}$", "", model)
return model
def _infer_provider(model: str) -> str:
"""Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS.
Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'."""
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
return "fireworks"
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"):
return "openai"
if model.startswith("claude-"):
return "anthropic"
if model.startswith("gemini-"):
return "google"
logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks")
return "fireworks"
class LLMService:
"""Routes requests to appropriate LLM provider"""
@@ -16,7 +50,8 @@ class LLMService:
temperature: float = 0.7,
) -> Dict[str, Any]:
"""Generate a response from the LLM"""
provider = MODEL_PROVIDERS.get(model, "openai")
model = _normalize_model(model)
provider = MODEL_PROVIDERS.get(model) or _infer_provider(model)
try:
if provider == "fireworks":
@@ -31,9 +66,16 @@ class LLMService:
return await self._call_openai(messages, model, max_tokens, temperature)
except Exception as e:
logger.error(f"LLM error ({provider}/{model}): {e}")
fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct"
if model != fallback and settings.fireworks_api_key:
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
if not settings.fireworks_api_key:
raise
for fallback in _FIREWORKS_FALLBACKS:
if model == fallback:
continue
try:
logger.warning(f"[LLM] Falling back to {fallback}")
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
except Exception as fe:
logger.error(f"[LLM] Fallback {fallback} also failed: {fe}")
raise
async def _call_fireworks(

View File

@@ -9,7 +9,7 @@ import logging
logger = logging.getLogger(__name__)
RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}.
Your role is to answer questions based on the provided context from company documents.
Your role is to answer questions based on the provided context from the knowledge base (documents and web pages).
IMPORTANT RULES:
1. Answer based on the provided context below
@@ -20,7 +20,7 @@ IMPORTANT RULES:
{language_instruction}
{custom_instructions}
Context from knowledge base:
Knowledge base context:
{context}
"""
@@ -74,14 +74,22 @@ class RAGEngine:
}
# Step 2: Retrieve relevant chunks
# Fetch more than needed so that after filtering low-quality results
# we still have enough context. score_threshold=0.55 keeps only chunks
# that are genuinely relevant for text-embedding-3-small cosine similarity.
# Retrieve more candidates than needed (10) with a slightly relaxed threshold (0.45)
# so that content from both document and URL sources gets fair representation.
# Scraped web text embeds less cleanly than structured documents, so 0.55 was
# filtering out valid URL chunks. Context is capped by char limit below.
total_in_collection = self.vector_svc.count_vectors(collection_name)
logger.info(f"[RAG] Collection '{collection_name}' has {total_in_collection} vectors total")
# No score_threshold — always return the top-N most similar chunks by rank.
# Absolute cosine scores vary widely by document type and embedding model;
# filtering by a fixed cutoff here discards valid context when scores are
# uniformly low. The confidence_score below captures retrieval quality for
# handoff/fallback decisions without silencing the LLM's context.
retrieved = self.vector_svc.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=8,
score_threshold=0.55,
limit=10,
)
logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'")
@@ -90,25 +98,38 @@ class RAGEngine:
text_preview = item.get("payload", {}).get("text", "")[:80]
logger.info(f"[RAG] Chunk {i+1}: score={score:.4f}, preview='{text_preview}...'")
# Step 3: Build sources
# Step 3: Build sources and labeled context
# Each chunk is prefixed with its source so the LLM can synthesize
# correctly when mixing document and URL content.
MAX_CONTEXT_CHARS = 10_000
sources = []
context_parts = []
seen_texts = set()
total_chars = 0
for item in retrieved:
payload = item.get("payload", {})
text = payload.get("text", "")
if text and text not in seen_texts:
seen_texts.add(text)
context_parts.append(text)
sources.append(
SourceDocument(
document_name=payload.get("file_name", "Document"),
chunk_text=text[:200] + "..." if len(text) > 200 else text,
score=item.get("score", 0.0),
page_number=payload.get("page_number"),
)
if not text or text in seen_texts:
continue
if total_chars + len(text) > MAX_CONTEXT_CHARS:
break
seen_texts.add(text)
total_chars += len(text)
file_name = payload.get("file_name", "Document")
source_url = payload.get("source_url")
label = f"[Source: {source_url}]" if source_url else f"[Source: {file_name}]"
context_parts.append(f"{label}\n{text}")
sources.append(
SourceDocument(
document_name=file_name,
chunk_text=text[:200] + "..." if len(text) > 200 else text,
score=item.get("score", 0.0),
page_number=payload.get("page_number"),
)
)
if context_parts:
context = "\n\n---\n\n".join(context_parts)

View File

@@ -1,7 +1,5 @@
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import (
Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
)
from qdrant_client.http.models import Distance, VectorParams, PointStruct
from app.config import settings
from typing import List, Dict, Any, Optional
import logging
@@ -103,15 +101,13 @@ class VectorStoreService:
collection_name: str,
query_vector: List[float],
limit: int = 5,
score_threshold: float = 0.3,
) -> List[Dict[str, Any]]:
"""Search for similar vectors"""
"""Search for similar vectors, returning the top-N by cosine score."""
try:
results = self.client.query_points(
collection_name=collection_name,
query=query_vector,
limit=limit,
score_threshold=score_threshold,
).points
return [
{
@@ -122,7 +118,7 @@ class VectorStoreService:
for r in results
]
except Exception as e:
logger.error(f"Error searching vectors: {e}")
logger.error(f"Error searching vectors in '{collection_name}': {e}", exc_info=True)
return []
def delete_by_document_id(self, collection_name: str, document_id: str) -> bool:
@@ -131,19 +127,21 @@ class VectorStoreService:
self.client.delete(
collection_name=collection_name,
points_selector=models.FilterSelector(
filter=Filter(
filter=models.Filter(
must=[
FieldCondition(
models.FieldCondition(
key="document_id",
match=MatchValue(value=document_id),
match=models.MatchValue(value=document_id),
)
]
)
),
wait=True,
)
logger.info(f"Deleted vectors for document '{document_id}' from '{collection_name}'")
return True
except Exception as e:
logger.error(f"Error deleting document vectors: {e}")
logger.error(f"Error deleting vectors for document '{document_id}' in '{collection_name}': {e}", exc_info=True)
return False
def count_vectors(self, collection_name: str) -> int:

View File

@@ -42,9 +42,22 @@ async def scrape_url(url: str) -> dict:
main = soup.find("main") or soup.find("article") or soup.find("body") or soup
text = main.get_text(separator="\n", strip=True)
# Clean up whitespace
lines = [line.strip() for line in text.splitlines() if line.strip()]
text = "\n".join(lines)
# Clean up whitespace and filter structural noise
seen_lines: set[str] = set()
clean_lines = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
# Skip very short lines (nav items, button labels, breadcrumb separators)
if len(line) < 15:
continue
# Skip duplicate lines (nav/footer repeated across sections)
if line in seen_lines:
continue
seen_lines.add(line)
clean_lines.append(line)
text = "\n".join(clean_lines)
# Limit size
if len(text.encode("utf-8")) > MAX_TEXT_BYTES: