fixed the RAg in test pipeline issue

This commit is contained in:
belviskhoremk
2026-04-26 18:51:48 +00:00
parent 205d9d7901
commit 97a501097d
14 changed files with 249 additions and 57 deletions

View File

@@ -99,7 +99,7 @@ MODEL_CATALOG = {
"badge": "Smart", "badge": "Smart",
"description": "Cost-effective and highly capable model", "description": "Cost-effective and highly capable model",
}, },
"accounts/fireworks/models/kimi-k2-instruct-0905": { "accounts/fireworks/models/kimi-k2-instruct": {
"name": "Kimi K2", "name": "Kimi K2",
"provider": "Fireworks AI", "provider": "Fireworks AI",
"badge": "Multilingual", "badge": "Multilingual",
@@ -156,7 +156,7 @@ MODEL_PROVIDERS = {
"accounts/fireworks/models/llama-v3p3-70b-instruct": "fireworks", "accounts/fireworks/models/llama-v3p3-70b-instruct": "fireworks",
"accounts/fireworks/models/qwen3-235b-a22b": "fireworks", "accounts/fireworks/models/qwen3-235b-a22b": "fireworks",
"accounts/fireworks/models/deepseek-v3p1": "fireworks", "accounts/fireworks/models/deepseek-v3p1": "fireworks",
"accounts/fireworks/models/kimi-k2-instruct-0905": "fireworks", "accounts/fireworks/models/kimi-k2-instruct": "fireworks",
# OpenAI # OpenAI
"gpt-4o": "openai", "gpt-4o": "openai",
"gpt-4o-mini": "openai", "gpt-4o-mini": "openai",
@@ -209,7 +209,7 @@ _ALL_FIREWORKS = [
"accounts/fireworks/models/llama-v3p3-70b-instruct", "accounts/fireworks/models/llama-v3p3-70b-instruct",
"accounts/fireworks/models/qwen3-235b-a22b", "accounts/fireworks/models/qwen3-235b-a22b",
"accounts/fireworks/models/deepseek-v3p1", "accounts/fireworks/models/deepseek-v3p1",
"accounts/fireworks/models/kimi-k2-instruct-0905", "accounts/fireworks/models/kimi-k2-instruct",
] ]
_ALL_PREMIUM = [ _ALL_PREMIUM = [
"gpt-4o", "gpt-4o-mini", "gpt-4o", "gpt-4o-mini",

View File

@@ -62,6 +62,7 @@ class UserResponse(BaseModel):
plan: str = "free" plan: str = "free"
is_admin: bool = False is_admin: bool = False
created_at: Optional[datetime] = None created_at: Optional[datetime] = None
language: Optional[str] = "fr"
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
@@ -101,7 +102,7 @@ class ChatbotCreate(BaseModel):
name: str = Field(min_length=2, max_length=100) name: str = Field(min_length=2, max_length=100)
description: Optional[str] = None description: Optional[str] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
model: str = "accounts/fireworks/models/kimi-k2-instruct-0905" model: str = "accounts/fireworks/models/kimi-k2-instruct"
@field_validator("name", mode="before") @field_validator("name", mode="before")
@classmethod @classmethod
@@ -301,6 +302,7 @@ class ChatResponse(BaseModel):
tokens_used: int = 0 tokens_used: int = 0
needs_lead_capture: bool = False needs_lead_capture: bool = False
handoff: bool = False handoff: bool = False
low_confidence: bool = False
class MessageResponse(BaseModel): class MessageResponse(BaseModel):
@@ -460,6 +462,24 @@ class FeedbackCreate(BaseModel):
feedback: str # 'positive' or 'negative' feedback: str # 'positive' or 'negative'
# ─── Test Models ──────────────────────────────────────────────────────────────
class TestQuestion(BaseModel):
question: str
class TestChatRequest(BaseModel):
questions: List[str] = Field(min_length=1, max_length=10)
class TestChatResult(BaseModel):
question: str
response: str
confidence_score: float
sources: List[SourceDocument]
model_used: str
# ─── Inbox Models ───────────────────────────────────────────────────────────── # ─── Inbox Models ─────────────────────────────────────────────────────────────
class InboxConversation(BaseModel): class InboxConversation(BaseModel):

View File

@@ -480,7 +480,7 @@ async def get_knowledge_gaps(chatbot_id: str, user=Depends(get_current_user)):
low_conf = supabase.table("messages").select("id, conversation_id, created_at") \ low_conf = supabase.table("messages").select("id, conversation_id, created_at") \
.in_("conversation_id", conv_ids[:100]) \ .in_("conversation_id", conv_ids[:100]) \
.eq("role", "assistant") \ .eq("role", "assistant") \
.lt("confidence_score", 0.2) \ .lt("confidence_score", 0.55) \
.limit(100).execute() .limit(100).execute()
if not low_conf.data: if not low_conf.data:

View File

@@ -24,6 +24,7 @@ class ProfileUpdate(BaseModel):
company_name: Optional[str] = None company_name: Optional[str] = None
current_password: Optional[str] = None current_password: Optional[str] = None
new_password: Optional[str] = Field(default=None, min_length=8) new_password: Optional[str] = Field(default=None, min_length=8)
language: Optional[str] = None
@router.post("/signup", response_model=TokenResponse) @router.post("/signup", response_model=TokenResponse)
@@ -116,12 +117,14 @@ async def login(data: UserLogin):
) )
plan = sub.data[0]["plan"] if sub.data else "free" plan = sub.data[0]["plan"] if sub.data else "free"
# Get is_admin flag # Get is_admin and language from profile
try: try:
profile = supabase.table("user_profiles").select("is_admin").eq("user_id", user.id).execute() profile = supabase.table("user_profiles").select("is_admin, language").eq("user_id", user.id).execute()
is_admin = profile.data[0].get("is_admin", False) if profile.data else False is_admin = profile.data[0].get("is_admin", False) if profile.data else False
language = profile.data[0].get("language", "fr") if profile.data else "fr"
except Exception: except Exception:
is_admin = False is_admin = False
language = "fr"
return TokenResponse( return TokenResponse(
access_token=auth_resp.session.access_token, access_token=auth_resp.session.access_token,
@@ -131,6 +134,7 @@ async def login(data: UserLogin):
company_name=company_name, company_name=company_name,
plan=plan, plan=plan,
is_admin=is_admin, is_admin=is_admin,
language=language,
), ),
) )
except HTTPException: except HTTPException:
@@ -191,12 +195,22 @@ async def update_profile(data: ProfileUpdate, user=Depends(get_current_user)):
raise HTTPException(status_code=400, detail="Current password is incorrect") raise HTTPException(status_code=400, detail="Current password is incorrect")
supabase.auth.admin.update_user_by_id(user.id, {"password": data.new_password}) supabase.auth.admin.update_user_by_id(user.id, {"password": data.new_password})
if data.language:
supabase.table("user_profiles").update({"language": data.language}).eq("user_id", user.id).execute()
company = supabase.table("companies").select("name").eq("owner_id", user.id).execute() company = supabase.table("companies").select("name").eq("owner_id", user.id).execute()
company_name = company.data[0]["name"] if company.data else "" company_name = company.data[0]["name"] if company.data else ""
sub = supabase.table("subscriptions").select("plan").eq("user_id", user.id).eq("status", "active").execute() sub = supabase.table("subscriptions").select("plan").eq("user_id", user.id).eq("status", "active").execute()
plan = sub.data[0]["plan"] if sub.data else "free" plan = sub.data[0]["plan"] if sub.data else "free"
try:
profile = supabase.table("user_profiles").select("is_admin, language").eq("user_id", user.id).execute()
is_admin = profile.data[0].get("is_admin", False) if profile.data else False
language = profile.data[0].get("language", "fr") if profile.data else "fr"
except Exception:
is_admin = False
language = data.language or "fr"
return UserResponse(id=user.id, email=user.email, company_name=company_name, plan=plan) return UserResponse(id=user.id, email=user.email, company_name=company_name, plan=plan, is_admin=is_admin, language=language)
@router.delete("/account") @router.delete("/account")
@@ -246,10 +260,12 @@ async def get_me(user=Depends(get_current_user)):
plan = sub.data[0]["plan"] if sub.data else "free" plan = sub.data[0]["plan"] if sub.data else "free"
try: try:
profile = supabase.table("user_profiles").select("is_admin").eq("user_id", user.id).execute() profile = supabase.table("user_profiles").select("is_admin, language").eq("user_id", user.id).execute()
is_admin = profile.data[0].get("is_admin", False) if profile.data else False is_admin = profile.data[0].get("is_admin", False) if profile.data else False
language = profile.data[0].get("language", "fr") if profile.data else "fr"
except Exception: except Exception:
is_admin = False is_admin = False
language = "fr"
return UserResponse( return UserResponse(
id=user.id, id=user.id,
@@ -257,4 +273,5 @@ async def get_me(user=Depends(get_current_user)):
company_name=company_name, company_name=company_name,
plan=plan, plan=plan,
is_admin=is_admin, is_admin=is_admin,
language=language,
) )

