From 97a501097dee5e4aff773e93a851c3f25610353b Mon Sep 17 00:00:00 2001 From: belviskhoremk Date: Sun, 26 Apr 2026 18:51:48 +0000 Subject: [PATCH] fixed the RAg in test pipeline issue --- app/config.py | 6 +-- app/models.py | 22 ++++++++++- app/routers/analytics.py | 2 +- app/routers/auth.py | 25 +++++++++++-- app/routers/channels.py | 9 ++++- app/routers/chat.py | 75 ++++++++++++++++++++++++++++++++------ app/routers/chatbots.py | 2 +- app/routers/documents.py | 56 ++++++++++++++++++++++++++++ app/routers/marketplace.py | 20 +++++++--- app/services/llm.py | 11 ++---- app/services/rag.py | 40 ++++++++++++++------ tests/test_chat.py | 22 ++++++++++- tests/test_documents.py | 4 +- tests/test_marketplace.py | 12 +++--- 14 files changed, 249 insertions(+), 57 deletions(-) diff --git a/app/config.py b/app/config.py index 1867743..a4c9635 100644 --- a/app/config.py +++ b/app/config.py @@ -99,7 +99,7 @@ MODEL_CATALOG = { "badge": "Smart", "description": "Cost-effective and highly capable model", }, - "accounts/fireworks/models/kimi-k2-instruct-0905": { + "accounts/fireworks/models/kimi-k2-instruct": { "name": "Kimi K2", "provider": "Fireworks AI", "badge": "Multilingual", @@ -156,7 +156,7 @@ MODEL_PROVIDERS = { "accounts/fireworks/models/llama-v3p3-70b-instruct": "fireworks", "accounts/fireworks/models/qwen3-235b-a22b": "fireworks", "accounts/fireworks/models/deepseek-v3p1": "fireworks", - "accounts/fireworks/models/kimi-k2-instruct-0905": "fireworks", + "accounts/fireworks/models/kimi-k2-instruct": "fireworks", # OpenAI "gpt-4o": "openai", "gpt-4o-mini": "openai", @@ -209,7 +209,7 @@ _ALL_FIREWORKS = [ "accounts/fireworks/models/llama-v3p3-70b-instruct", "accounts/fireworks/models/qwen3-235b-a22b", "accounts/fireworks/models/deepseek-v3p1", - "accounts/fireworks/models/kimi-k2-instruct-0905", + "accounts/fireworks/models/kimi-k2-instruct", ] _ALL_PREMIUM = [ "gpt-4o", "gpt-4o-mini", diff --git a/app/models.py b/app/models.py index 99f5230..b472051 100644 --- a/app/models.py +++ b/app/models.py @@ -62,6 +62,7 @@ class UserResponse(BaseModel): plan: str = "free" is_admin: bool = False created_at: Optional[datetime] = None + language: Optional[str] = "fr" class TokenResponse(BaseModel): @@ -101,7 +102,7 @@ class ChatbotCreate(BaseModel): name: str = Field(min_length=2, max_length=100) description: Optional[str] = None system_prompt: Optional[str] = None - model: str = "accounts/fireworks/models/kimi-k2-instruct-0905" + model: str = "accounts/fireworks/models/kimi-k2-instruct" @field_validator("name", mode="before") @classmethod @@ -301,6 +302,7 @@ class ChatResponse(BaseModel): tokens_used: int = 0 needs_lead_capture: bool = False handoff: bool = False + low_confidence: bool = False class MessageResponse(BaseModel): @@ -460,6 +462,24 @@ class FeedbackCreate(BaseModel): feedback: str # 'positive' or 'negative' +# ─── Test Models ────────────────────────────────────────────────────────────── + +class TestQuestion(BaseModel): + question: str + + +class TestChatRequest(BaseModel): + questions: List[str] = Field(min_length=1, max_length=10) + + +class TestChatResult(BaseModel): + question: str + response: str + confidence_score: float + sources: List[SourceDocument] + model_used: str + + # ─── Inbox Models ───────────────────────────────────────────────────────────── class InboxConversation(BaseModel): diff --git a/app/routers/analytics.py b/app/routers/analytics.py index 20843ec..a113c0e 100644 --- a/app/routers/analytics.py +++ b/app/routers/analytics.py @@ -480,7 +480,7 @@ async def get_knowledge_gaps(chatbot_id: str, user=Depends(get_current_user)): low_conf = supabase.table("messages").select("id, conversation_id, created_at") \ .in_("conversation_id", conv_ids[:100]) \ .eq("role", "assistant") \ - .lt("confidence_score", 0.2) \ + .lt("confidence_score", 0.55) \ .limit(100).execute() if not low_conf.data: diff --git a/app/routers/auth.py b/app/routers/auth.py index 2f1f54e..288ee19 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -24,6 +24,7 @@ class ProfileUpdate(BaseModel): company_name: Optional[str] = None current_password: Optional[str] = None new_password: Optional[str] = Field(default=None, min_length=8) + language: Optional[str] = None @router.post("/signup", response_model=TokenResponse) @@ -116,12 +117,14 @@ async def login(data: UserLogin): ) plan = sub.data[0]["plan"] if sub.data else "free" - # Get is_admin flag + # Get is_admin and language from profile try: - profile = supabase.table("user_profiles").select("is_admin").eq("user_id", user.id).execute() + profile = supabase.table("user_profiles").select("is_admin, language").eq("user_id", user.id).execute() is_admin = profile.data[0].get("is_admin", False) if profile.data else False + language = profile.data[0].get("language", "fr") if profile.data else "fr" except Exception: is_admin = False + language = "fr" return TokenResponse( access_token=auth_resp.session.access_token, @@ -131,6 +134,7 @@ async def login(data: UserLogin): company_name=company_name, plan=plan, is_admin=is_admin, + language=language, ), ) except HTTPException: @@ -191,12 +195,22 @@ async def update_profile(data: ProfileUpdate, user=Depends(get_current_user)): raise HTTPException(status_code=400, detail="Current password is incorrect") supabase.auth.admin.update_user_by_id(user.id, {"password": data.new_password}) + if data.language: + supabase.table("user_profiles").update({"language": data.language}).eq("user_id", user.id).execute() + company = supabase.table("companies").select("name").eq("owner_id", user.id).execute() company_name = company.data[0]["name"] if company.data else "" sub = supabase.table("subscriptions").select("plan").eq("user_id", user.id).eq("status", "active").execute() plan = sub.data[0]["plan"] if sub.data else "free" + try: + profile = supabase.table("user_profiles").select("is_admin, language").eq("user_id", user.id).execute() + is_admin = profile.data[0].get("is_admin", False) if profile.data else False + language = profile.data[0].get("language", "fr") if profile.data else "fr" + except Exception: + is_admin = False + language = data.language or "fr" - return UserResponse(id=user.id, email=user.email, company_name=company_name, plan=plan) + return UserResponse(id=user.id, email=user.email, company_name=company_name, plan=plan, is_admin=is_admin, language=language) @router.delete("/account") @@ -246,10 +260,12 @@ async def get_me(user=Depends(get_current_user)): plan = sub.data[0]["plan"] if sub.data else "free" try: - profile = supabase.table("user_profiles").select("is_admin").eq("user_id", user.id).execute() + profile = supabase.table("user_profiles").select("is_admin, language").eq("user_id", user.id).execute() is_admin = profile.data[0].get("is_admin", False) if profile.data else False + language = profile.data[0].get("language", "fr") if profile.data else "fr" except Exception: is_admin = False + language = "fr" return UserResponse( id=user.id, @@ -257,4 +273,5 @@ async def get_me(user=Depends(get_current_user)): company_name=company_name, plan=plan, is_admin=is_admin, + language=language, ) diff --git a/app/routers/channels.py b/app/routers/channels.py index 807ee3c..8daaa83 100644 --- a/app/routers/channels.py +++ b/app/routers/channels.py @@ -245,6 +245,13 @@ async def telegram_webhook(bot_token: str, request: Request): await tg_send(bot_token, chat_id, welcome) return {"ok": True} + if text == "/owner": + supabase.table("channel_connections").update( + {"owner_chat_id": str(chat_id)} + ).eq("channel", "telegram").eq("bot_token", bot_token).execute() + await tg_send(bot_token, chat_id, "✅ You're registered as the owner of this bot. You'll receive handoff alerts here when a visitor needs human support.") + return {"ok": True} + # Use first 8 chars of token as namespace to avoid collisions between bots external_id = f"tg:{bot_token[:8]}:{chat_id}" session = _get_or_create_channel_session(chatbot_id, "telegram", external_id, supabase) @@ -274,7 +281,7 @@ async def telegram_webhook(bot_token: str, request: Request): await tg_send(bot_token, chat_id, "Sorry, I encountered an error. Please try again.") return {"ok": True} - confidence_score = max((s.score for s in result.get("sources", [])), default=0.0) + confidence_score = result.get("confidence_score", 0.0) _save_message(conversation["id"], "user", text, supabase) _save_message( conversation["id"], "assistant", result["response"], supabase, diff --git a/app/routers/chat.py b/app/routers/chat.py index 1b448b0..d212356 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -2,7 +2,7 @@ import time from collections import defaultdict from fastapi import APIRouter, HTTPException, Depends, Request -from app.models import ChatMessage, ChatResponse, ConversationResponse, MessageResponse, FeedbackCreate +from app.models import ChatMessage, ChatResponse, ConversationResponse, MessageResponse, FeedbackCreate, TestChatRequest, TestChatResult from app.database import get_supabase from app.dependencies import get_current_user, get_optional_user from app.services.rag import rag_engine @@ -15,6 +15,8 @@ import logging logger = logging.getLogger(__name__) router = APIRouter(tags=["Chat"]) +CONFIDENCE_THRESHOLD = 0.55 + # ── Simple in-memory rate limiter ──────────────────────────────────────────── _rate_store: dict = defaultdict(list) _RATE_LIMIT = 30 # max requests @@ -166,29 +168,27 @@ async def chat( language=message.language, ) - # Compute confidence score - confidence_score = max((s.score for s in result.get("sources", [])), default=0.0) + confidence_score = result.get("confidence_score", 0.0) # Check handoff is_handoff = False + low_confidence = confidence_score < CONFIDENCE_THRESHOLD if chatbot.get("handoff_enabled"): handoff_keywords = chatbot.get("handoff_keywords", []) msg_lower = message.message.lower() - if any(kw.lower() in msg_lower for kw in handoff_keywords): + keyword_triggered = any(kw.lower() in msg_lower for kw in handoff_keywords) + if keyword_triggered or low_confidence: is_handoff = True - # Fire n8n notification (async, non-blocking) try: - from app.services.n8n_service import send_handoff_notification - from app.config import settings as _settings - company_data_for_handoff = chatbot.get("companies", {}) or {} - await send_handoff_notification( + from app.services.notification_service import send_handoff_alert + await send_handoff_alert( + chatbot_id=chatbot_id, chatbot_name=chatbot.get("name", ""), - owner_email=chatbot.get("handoff_email") or "", conversation_history=history, trigger_message=message.message, - chatbot_id=chatbot_id, conversation_id=conversation["id"], - webhook_url=_settings.n8n_handoff_webhook_url, + low_confidence=low_confidence, + supabase=supabase, ) except Exception: pass # never block chat on handoff failure @@ -228,6 +228,7 @@ async def chat( tokens_used=result.get("tokens_used", 0), needs_lead_capture=needs_lead_capture, handoff=is_handoff, + low_confidence=low_confidence, ) @@ -289,6 +290,56 @@ async def submit_feedback(chatbot_id: str, data: FeedbackCreate): return {"success": True} +@router.post("/chat/{chatbot_id}/test", response_model=List[TestChatResult]) +async def test_chat( + chatbot_id: str, + body: TestChatRequest, + user=Depends(get_current_user), +): + """Run test questions against a chatbot without saving to conversation history. Owner only.""" + supabase = get_supabase() + chatbot = _get_public_chatbot(chatbot_id, supabase) + + company = supabase.table("companies").select("id").eq("owner_id", user.id).execute() + if not company.data or company.data[0]["id"] != chatbot.get("company_id"): + raise HTTPException(status_code=403, detail="Access denied") + + collection_name = chatbot.get("qdrant_collection_name") + if not collection_name: + raise HTTPException(status_code=400, detail="Chatbot has no knowledge base configured") + + company_data = chatbot.get("companies", {}) or {} + chatbot_config = {**chatbot, "company_name": company_data.get("name", "")} + + results = [] + for question in body.questions: + try: + result = await rag_engine.process_query( + query=question, + collection_name=collection_name, + chatbot_config=chatbot_config, + conversation_history=[], + language="auto", + bypass_cache=True, + ) + results.append(TestChatResult( + question=question, + response=result["response"], + confidence_score=result.get("confidence_score", 0.0), + sources=result.get("sources", []), + model_used=result.get("model", ""), + )) + except Exception as e: + results.append(TestChatResult( + question=question, + response=f"Error: {e}", + confidence_score=0.0, + sources=[], + model_used="", + )) + return results + + # ── OLD analytics endpoint REMOVED ─────────────────────────────────────────── # The /analytics/{chatbot_id} endpoint that was here has been replaced by # the dedicated analytics router (app/routers/analytics.py) which provides: diff --git a/app/routers/chatbots.py b/app/routers/chatbots.py index cb0d52b..856c2d4 100644 --- a/app/routers/chatbots.py +++ b/app/routers/chatbots.py @@ -296,7 +296,7 @@ def _format_chatbot(chatbot: dict, supabase) -> ChatbotResponse: name=chatbot["name"], description=chatbot.get("description"), system_prompt=chatbot.get("system_prompt"), - model=chatbot.get("model", "accounts/fireworks/models/kimi-k2-instruct-0905"), + model=chatbot.get("model", "accounts/fireworks/models/kimi-k2-instruct"), temperature=chatbot.get("temperature", 0.7), max_tokens=chatbot.get("max_tokens", 1000), primary_color=chatbot.get("primary_color", "#6366f1"), diff --git a/app/routers/documents.py b/app/routers/documents.py index 1cad3e8..7e8faea 100644 --- a/app/routers/documents.py +++ b/app/routers/documents.py @@ -6,6 +6,7 @@ from app.services.document_processor import process_document from app.services.embeddings import embedding_service from app.services.vector_store import vector_store from app.services.storage import delete_from_storage, extract_storage_path +from app.services import cache as response_cache from app.config import settings from typing import List import uuid @@ -166,6 +167,7 @@ async def _process_document_bg( "chunk_count": len(chunks), }).eq("id", doc_id).execute() + response_cache.invalidate(collection_name) logger.info(f"Document {doc_id} processed: {len(chunks)} chunks") except Exception as e: @@ -211,6 +213,8 @@ async def delete_document(chatbot_id: str, document_id: str, user=Depends(get_cu delete_from_storage(supabase, "documents", doc.data[0]["file_url"]) supabase.table("documents").delete().eq("id", document_id).execute() + if collection_name: + response_cache.invalidate(collection_name) return SuccessResponse(success=True, message="Document deleted") @@ -259,6 +263,11 @@ async def retry_document_processing( "chunk_count": 0, }).eq("id", document_id).execute() + # Clear stale cache before re-processing so tests see fresh results + collection_name = chatbot.get("qdrant_collection_name") + if collection_name: + response_cache.invalidate(collection_name) + # Re-enqueue background processing background_tasks.add_task( _process_document_bg, @@ -340,10 +349,56 @@ async def delete_url_source(chatbot_id: str, source_id: str, user=Depends(get_cu if not source.data: raise HTTPException(status_code=404, detail="URL source not found") + chatbot = _get_user_chatbot(chatbot_id, user.id, supabase) + collection_name = chatbot.get("qdrant_collection_name") + if collection_name: + try: + vector_store.delete_by_document_id(collection_name, source_id) + except Exception: + pass + response_cache.invalidate(collection_name) supabase.table("url_sources").delete().eq("id", source_id).execute() return SuccessResponse(success=True, message="URL source deleted") +@url_router.post("/{source_id}/refresh", response_model=UrlSourceResponse) +async def refresh_url_source( + chatbot_id: str, + source_id: str, + background_tasks: BackgroundTasks, + user=Depends(get_current_user), +): + """Re-scrape a URL source and rebuild its vectors.""" + supabase = get_supabase() + chatbot = _get_user_chatbot(chatbot_id, user.id, supabase) + + source = supabase.table("url_sources").select("*").eq("id", source_id).eq("chatbot_id", chatbot_id).execute() + if not source.data: + raise HTTPException(status_code=404, detail="URL source not found") + + src = source.data[0] + collection_name = chatbot.get("qdrant_collection_name") + + # Drop existing vectors for this source + if collection_name: + try: + vector_store.delete_by_document_id(collection_name, source_id) + except Exception as e: + logger.warning(f"Could not delete old vectors for url source {source_id}: {e}") + response_cache.invalidate(collection_name) + + # Reset to pending and reprocess + updated = supabase.table("url_sources").update({ + "status": "pending", + "error_message": None, + "chunk_count": 0, + }).eq("id", source_id).returning("representation").execute() + + background_tasks.add_task(_process_url_source, source_id, src["url"], chatbot, supabase) + + return UrlSourceResponse(**{**src, "status": "pending", "chunk_count": 0}) + + async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase): """Background task to scrape a URL and add its content to the vector store.""" from app.services.web_scraper import scrape_url @@ -424,6 +479,7 @@ async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase) "chunk_count": len(chunks), }).eq("id", source_id).execute() + response_cache.invalidate(collection_name) logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url}") except Exception as e: diff --git a/app/routers/marketplace.py b/app/routers/marketplace.py index 1b712cc..69fa271 100644 --- a/app/routers/marketplace.py +++ b/app/routers/marketplace.py @@ -175,13 +175,21 @@ async def rate_chatbot( user=Depends(get_current_user), ): supabase = get_supabase() - chatbot = supabase.table("chatbots").select("id, average_rating").eq("id", chatbot_id).eq("is_published", True).execute() + chatbot = supabase.table("chatbots") \ + .select("id, average_rating, rating_count") \ + .eq("id", chatbot_id).eq("is_published", True).execute() if not chatbot.data: raise HTTPException(status_code=404, detail="Chatbot not found") - # Simple rating update (average) - current = chatbot.data[0].get("average_rating") or rating.rating - new_avg = (current + rating.rating) / 2 + row = chatbot.data[0] + current_avg = row.get("average_rating") or 0.0 + current_count = row.get("rating_count") or 0 - supabase.table("chatbots").update({"average_rating": round(new_avg, 1)}).eq("id", chatbot_id).execute() - return {"message": "Rating submitted", "new_average": round(new_avg, 1)} \ No newline at end of file + new_count = current_count + 1 + new_avg = round((current_avg * current_count + rating.rating) / new_count, 2) + + supabase.table("chatbots").update({ + "average_rating": new_avg, + "rating_count": new_count, + }).eq("id", chatbot_id).execute() + return {"average_rating": new_avg, "rating_count": new_count} \ No newline at end of file diff --git a/app/services/llm.py b/app/services/llm.py index 989d031..ec6f10c 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -31,14 +31,9 @@ class LLMService: return await self._call_openai(messages, model, max_tokens, temperature) except Exception as e: logger.error(f"LLM error ({provider}/{model}): {e}") - # Fallback to a basic model if available - if model != "accounts/fireworks/models/kimi-k2-instruct-0905" and settings.fireworks_api_key: - return await self._call_fireworks( - messages, - "accounts/fireworks/models/kimi-k2-instruct-0905", - max_tokens, - temperature, - ) + fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct" + if model != fallback and settings.fireworks_api_key: + return await self._call_fireworks(messages, fallback, max_tokens, temperature) raise async def _call_fireworks( diff --git a/app/services/rag.py b/app/services/rag.py index 3c05817..9ef1d25 100644 --- a/app/services/rag.py +++ b/app/services/rag.py @@ -1,8 +1,9 @@ from app.services.embeddings import embedding_service from app.services.vector_store import vector_store from app.services.llm import llm_service +from app.services import cache as response_cache from app.models import SourceDocument -from typing import List, Dict, Any, Optional, Tuple +from typing import List, Dict, Any, Optional import logging logger = logging.getLogger(__name__) @@ -44,6 +45,7 @@ class RAGEngine: chatbot_config: Dict[str, Any], conversation_history: List[Dict[str, str]] = None, language: str = "en", + bypass_cache: bool = False, ) -> Dict[str, Any]: """ Full RAG pipeline: embed → retrieve → generate @@ -51,6 +53,13 @@ class RAGEngine: if conversation_history is None: conversation_history = [] + # Cache hit — only for stateless (no history) queries, and not bypassed + if not conversation_history and not bypass_cache: + cached = response_cache.get(collection_name, query) + if cached is not None: + logger.info(f"[RAG] Cache hit for query in '{collection_name}'") + return cached + # Step 1: Embed the query try: query_embedding = self.embedding_svc.embed_text(query) @@ -65,14 +74,14 @@ class RAGEngine: } # Step 2: Retrieve relevant chunks - # FIX: Lowered score_threshold from 0.3 to 0.1 to avoid filtering out - # all results. With cosine similarity, 0.3 can be too aggressive for - # many document types and query patterns. + # Fetch more than needed so that after filtering low-quality results + # we still have enough context. score_threshold=0.55 keeps only chunks + # that are genuinely relevant for text-embedding-3-small cosine similarity. retrieved = self.vector_svc.search( collection_name=collection_name, query_vector=query_embedding, - limit=5, - score_threshold=0.1, # FIX: was 0.3, now 0.1 to avoid over-filtering + limit=8, + score_threshold=0.55, ) logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'") @@ -108,11 +117,15 @@ class RAGEngine: context = "No relevant information found in the knowledge base." logger.warning(f"[RAG] No context found for query: '{query}' in collection '{collection_name}'") + # Confidence: mean of top-3 scores (more stable than max alone) + top_scores = sorted([s.score for s in sources], reverse=True)[:3] + confidence_score = round(sum(top_scores) / len(top_scores), 4) if top_scores else 0.0 + # Step 4: Build messages - lang_name = LANGUAGE_NAMES.get(language, "English") if language and language != "en" else "" language_instruction = ( - f"\n6. Respond in {lang_name}. Match the language of the user's message." - if lang_name else "" + "\n6. CRITICAL: Always reply in the exact same language the user wrote in. " + "If they write in French, reply in French. If Spanish, reply in Spanish. " + "Never switch to English unless the user writes in English." ) system_prompt = RAG_SYSTEM_PROMPT.format( @@ -137,7 +150,7 @@ class RAGEngine: logger.info(f"[RAG] Sending {len(messages)} messages to LLM (model: {chatbot_config.get('model')})") # Step 5: Generate response - model = chatbot_config.get("model", "accounts/fireworks/models/kimi-k2-instruct-0905") + model = chatbot_config.get("model", "accounts/fireworks/models/kimi-k2-instruct") try: result = await self.llm_svc.generate( messages=messages, @@ -146,17 +159,22 @@ class RAGEngine: temperature=chatbot_config.get("temperature", 0.7), ) logger.info(f"[RAG] LLM response generated. Tokens used: {result.get('tokens_used', 0)}") - return { + payload = { "response": result["content"], "sources": sources, + "confidence_score": confidence_score, "tokens_used": result.get("tokens_used", 0), "model": result.get("model", model), } + if not conversation_history and not bypass_cache: + response_cache.set(collection_name, query, payload) + return payload except Exception as e: logger.error(f"[RAG] LLM generation error: {e}", exc_info=True) return { "response": "I'm having trouble generating a response. Please try again later.", "sources": sources, + "confidence_score": confidence_score, "tokens_used": 0, "model": model, } diff --git a/tests/test_chat.py b/tests/test_chat.py index dd800b7..34368a4 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -206,9 +206,11 @@ class TestChatHandoff: assert resp.status_code == 200 assert resp.json()["handoff"] is True - def test_handoff_not_triggered_without_keyword_match(self, client): + def test_handoff_not_triggered_when_high_confidence_no_keyword(self, client): + # High confidence + no keyword match → no handoff chatbot = _make_chatbot(handoff_enabled=True, handoff_keywords=["human"]) - rag_result = {"response": "Sure!", "sources": [], "model": "m", "tokens_used": 5} + rag_result = {"response": "We open at 9am.", "sources": [], "confidence_score": 0.85, + "model": "m", "tokens_used": 5} with patch("app.routers.chat.get_supabase") as mock_sb, \ patch("app.routers.chat.rag_engine") as mock_rag: @@ -219,6 +221,22 @@ class TestChatHandoff: assert resp.status_code == 200 assert resp.json()["handoff"] is False + def test_handoff_triggered_by_low_confidence(self, client): + # Low confidence (no sources) triggers handoff even without a keyword + chatbot = _make_chatbot(handoff_enabled=True, handoff_keywords=["human"]) + rag_result = {"response": "I'm not sure.", "sources": [], "confidence_score": 0.1, + "model": "m", "tokens_used": 5} + + with patch("app.routers.chat.get_supabase") as mock_sb, \ + patch("app.routers.chat.rag_engine") as mock_rag: + mock_rag.process_query = AsyncMock(return_value=rag_result) + mock_sb.return_value = _make_chat_sb(chatbot=chatbot) + resp = client.post("/api/v1/chat/cb-1", json={"message": "Tell me about quantum physics", "language": "en"}) + + assert resp.status_code == 200 + assert resp.json()["handoff"] is True + assert resp.json()["low_confidence"] is True + class TestChatHistory: def test_history_returns_empty_for_unknown_session(self, client): diff --git a/tests/test_documents.py b/tests/test_documents.py index 8fd0313..441e1e6 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -216,8 +216,10 @@ class TestDocumentDelete: "file_url": None, } with patch("app.routers.documents.get_supabase") as mock_sb, \ - patch("app.routers.documents.vector_store") as mock_vs: + patch("app.routers.documents.vector_store") as mock_vs, \ + patch("app.routers.documents.response_cache") as mock_cache: mock_vs.delete_by_document_id = MagicMock() + mock_cache.invalidate = MagicMock() mock_sb.return_value = _make_doc_sb(doc=doc) resp = client.delete("/api/v1/chatbots/cb-1/documents/doc-1", headers=AUTH) assert resp.status_code == 200 diff --git a/tests/test_marketplace.py b/tests/test_marketplace.py index 6c81026..75f9a87 100644 --- a/tests/test_marketplace.py +++ b/tests/test_marketplace.py @@ -185,7 +185,7 @@ class TestMarketplaceRating: assert resp.status_code == 404 def test_rate_chatbot_success(self, client): - bot = {"id": "bot-1", "average_rating": 4.0} + bot = {"id": "bot-1", "average_rating": 4.0, "rating_count": 1} with patch("app.routers.marketplace.get_supabase") as mock_sb: mock_sb.return_value = _make_marketplace_sb(chatbot_data=[bot]) resp = client.post( @@ -195,12 +195,12 @@ class TestMarketplaceRating: ) assert resp.status_code == 200 body = resp.json() - assert "new_average" in body - assert body["new_average"] == 4.5 # (4.0 + 5) / 2 + assert "average_rating" in body + assert body["average_rating"] == 4.5 # (4.0 * 1 + 5) / 2 def test_rate_chatbot_first_rating(self, client): - """When average_rating is None, should use the submitted rating as both sides.""" - bot = {"id": "bot-1", "average_rating": None} + """When average_rating is None, should use the submitted rating as the new average.""" + bot = {"id": "bot-1", "average_rating": None, "rating_count": 0} with patch("app.routers.marketplace.get_supabase") as mock_sb: mock_sb.return_value = _make_marketplace_sb(chatbot_data=[bot]) resp = client.post( @@ -209,4 +209,4 @@ class TestMarketplaceRating: headers=AUTH, ) assert resp.status_code == 200 - assert resp.json()["new_average"] == 5.0 + assert resp.json()["average_rating"] == 5.0