mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-13 10:39:00 +00:00
204 lines
8.6 KiB
Python
204 lines
8.6 KiB
Python
from app.services.embeddings import embedding_service
|
|
from app.services.vector_store import vector_store
|
|
from app.services.llm import llm_service
|
|
from app.services import cache as response_cache
|
|
from app.models import SourceDocument
|
|
from typing import List, Dict, Any, Optional
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}.
|
|
Your role is to answer questions based on the provided context from the knowledge base (documents and web pages).
|
|
|
|
IMPORTANT RULES:
|
|
1. Answer based on the provided context below
|
|
2. If the context does not contain enough information, say so, but also try to be helpful with what IS available
|
|
3. Be concise and helpful
|
|
4. Always maintain a professional, friendly tone
|
|
5. If asked about topics completely outside the context, politely redirect to relevant topics
|
|
{language_instruction}
|
|
{custom_instructions}
|
|
|
|
Knowledge base context:
|
|
{context}
|
|
"""
|
|
|
|
LANGUAGE_NAMES = {
|
|
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
|
|
"it": "Italian", "pt": "Portuguese", "ar": "Arabic", "zh": "Chinese",
|
|
"ja": "Japanese", "ko": "Korean", "ru": "Russian", "nl": "Dutch",
|
|
"tr": "Turkish", "pl": "Polish", "vi": "Vietnamese", "th": "Thai",
|
|
}
|
|
|
|
|
|
class RAGEngine:
|
|
def __init__(self):
|
|
self.embedding_svc = embedding_service
|
|
self.vector_svc = vector_store
|
|
self.llm_svc = llm_service
|
|
|
|
async def process_query(
|
|
self,
|
|
query: str,
|
|
collection_name: str,
|
|
chatbot_config: Dict[str, Any],
|
|
conversation_history: List[Dict[str, str]] = None,
|
|
language: str = "en",
|
|
bypass_cache: bool = False,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Full RAG pipeline: embed → retrieve → generate
|
|
"""
|
|
if conversation_history is None:
|
|
conversation_history = []
|
|
|
|
# Cache hit — only for stateless (no history) queries, and not bypassed
|
|
if not conversation_history and not bypass_cache:
|
|
cached = response_cache.get(collection_name, query)
|
|
if cached is not None:
|
|
logger.info(f"[RAG] Cache hit for query in '{collection_name}'")
|
|
return cached
|
|
|
|
# Step 1: Embed the query
|
|
try:
|
|
query_embedding = self.embedding_svc.embed_text(query)
|
|
logger.info(f"[RAG] Query embedded successfully. Vector length: {len(query_embedding)}")
|
|
except Exception as e:
|
|
logger.error(f"[RAG] Embedding error: {e}", exc_info=True)
|
|
return {
|
|
"response": "I'm having trouble processing your request. Please try again.",
|
|
"sources": [],
|
|
"tokens_used": 0,
|
|
"model": chatbot_config.get("model", "unknown"),
|
|
}
|
|
|
|
# Step 2: Retrieve relevant chunks
|
|
# Retrieve more candidates than needed (10) with a slightly relaxed threshold (0.45)
|
|
# so that content from both document and URL sources gets fair representation.
|
|
# Scraped web text embeds less cleanly than structured documents, so 0.55 was
|
|
# filtering out valid URL chunks. Context is capped by char limit below.
|
|
total_in_collection = self.vector_svc.count_vectors(collection_name)
|
|
logger.info(f"[RAG] Collection '{collection_name}' has {total_in_collection} vectors total")
|
|
|
|
# No score_threshold — always return the top-N most similar chunks by rank.
|
|
# Absolute cosine scores vary widely by document type and embedding model;
|
|
# filtering by a fixed cutoff here discards valid context when scores are
|
|
# uniformly low. The confidence_score below captures retrieval quality for
|
|
# handoff/fallback decisions without silencing the LLM's context.
|
|
retrieved = self.vector_svc.search(
|
|
collection_name=collection_name,
|
|
query_vector=query_embedding,
|
|
limit=10,
|
|
)
|
|
|
|
logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'")
|
|
for i, item in enumerate(retrieved):
|
|
score = item.get("score", 0)
|
|
text_preview = item.get("payload", {}).get("text", "")[:80]
|
|
logger.info(f"[RAG] Chunk {i+1}: score={score:.4f}, preview='{text_preview}...'")
|
|
|
|
# Step 3: Build sources and labeled context
|
|
# Each chunk is prefixed with its source so the LLM can synthesize
|
|
# correctly when mixing document and URL content.
|
|
MAX_CONTEXT_CHARS = 10_000
|
|
sources = []
|
|
context_parts = []
|
|
seen_texts = set()
|
|
total_chars = 0
|
|
|
|
for item in retrieved:
|
|
payload = item.get("payload", {})
|
|
text = payload.get("text", "")
|
|
if not text or text in seen_texts:
|
|
continue
|
|
if total_chars + len(text) > MAX_CONTEXT_CHARS:
|
|
break
|
|
seen_texts.add(text)
|
|
total_chars += len(text)
|
|
|
|
file_name = payload.get("file_name", "Document")
|
|
source_url = payload.get("source_url")
|
|
label = f"[Source: {source_url}]" if source_url else f"[Source: {file_name}]"
|
|
context_parts.append(f"{label}\n{text}")
|
|
|
|
sources.append(
|
|
SourceDocument(
|
|
document_name=file_name,
|
|
chunk_text=text[:200] + "..." if len(text) > 200 else text,
|
|
score=item.get("score", 0.0),
|
|
page_number=payload.get("page_number"),
|
|
)
|
|
)
|
|
|
|
if context_parts:
|
|
context = "\n\n---\n\n".join(context_parts)
|
|
logger.info(f"[RAG] Built context from {len(context_parts)} chunks ({len(context)} chars)")
|
|
else:
|
|
context = "No relevant information found in the knowledge base."
|
|
logger.warning(f"[RAG] No context found for query: '{query}' in collection '{collection_name}'")
|
|
|
|
# Confidence: mean of top-3 scores (more stable than max alone)
|
|
top_scores = sorted([s.score for s in sources], reverse=True)[:3]
|
|
confidence_score = round(sum(top_scores) / len(top_scores), 4) if top_scores else 0.0
|
|
|
|
# Step 4: Build messages
|
|
language_instruction = (
|
|
"\n6. CRITICAL: Always reply in the exact same language the user wrote in. "
|
|
"If they write in French, reply in French. If Spanish, reply in Spanish. "
|
|
"Never switch to English unless the user writes in English."
|
|
)
|
|
|
|
system_prompt = RAG_SYSTEM_PROMPT.format(
|
|
company_name=chatbot_config.get("company_name", ""),
|
|
language_instruction=language_instruction,
|
|
custom_instructions=chatbot_config.get("system_prompt") or "",
|
|
context=context,
|
|
)
|
|
|
|
messages = [{"role": "system", "content": system_prompt}]
|
|
|
|
# FIX: Conversation history must be in CHRONOLOGICAL order (oldest first).
|
|
# The history should already come sorted ascending from the chat router.
|
|
# We take the last 10 messages for context window management.
|
|
history_to_use = conversation_history[-10:] if conversation_history else []
|
|
for msg in history_to_use:
|
|
messages.append({"role": msg["role"], "content": msg["content"]})
|
|
|
|
# Add current query
|
|
messages.append({"role": "user", "content": query})
|
|
|
|
logger.info(f"[RAG] Sending {len(messages)} messages to LLM (model: {chatbot_config.get('model')})")
|
|
|
|
# Step 5: Generate response
|
|
model = chatbot_config.get("model", "accounts/fireworks/models/kimi-k2-instruct")
|
|
try:
|
|
result = await self.llm_svc.generate(
|
|
messages=messages,
|
|
model=model,
|
|
max_tokens=chatbot_config.get("max_tokens", 1000),
|
|
temperature=chatbot_config.get("temperature", 0.7),
|
|
)
|
|
logger.info(f"[RAG] LLM response generated. Tokens used: {result.get('tokens_used', 0)}")
|
|
payload = {
|
|
"response": result["content"],
|
|
"sources": sources,
|
|
"confidence_score": confidence_score,
|
|
"tokens_used": result.get("tokens_used", 0),
|
|
"model": result.get("model", model),
|
|
}
|
|
if not conversation_history and not bypass_cache:
|
|
response_cache.set(collection_name, query, payload)
|
|
return payload
|
|
except Exception as e:
|
|
logger.error(f"[RAG] LLM generation error: {e}", exc_info=True)
|
|
return {
|
|
"response": "I'm having trouble generating a response. Please try again later.",
|
|
"sources": sources,
|
|
"confidence_score": confidence_score,
|
|
"tokens_used": 0,
|
|
"model": model,
|
|
}
|
|
|
|
|
|
rag_engine = RAGEngine() |