View File

@@ -245,6 +245,13 @@ async def telegram_webhook(bot_token: str, request: Request):
await tg_send(bot_token, chat_id, welcome) await tg_send(bot_token, chat_id, welcome)
return {"ok": True} return {"ok": True}
if text == "/owner":
supabase.table("channel_connections").update(
{"owner_chat_id": str(chat_id)}
).eq("channel", "telegram").eq("bot_token", bot_token).execute()
await tg_send(bot_token, chat_id, "✅ You're registered as the owner of this bot. You'll receive handoff alerts here when a visitor needs human support.")
return {"ok": True}
# Use first 8 chars of token as namespace to avoid collisions between bots # Use first 8 chars of token as namespace to avoid collisions between bots
external_id = f"tg:{bot_token[:8]}:{chat_id}" external_id = f"tg:{bot_token[:8]}:{chat_id}"
session = _get_or_create_channel_session(chatbot_id, "telegram", external_id, supabase) session = _get_or_create_channel_session(chatbot_id, "telegram", external_id, supabase)
@@ -274,7 +281,7 @@ async def telegram_webhook(bot_token: str, request: Request):
await tg_send(bot_token, chat_id, "Sorry, I encountered an error. Please try again.") await tg_send(bot_token, chat_id, "Sorry, I encountered an error. Please try again.")
return {"ok": True} return {"ok": True}
confidence_score = max((s.score for s in result.get("sources", [])), default=0.0) confidence_score = result.get("confidence_score", 0.0)
_save_message(conversation["id"], "user", text, supabase) _save_message(conversation["id"], "user", text, supabase)
_save_message( _save_message(
conversation["id"], "assistant", result["response"], supabase, conversation["id"], "assistant", result["response"], supabase,

View File

@@ -2,7 +2,7 @@ import time
from collections import defaultdict from collections import defaultdict
from fastapi import APIRouter, HTTPException, Depends, Request from fastapi import APIRouter, HTTPException, Depends, Request
from app.models import ChatMessage, ChatResponse, ConversationResponse, MessageResponse, FeedbackCreate from app.models import ChatMessage, ChatResponse, ConversationResponse, MessageResponse, FeedbackCreate, TestChatRequest, TestChatResult
from app.database import get_supabase from app.database import get_supabase
from app.dependencies import get_current_user, get_optional_user from app.dependencies import get_current_user, get_optional_user
from app.services.rag import rag_engine from app.services.rag import rag_engine
@@ -15,6 +15,8 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(tags=["Chat"]) router = APIRouter(tags=["Chat"])
CONFIDENCE_THRESHOLD = 0.55
# ── Simple in-memory rate limiter ──────────────────────────────────────────── # ── Simple in-memory rate limiter ────────────────────────────────────────────
_rate_store: dict = defaultdict(list) _rate_store: dict = defaultdict(list)
_RATE_LIMIT = 30 # max requests _RATE_LIMIT = 30 # max requests
@@ -166,29 +168,27 @@ async def chat(
language=message.language, language=message.language,
) )
# Compute confidence score confidence_score = result.get("confidence_score", 0.0)
confidence_score = max((s.score for s in result.get("sources", [])), default=0.0)
# Check handoff # Check handoff
is_handoff = False is_handoff = False
low_confidence = confidence_score < CONFIDENCE_THRESHOLD
if chatbot.get("handoff_enabled"): if chatbot.get("handoff_enabled"):
handoff_keywords = chatbot.get("handoff_keywords", []) handoff_keywords = chatbot.get("handoff_keywords", [])
msg_lower = message.message.lower() msg_lower = message.message.lower()
if any(kw.lower() in msg_lower for kw in handoff_keywords): keyword_triggered = any(kw.lower() in msg_lower for kw in handoff_keywords)
if keyword_triggered or low_confidence:
is_handoff = True is_handoff = True
# Fire n8n notification (async, non-blocking)
try: try:
from app.services.n8n_service import send_handoff_notification from app.services.notification_service import send_handoff_alert
from app.config import settings as _settings await send_handoff_alert(
company_data_for_handoff = chatbot.get("companies", {}) or {} chatbot_id=chatbot_id,
await send_handoff_notification(
chatbot_name=chatbot.get("name", ""), chatbot_name=chatbot.get("name", ""),
owner_email=chatbot.get("handoff_email") or "",
conversation_history=history, conversation_history=history,
trigger_message=message.message, trigger_message=message.message,
chatbot_id=chatbot_id,
conversation_id=conversation["id"], conversation_id=conversation["id"],
webhook_url=_settings.n8n_handoff_webhook_url, low_confidence=low_confidence,
supabase=supabase,
) )
except Exception: except Exception:
pass # never block chat on handoff failure pass # never block chat on handoff failure
@@ -228,6 +228,7 @@ async def chat(
tokens_used=result.get("tokens_used", 0), tokens_used=result.get("tokens_used", 0),
needs_lead_capture=needs_lead_capture, needs_lead_capture=needs_lead_capture,
handoff=is_handoff, handoff=is_handoff,
low_confidence=low_confidence,
) )
@@ -289,6 +290,56 @@ async def submit_feedback(chatbot_id: str, data: FeedbackCreate):
return {"success": True} return {"success": True}
@router.post("/chat/{chatbot_id}/test", response_model=List[TestChatResult])
async def test_chat(
chatbot_id: str,
body: TestChatRequest,
user=Depends(get_current_user),
):
"""Run test questions against a chatbot without saving to conversation history. Owner only."""
supabase = get_supabase()
chatbot = _get_public_chatbot(chatbot_id, supabase)
company = supabase.table("companies").select("id").eq("owner_id", user.id).execute()
if not company.data or company.data[0]["id"] != chatbot.get("company_id"):
raise HTTPException(status_code=403, detail="Access denied")
collection_name = chatbot.get("qdrant_collection_name")
if not collection_name:
raise HTTPException(status_code=400, detail="Chatbot has no knowledge base configured")
company_data = chatbot.get("companies", {}) or {}
chatbot_config = {**chatbot, "company_name": company_data.get("name", "")}
results = []
for question in body.questions:
try:
result = await rag_engine.process_query(
query=question,
collection_name=collection_name,
chatbot_config=chatbot_config,
conversation_history=[],
language="auto",
bypass_cache=True,
)
results.append(TestChatResult(
question=question,
response=result["response"],
confidence_score=result.get("confidence_score", 0.0),
sources=result.get("sources", []),
model_used=result.get("model", ""),
))
except Exception as e:
results.append(TestChatResult(
question=question,
response=f"Error: {e}",
confidence_score=0.0,
sources=[],
model_used="",
))
return results
# ── OLD analytics endpoint REMOVED ─────────────────────────────────────────── # ── OLD analytics endpoint REMOVED ───────────────────────────────────────────
# The /analytics/{chatbot_id} endpoint that was here has been replaced by # The /analytics/{chatbot_id} endpoint that was here has been replaced by
# the dedicated analytics router (app/routers/analytics.py) which provides: # the dedicated analytics router (app/routers/analytics.py) which provides:

