Files
contexta_be/app/services/rag.py
belviskhoremk 5bd496d355 Initial commit
2026-02-22 21:59:37 +00:00

131 lines
4.4 KiB
Python

from app.services.embeddings import embedding_service
from app.services.vector_store import vector_store
from app.services.llm import llm_service
from app.models import SourceDocument
from typing import List, Dict, Any, Optional, Tuple
import logging
logger = logging.getLogger(__name__)
RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}.
Your role is to answer questions based on the provided context from company documents.
IMPORTANT RULES:
1. Only answer based on the provided context
2. If information is not in the context, say "I don't have information about that in my knowledge base"
3. Be concise and helpful
4. Always maintain a professional, friendly tone
5. If asked about topics outside the context, politely redirect to relevant topics
{custom_instructions}
Context from knowledge base:
{context}
"""
class RAGEngine:
def __init__(self):
self.embedding_svc = embedding_service
self.vector_svc = vector_store
self.llm_svc = llm_service
async def process_query(
self,
query: str,
collection_name: str,
chatbot_config: Dict[str, Any],
conversation_history: List[Dict[str, str]] = None,
language: str = "en",
) -> Dict[str, Any]:
"""
Full RAG pipeline: embed → retrieve → generate
"""
if conversation_history is None:
conversation_history = []
# Step 1: Embed the query
try:
query_embedding = self.embedding_svc.embed_text(query)
except Exception as e:
logger.error(f"Embedding error: {e}")
return {
"response": "I'm having trouble processing your request. Please try again.",
"sources": [],
"tokens_used": 0,
"model": chatbot_config.get("model", "unknown"),
}
# Step 2: Retrieve relevant chunks
retrieved = self.vector_svc.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=5,
score_threshold=0.3,
)
# Step 3: Build sources
sources = []
context_parts = []
seen_texts = set()
for item in retrieved:
payload = item.get("payload", {})
text = payload.get("text", "")
if text and text not in seen_texts:
seen_texts.add(text)
context_parts.append(text)
sources.append(
SourceDocument(
document_name=payload.get("file_name", "Document"),
chunk_text=text[:200] + "..." if len(text) > 200 else text,
score=item.get("score", 0.0),
page_number=payload.get("page_number"),
)
)
context = "\n\n---\n\n".join(context_parts) if context_parts else "No relevant information found."
# Step 4: Build messages
system_prompt = RAG_SYSTEM_PROMPT.format(
company_name=chatbot_config.get("company_name", ""),
custom_instructions=chatbot_config.get("system_prompt") or "",
context=context,
)
messages = [{"role": "system", "content": system_prompt}]
# Add conversation history (last 10 messages)
for msg in conversation_history[-10:]:
messages.append({"role": msg["role"], "content": msg["content"]})
# Add current query
messages.append({"role": "user", "content": query})
# Step 5: Generate response
model = chatbot_config.get("model", "accounts/fireworks/models/llama-v3p1-70b-instruct")
try:
result = await self.llm_svc.generate(
messages=messages,
model=model,
max_tokens=chatbot_config.get("max_tokens", 1000),
temperature=chatbot_config.get("temperature", 0.7),
)
return {
"response": result["content"],
"sources": sources,
"tokens_used": result.get("tokens_used", 0),
"model": result.get("model", model),
}
except Exception as e:
logger.error(f"LLM generation error: {e}")
return {
"response": "I'm having trouble generating a response. Please try again later.",
"sources": sources,
"tokens_used": 0,
"model": model,
}
rag_engine = RAGEngine()