Files
contexta_be/app/services/rag.py
2026-04-26 21:43:19 +00:00

204 lines
8.6 KiB
Python

from app.services.embeddings import embedding_service
from app.services.vector_store import vector_store
from app.services.llm import llm_service
from app.services import cache as response_cache
from app.models import SourceDocument
from typing import List, Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}.
Your role is to answer questions based on the provided context from the knowledge base (documents and web pages).
IMPORTANT RULES:
1. Answer based on the provided context below
2. If the context does not contain enough information, say so, but also try to be helpful with what IS available
3. Be concise and helpful
4. Always maintain a professional, friendly tone
5. If asked about topics completely outside the context, politely redirect to relevant topics
{language_instruction}
{custom_instructions}
Knowledge base context:
{context}
"""
LANGUAGE_NAMES = {
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
"it": "Italian", "pt": "Portuguese", "ar": "Arabic", "zh": "Chinese",
"ja": "Japanese", "ko": "Korean", "ru": "Russian", "nl": "Dutch",
"tr": "Turkish", "pl": "Polish", "vi": "Vietnamese", "th": "Thai",
}
class RAGEngine:
def __init__(self):
self.embedding_svc = embedding_service
self.vector_svc = vector_store
self.llm_svc = llm_service
async def process_query(
self,
query: str,
collection_name: str,
chatbot_config: Dict[str, Any],
conversation_history: List[Dict[str, str]] = None,
language: str = "en",
bypass_cache: bool = False,
) -> Dict[str, Any]:
"""
Full RAG pipeline: embed → retrieve → generate
"""
if conversation_history is None:
conversation_history = []
# Cache hit — only for stateless (no history) queries, and not bypassed
if not conversation_history and not bypass_cache:
cached = response_cache.get(collection_name, query)
if cached is not None:
logger.info(f"[RAG] Cache hit for query in '{collection_name}'")
return cached
# Step 1: Embed the query
try:
query_embedding = self.embedding_svc.embed_text(query)
logger.info(f"[RAG] Query embedded successfully. Vector length: {len(query_embedding)}")
except Exception as e:
logger.error(f"[RAG] Embedding error: {e}", exc_info=True)
return {
"response": "I'm having trouble processing your request. Please try again.",
"sources": [],
"tokens_used": 0,
"model": chatbot_config.get("model", "unknown"),
}
# Step 2: Retrieve relevant chunks
# Retrieve more candidates than needed (10) with a slightly relaxed threshold (0.45)
# so that content from both document and URL sources gets fair representation.
# Scraped web text embeds less cleanly than structured documents, so 0.55 was
# filtering out valid URL chunks. Context is capped by char limit below.
total_in_collection = self.vector_svc.count_vectors(collection_name)
logger.info(f"[RAG] Collection '{collection_name}' has {total_in_collection} vectors total")
# No score_threshold — always return the top-N most similar chunks by rank.
# Absolute cosine scores vary widely by document type and embedding model;
# filtering by a fixed cutoff here discards valid context when scores are
# uniformly low. The confidence_score below captures retrieval quality for
# handoff/fallback decisions without silencing the LLM's context.
retrieved = self.vector_svc.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=10,
)
logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'")
for i, item in enumerate(retrieved):
score = item.get("score", 0)
text_preview = item.get("payload", {}).get("text", "")[:80]
logger.info(f"[RAG] Chunk {i+1}: score={score:.4f}, preview='{text_preview}...'")
# Step 3: Build sources and labeled context
# Each chunk is prefixed with its source so the LLM can synthesize
# correctly when mixing document and URL content.
MAX_CONTEXT_CHARS = 10_000
sources = []
context_parts = []
seen_texts = set()
total_chars = 0
for item in retrieved:
payload = item.get("payload", {})
text = payload.get("text", "")
if not text or text in seen_texts:
continue
if total_chars + len(text) > MAX_CONTEXT_CHARS:
break
seen_texts.add(text)
total_chars += len(text)
file_name = payload.get("file_name", "Document")
source_url = payload.get("source_url")
label = f"[Source: {source_url}]" if source_url else f"[Source: {file_name}]"
context_parts.append(f"{label}\n{text}")
sources.append(
SourceDocument(
document_name=file_name,
chunk_text=text[:200] + "..." if len(text) > 200 else text,
score=item.get("score", 0.0),
page_number=payload.get("page_number"),
)
)
if context_parts:
context = "\n\n---\n\n".join(context_parts)
logger.info(f"[RAG] Built context from {len(context_parts)} chunks ({len(context)} chars)")
else:
context = "No relevant information found in the knowledge base."
logger.warning(f"[RAG] No context found for query: '{query}' in collection '{collection_name}'")
# Confidence: mean of top-3 scores (more stable than max alone)
top_scores = sorted([s.score for s in sources], reverse=True)[:3]
confidence_score = round(sum(top_scores) / len(top_scores), 4) if top_scores else 0.0
# Step 4: Build messages
language_instruction = (
"\n6. CRITICAL: Always reply in the exact same language the user wrote in. "
"If they write in French, reply in French. If Spanish, reply in Spanish. "
"Never switch to English unless the user writes in English."
)
system_prompt = RAG_SYSTEM_PROMPT.format(
company_name=chatbot_config.get("company_name", ""),
language_instruction=language_instruction,
custom_instructions=chatbot_config.get("system_prompt") or "",
context=context,
)
messages = [{"role": "system", "content": system_prompt}]
# FIX: Conversation history must be in CHRONOLOGICAL order (oldest first).
# The history should already come sorted ascending from the chat router.
# We take the last 10 messages for context window management.
history_to_use = conversation_history[-10:] if conversation_history else []
for msg in history_to_use:
messages.append({"role": msg["role"], "content": msg["content"]})
# Add current query
messages.append({"role": "user", "content": query})
logger.info(f"[RAG] Sending {len(messages)} messages to LLM (model: {chatbot_config.get('model')})")
# Step 5: Generate response
model = chatbot_config.get("model", "accounts/fireworks/models/kimi-k2-instruct")
try:
result = await self.llm_svc.generate(
messages=messages,
model=model,
max_tokens=chatbot_config.get("max_tokens", 1000),
temperature=chatbot_config.get("temperature", 0.7),
)
logger.info(f"[RAG] LLM response generated. Tokens used: {result.get('tokens_used', 0)}")
payload = {
"response": result["content"],
"sources": sources,
"confidence_score": confidence_score,
"tokens_used": result.get("tokens_used", 0),
"model": result.get("model", model),
}
if not conversation_history and not bypass_cache:
response_cache.set(collection_name, query, payload)
return payload
except Exception as e:
logger.error(f"[RAG] LLM generation error: {e}", exc_info=True)
return {
"response": "I'm having trouble generating a response. Please try again later.",
"sources": sources,
"confidence_score": confidence_score,
"tokens_used": 0,
"model": model,
}
rag_engine = RAGEngine()