View File

@@ -296,7 +296,7 @@ def _format_chatbot(chatbot: dict, supabase) -> ChatbotResponse:
name=chatbot["name"], name=chatbot["name"],
description=chatbot.get("description"), description=chatbot.get("description"),
system_prompt=chatbot.get("system_prompt"), system_prompt=chatbot.get("system_prompt"),
model=chatbot.get("model", "accounts/fireworks/models/kimi-k2-instruct-0905"), model=chatbot.get("model", "accounts/fireworks/models/kimi-k2-instruct"),
temperature=chatbot.get("temperature", 0.7), temperature=chatbot.get("temperature", 0.7),
max_tokens=chatbot.get("max_tokens", 1000), max_tokens=chatbot.get("max_tokens", 1000),
primary_color=chatbot.get("primary_color", "#6366f1"), primary_color=chatbot.get("primary_color", "#6366f1"),

View File

@@ -6,6 +6,7 @@ from app.services.document_processor import process_document
from app.services.embeddings import embedding_service from app.services.embeddings import embedding_service
from app.services.vector_store import vector_store from app.services.vector_store import vector_store
from app.services.storage import delete_from_storage, extract_storage_path from app.services.storage import delete_from_storage, extract_storage_path
from app.services import cache as response_cache
from app.config import settings from app.config import settings
from typing import List from typing import List
import uuid import uuid
@@ -166,6 +167,7 @@ async def _process_document_bg(
"chunk_count": len(chunks), "chunk_count": len(chunks),
}).eq("id", doc_id).execute() }).eq("id", doc_id).execute()
response_cache.invalidate(collection_name)
logger.info(f"Document {doc_id} processed: {len(chunks)} chunks") logger.info(f"Document {doc_id} processed: {len(chunks)} chunks")
except Exception as e: except Exception as e:
@@ -211,6 +213,8 @@ async def delete_document(chatbot_id: str, document_id: str, user=Depends(get_cu
delete_from_storage(supabase, "documents", doc.data[0]["file_url"]) delete_from_storage(supabase, "documents", doc.data[0]["file_url"])
supabase.table("documents").delete().eq("id", document_id).execute() supabase.table("documents").delete().eq("id", document_id).execute()
if collection_name:
response_cache.invalidate(collection_name)
return SuccessResponse(success=True, message="Document deleted") return SuccessResponse(success=True, message="Document deleted")
@@ -259,6 +263,11 @@ async def retry_document_processing(
"chunk_count": 0, "chunk_count": 0,
}).eq("id", document_id).execute() }).eq("id", document_id).execute()
# Clear stale cache before re-processing so tests see fresh results
collection_name = chatbot.get("qdrant_collection_name")
if collection_name:
response_cache.invalidate(collection_name)
# Re-enqueue background processing # Re-enqueue background processing
background_tasks.add_task( background_tasks.add_task(
_process_document_bg, _process_document_bg,
@@ -340,10 +349,56 @@ async def delete_url_source(chatbot_id: str, source_id: str, user=Depends(get_cu
if not source.data: if not source.data:
raise HTTPException(status_code=404, detail="URL source not found") raise HTTPException(status_code=404, detail="URL source not found")
chatbot = _get_user_chatbot(chatbot_id, user.id, supabase)
collection_name = chatbot.get("qdrant_collection_name")
if collection_name:
try:
vector_store.delete_by_document_id(collection_name, source_id)
except Exception:
pass
response_cache.invalidate(collection_name)
supabase.table("url_sources").delete().eq("id", source_id).execute() supabase.table("url_sources").delete().eq("id", source_id).execute()
return SuccessResponse(success=True, message="URL source deleted") return SuccessResponse(success=True, message="URL source deleted")
@url_router.post("/{source_id}/refresh", response_model=UrlSourceResponse)
async def refresh_url_source(
chatbot_id: str,
source_id: str,
background_tasks: BackgroundTasks,
user=Depends(get_current_user),
):
"""Re-scrape a URL source and rebuild its vectors."""
supabase = get_supabase()
chatbot = _get_user_chatbot(chatbot_id, user.id, supabase)
source = supabase.table("url_sources").select("*").eq("id", source_id).eq("chatbot_id", chatbot_id).execute()
if not source.data:
raise HTTPException(status_code=404, detail="URL source not found")
src = source.data[0]
collection_name = chatbot.get("qdrant_collection_name")
# Drop existing vectors for this source
if collection_name:
try:
vector_store.delete_by_document_id(collection_name, source_id)
except Exception as e:
logger.warning(f"Could not delete old vectors for url source {source_id}: {e}")
response_cache.invalidate(collection_name)
# Reset to pending and reprocess
updated = supabase.table("url_sources").update({
"status": "pending",
"error_message": None,
"chunk_count": 0,
}).eq("id", source_id).returning("representation").execute()
background_tasks.add_task(_process_url_source, source_id, src["url"], chatbot, supabase)
return UrlSourceResponse(**{**src, "status": "pending", "chunk_count": 0})
async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase): async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase):
"""Background task to scrape a URL and add its content to the vector store.""" """Background task to scrape a URL and add its content to the vector store."""
from app.services.web_scraper import scrape_url from app.services.web_scraper import scrape_url
@@ -424,6 +479,7 @@ async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase)
"chunk_count": len(chunks), "chunk_count": len(chunks),
}).eq("id", source_id).execute() }).eq("id", source_id).execute()
response_cache.invalidate(collection_name)
logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url}") logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url}")
except Exception as e: except Exception as e:

View File

@@ -175,13 +175,21 @@ async def rate_chatbot(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
supabase = get_supabase() supabase = get_supabase()
chatbot = supabase.table("chatbots").select("id, average_rating").eq("id", chatbot_id).eq("is_published", True).execute() chatbot = supabase.table("chatbots") \
.select("id, average_rating, rating_count") \
.eq("id", chatbot_id).eq("is_published", True).execute()
if not chatbot.data: if not chatbot.data:
raise HTTPException(status_code=404, detail="Chatbot not found") raise HTTPException(status_code=404, detail="Chatbot not found")
# Simple rating update (average) row = chatbot.data[0]
current = chatbot.data[0].get("average_rating") or rating.rating current_avg = row.get("average_rating") or 0.0
new_avg = (current + rating.rating) / 2 current_count = row.get("rating_count") or 0
supabase.table("chatbots").update({"average_rating": round(new_avg, 1)}).eq("id", chatbot_id).execute() new_count = current_count + 1
return {"message": "Rating submitted", "new_average": round(new_avg, 1)} new_avg = round((current_avg * current_count + rating.rating) / new_count, 2)
supabase.table("chatbots").update({
"average_rating": new_avg,
"rating_count": new_count,
}).eq("id", chatbot_id).execute()
return {"average_rating": new_avg, "rating_count": new_count}

View File

@@ -31,14 +31,9 @@ class LLMService:
return await self._call_openai(messages, model, max_tokens, temperature) return await self._call_openai(messages, model, max_tokens, temperature)
except Exception as e: except Exception as e:
logger.error(f"LLM error ({provider}/{model}): {e}") logger.error(f"LLM error ({provider}/{model}): {e}")
# Fallback to a basic model if available fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct"
if model != "accounts/fireworks/models/kimi-k2-instruct-0905" and settings.fireworks_api_key: if model != fallback and settings.fireworks_api_key:
return await self._call_fireworks( return await self._call_fireworks(messages, fallback, max_tokens, temperature)
messages,
"accounts/fireworks/models/kimi-k2-instruct-0905",
max_tokens,
temperature,
)
raise raise
async def _call_fireworks( async def _call_fireworks(

View File

@@ -1,8 +1,9 @@
from app.services.embeddings import embedding_service from app.services.embeddings import embedding_service
from app.services.vector_store import vector_store from app.services.vector_store import vector_store
from app.services.llm import llm_service from app.services.llm import llm_service
from app.services import cache as response_cache
from app.models import SourceDocument from app.models import SourceDocument
from typing import List, Dict, Any, Optional, Tuple from typing import List, Dict, Any, Optional
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,6 +45,7 @@ class RAGEngine:
chatbot_config: Dict[str, Any], chatbot_config: Dict[str, Any],
conversation_history: List[Dict[str, str]] = None, conversation_history: List[Dict[str, str]] = None,
language: str = "en", language: str = "en",
bypass_cache: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Full RAG pipeline: embed → retrieve → generate Full RAG pipeline: embed → retrieve → generate
@@ -51,6 +53,13 @@ class RAGEngine:
if conversation_history is None: if conversation_history is None:
conversation_history = [] conversation_history = []
# Cache hit — only for stateless (no history) queries, and not bypassed
if not conversation_history and not bypass_cache:
cached = response_cache.get(collection_name, query)
if cached is not None:
logger.info(f"[RAG] Cache hit for query in '{collection_name}'")
return cached
# Step 1: Embed the query # Step 1: Embed the query
try: try:
query_embedding = self.embedding_svc.embed_text(query) query_embedding = self.embedding_svc.embed_text(query)
@@ -65,14 +74,14 @@ class RAGEngine:
} }
# Step 2: Retrieve relevant chunks # Step 2: Retrieve relevant chunks
# FIX: Lowered score_threshold from 0.3 to 0.1 to avoid filtering out # Fetch more than needed so that after filtering low-quality results
# all results. With cosine similarity, 0.3 can be too aggressive for # we still have enough context. score_threshold=0.55 keeps only chunks
# many document types and query patterns. # that are genuinely relevant for text-embedding-3-small cosine similarity.
retrieved = self.vector_svc.search( retrieved = self.vector_svc.search(
collection_name=collection_name, collection_name=collection_name,
query_vector=query_embedding, query_vector=query_embedding,
limit=5, limit=8,
score_threshold=0.1, # FIX: was 0.3, now 0.1 to avoid over-filtering score_threshold=0.55,
) )
logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'") logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'")
@@ -108,11 +117,15 @@ class RAGEngine:
context = "No relevant information found in the knowledge base." context = "No relevant information found in the knowledge base."
logger.warning(f"[RAG] No context found for query: '{query}' in collection '{collection_name}'") logger.warning(f"[RAG] No context found for query: '{query}' in collection '{collection_name}'")
# Confidence: mean of top-3 scores (more stable than max alone)
top_scores = sorted([s.score for s in sources], reverse=True)[:3]
confidence_score = round(sum(top_scores) / len(top_scores), 4) if top_scores else 0.0
# Step 4: Build messages # Step 4: Build messages
lang_name = LANGUAGE_NAMES.get(language, "English") if language and language != "en" else ""
language_instruction = ( language_instruction = (
f"\n6. Respond in {lang_name}. Match the language of the user's message." "\n6. CRITICAL: Always reply in the exact same language the user wrote in. "
if lang_name else "" "If they write in French, reply in French. If Spanish, reply in Spanish. "
"Never switch to English unless the user writes in English."
) )
system_prompt = RAG_SYSTEM_PROMPT.format( system_prompt = RAG_SYSTEM_PROMPT.format(
@@ -137,7 +150,7 @@ class RAGEngine:
logger.info(f"[RAG] Sending {len(messages)} messages to LLM (model: {chatbot_config.get('model')})") logger.info(f"[RAG] Sending {len(messages)} messages to LLM (model: {chatbot_config.get('model')})")
# Step 5: Generate response # Step 5: Generate response
model = chatbot_config.get("model", "accounts/fireworks/models/kimi-k2-instruct-0905") model = chatbot_config.get("model", "accounts/fireworks/models/kimi-k2-instruct")
try: try:
result = await self.llm_svc.generate( result = await self.llm_svc.generate(
messages=messages, messages=messages,
@@ -146,17 +159,22 @@ class RAGEngine:
temperature=chatbot_config.get("temperature", 0.7), temperature=chatbot_config.get("temperature", 0.7),
) )
logger.info(f"[RAG] LLM response generated. Tokens used: {result.get('tokens_used', 0)}") logger.info(f"[RAG] LLM response generated. Tokens used: {result.get('tokens_used', 0)}")
return { payload = {
"response": result["content"], "response": result["content"],
"sources": sources, "sources": sources,
"confidence_score": confidence_score,
"tokens_used": result.get("tokens_used", 0), "tokens_used": result.get("tokens_used", 0),
"model": result.get("model", model), "model": result.get("model", model),
} }
if not conversation_history and not bypass_cache:
response_cache.set(collection_name, query, payload)
return payload
except Exception as e: except Exception as e:
logger.error(f"[RAG] LLM generation error: {e}", exc_info=True) logger.error(f"[RAG] LLM generation error: {e}", exc_info=True)
return { return {
"response": "I'm having trouble generating a response. Please try again later.", "response": "I'm having trouble generating a response. Please try again later.",
"sources": sources, "sources": sources,
"confidence_score": confidence_score,
"tokens_used": 0, "tokens_used": 0,
"model": model, "model": model,
} }

View File

@@ -206,9 +206,11 @@ class TestChatHandoff:
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["handoff"] is True assert resp.json()["handoff"] is True
def test_handoff_not_triggered_without_keyword_match(self, client): def test_handoff_not_triggered_when_high_confidence_no_keyword(self, client):
# High confidence + no keyword match → no handoff
chatbot = _make_chatbot(handoff_enabled=True, handoff_keywords=["human"]) chatbot = _make_chatbot(handoff_enabled=True, handoff_keywords=["human"])
rag_result = {"response": "Sure!", "sources": [], "model": "m", "tokens_used": 5} rag_result = {"response": "We open at 9am.", "sources": [], "confidence_score": 0.85,
"model": "m", "tokens_used": 5}
with patch("app.routers.chat.get_supabase") as mock_sb, \ with patch("app.routers.chat.get_supabase") as mock_sb, \
patch("app.routers.chat.rag_engine") as mock_rag: patch("app.routers.chat.rag_engine") as mock_rag:
@@ -219,6 +221,22 @@ class TestChatHandoff:
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["handoff"] is False assert resp.json()["handoff"] is False
def test_handoff_triggered_by_low_confidence(self, client):
# Low confidence (no sources) triggers handoff even without a keyword
chatbot = _make_chatbot(handoff_enabled=True, handoff_keywords=["human"])
rag_result = {"response": "I'm not sure.", "sources": [], "confidence_score": 0.1,
"model": "m", "tokens_used": 5}
with patch("app.routers.chat.get_supabase") as mock_sb, \
patch("app.routers.chat.rag_engine") as mock_rag:
mock_rag.process_query = AsyncMock(return_value=rag_result)
mock_sb.return_value = _make_chat_sb(chatbot=chatbot)
resp = client.post("/api/v1/chat/cb-1", json={"message": "Tell me about quantum physics", "language": "en"})
assert resp.status_code == 200
assert resp.json()["handoff"] is True
assert resp.json()["low_confidence"] is True
class TestChatHistory: class TestChatHistory:
def test_history_returns_empty_for_unknown_session(self, client): def test_history_returns_empty_for_unknown_session(self, client):

View File

@@ -216,8 +216,10 @@ class TestDocumentDelete:
"file_url": None, "file_url": None,
} }
with patch("app.routers.documents.get_supabase") as mock_sb, \ with patch("app.routers.documents.get_supabase") as mock_sb, \
patch("app.routers.documents.vector_store") as mock_vs: patch("app.routers.documents.vector_store") as mock_vs, \
patch("app.routers.documents.response_cache") as mock_cache:
mock_vs.delete_by_document_id = MagicMock() mock_vs.delete_by_document_id = MagicMock()
mock_cache.invalidate = MagicMock()
mock_sb.return_value = _make_doc_sb(doc=doc) mock_sb.return_value = _make_doc_sb(doc=doc)
resp = client.delete("/api/v1/chatbots/cb-1/documents/doc-1", headers=AUTH) resp = client.delete("/api/v1/chatbots/cb-1/documents/doc-1", headers=AUTH)
assert resp.status_code == 200 assert resp.status_code == 200

View File

@@ -185,7 +185,7 @@ class TestMarketplaceRating:
assert resp.status_code == 404 assert resp.status_code == 404
def test_rate_chatbot_success(self, client): def test_rate_chatbot_success(self, client):
bot = {"id": "bot-1", "average_rating": 4.0} bot = {"id": "bot-1", "average_rating": 4.0, "rating_count": 1}
with patch("app.routers.marketplace.get_supabase") as mock_sb: with patch("app.routers.marketplace.get_supabase") as mock_sb:
mock_sb.return_value = _make_marketplace_sb(chatbot_data=[bot]) mock_sb.return_value = _make_marketplace_sb(chatbot_data=[bot])
resp = client.post( resp = client.post(
@@ -195,12 +195,12 @@ class TestMarketplaceRating:
) )
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert "new_average" in body assert "average_rating" in body
assert body["new_average"] == 4.5 # (4.0 + 5) / 2 assert body["average_rating"] == 4.5 # (4.0 * 1 + 5) / 2
def test_rate_chatbot_first_rating(self, client): def test_rate_chatbot_first_rating(self, client):
"""When average_rating is None, should use the submitted rating as both sides.""" """When average_rating is None, should use the submitted rating as the new average."""
bot = {"id": "bot-1", "average_rating": None} bot = {"id": "bot-1", "average_rating": None, "rating_count": 0}
with patch("app.routers.marketplace.get_supabase") as mock_sb: with patch("app.routers.marketplace.get_supabase") as mock_sb:
mock_sb.return_value = _make_marketplace_sb(chatbot_data=[bot]) mock_sb.return_value = _make_marketplace_sb(chatbot_data=[bot])
resp = client.post( resp = client.post(
@@ -209,4 +209,4 @@ class TestMarketplaceRating:
headers=AUTH, headers=AUTH,
) )
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["new_average"] == 5.0 assert resp.json()["average_rating"] == 5.0