diff --git a/app/services/cache.py b/app/services/cache.py new file mode 100644 index 0000000..1379834 --- /dev/null +++ b/app/services/cache.py @@ -0,0 +1,54 @@ +""" +In-memory TTL cache for RAG responses. +Keyed on (collection_name, normalised query). +Only used for stateless queries (no conversation history). +""" +import hashlib +import time +from typing import Any, Optional + +_store: dict[str, tuple[Any, float]] = {} +_index: dict[str, set[str]] = {} # collection_name β†’ set of cache keys +_MAX_ENTRIES = 500 +TTL = 6 * 3600 # 6 hours +_set_ctor = set # save builtin before it's shadowed by our module-level `set` function + + +def _make_key(collection_name: str, query: str) -> str: + raw = f"{collection_name}::{query.lower().strip()}" + return hashlib.sha256(raw.encode()).hexdigest() + + +def get(collection_name: str, query: str) -> Optional[Any]: + k = _make_key(collection_name, query) + entry = _store.get(k) + if entry is None: + return None + value, expires_at = entry + if time.monotonic() > expires_at: + _store.pop(k, None) + _index.get(collection_name, _set_ctor()).discard(k) + return None + return value + + +def set(collection_name: str, query: str, value: Any) -> None: + if len(_store) >= _MAX_ENTRIES: + now = time.monotonic() + expired = [k for k, (_, exp) in _store.items() if exp < now] + for k in expired: + _store.pop(k, None) + if len(_store) >= _MAX_ENTRIES: + oldest = min(_store, key=lambda k: _store[k][1]) + _store.pop(oldest, None) + k = _make_key(collection_name, query) + _store[k] = (value, time.monotonic() + TTL) + _index.setdefault(collection_name, _set_ctor()).add(k) + + +def invalidate(collection_name: str) -> int: + """Drop all cached entries for a collection β€” call after any KB update.""" + keys = _index.pop(collection_name, _set_ctor()) + for k in keys: + _store.pop(k, None) + return len(keys) diff --git a/app/services/notification_service.py b/app/services/notification_service.py new file mode 100644 index 0000000..093fb90 --- /dev/null +++ b/app/services/notification_service.py @@ -0,0 +1,69 @@ +import logging +from typing import List, Optional +from app.services.telegram_service import send_message as tg_send + +logger = logging.getLogger(__name__) + + +async def send_handoff_alert( + chatbot_id: str, + chatbot_name: str, + conversation_history: List[dict], + trigger_message: str, + conversation_id: str, + low_confidence: bool, + supabase, +) -> bool: + """ + Notify the chatbot owner via their connected Telegram bot. + Owner must have sent /owner to their bot to register their chat_id. + Returns True if notification was sent. + """ + try: + conn = ( + supabase.table("channel_connections") + .select("bot_token, owner_chat_id") + .eq("chatbot_id", chatbot_id) + .eq("channel", "telegram") + .eq("is_active", True) + .execute() + ) + if not conn.data: + logger.info(f"No Telegram connection for chatbot {chatbot_id}") + return False + + bot_token: Optional[str] = conn.data[0].get("bot_token") + owner_chat_id = conn.data[0].get("owner_chat_id") + + if not bot_token or not owner_chat_id: + logger.info( + f"Owner has not registered for notifications on chatbot {chatbot_id}. " + "They should send /owner to their Telegram bot." + ) + return False + + from app.config import settings + recent = conversation_history[-4:] + history_lines = "\n".join( + f"{'πŸ‘€' if m['role'] == 'user' else 'πŸ€–'} {m['content'][:120]}" + for m in recent + ) + + reason = "❓ Bot couldn't answer confidently" if low_confidence else "πŸ™‹ User requested a human" + inbox_url = f"{settings.app_url}/inbox" + + text = ( + f"πŸ”” *Handoff β€” {chatbot_name}*\n" + f"Reason: {reason}\n\n" + f"Last messages:\n{history_lines}\n\n" + f"β–Ά [Open in inbox]({inbox_url})" + ) + + sent = await tg_send(bot_token, owner_chat_id, text) + if sent: + logger.info(f"Handoff alert sent to owner for chatbot {chatbot_id}") + return sent + + except Exception as e: + logger.error(f"Failed to send handoff alert for chatbot {chatbot_id}: {e}") + return False diff --git a/migrations/004_user_language.sql b/migrations/004_user_language.sql new file mode 100644 index 0000000..bc53ae6 --- /dev/null +++ b/migrations/004_user_language.sql @@ -0,0 +1,3 @@ +-- Add language preference to user_profiles +ALTER TABLE user_profiles + ADD COLUMN IF NOT EXISTS language VARCHAR(10) NOT NULL DEFAULT 'fr'; diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..ccd65fd --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,87 @@ +"""Tests for the in-memory response cache.""" +import time +import pytest +from app.services import cache + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Wipe cache state before each test.""" + cache._store.clear() + cache._index.clear() + yield + cache._store.clear() + cache._index.clear() + + +class TestCacheGetSet: + def test_miss_on_empty_cache(self): + assert cache.get("col-1", "hello") is None + + def test_set_then_get_returns_value(self): + payload = {"response": "Hi", "sources": []} + cache.set("col-1", "hello", payload) + assert cache.get("col-1", "hello") == payload + + def test_different_collections_are_independent(self): + cache.set("col-a", "query", {"response": "A"}) + cache.set("col-b", "query", {"response": "B"}) + assert cache.get("col-a", "query")["response"] == "A" + assert cache.get("col-b", "query")["response"] == "B" + + def test_query_normalisation_ignores_case_and_whitespace(self): + cache.set("col-1", " Hello World ", {"response": "hi"}) + assert cache.get("col-1", "hello world") is not None + assert cache.get("col-1", "HELLO WORLD") is not None + + def test_overwrite_updates_value(self): + cache.set("col-1", "q", {"response": "old"}) + cache.set("col-1", "q", {"response": "new"}) + assert cache.get("col-1", "q")["response"] == "new" + + +class TestCacheExpiry: + def test_expired_entry_returns_none(self, monkeypatch): + cache.set("col-1", "q", {"response": "old"}) + # Manually expire the entry + k = list(cache._store.keys())[0] + cache._store[k] = (cache._store[k][0], time.monotonic() - 1) + assert cache.get("col-1", "q") is None + + def test_expired_entry_is_evicted_from_store(self, monkeypatch): + cache.set("col-1", "q", {"response": "x"}) + k = list(cache._store.keys())[0] + cache._store[k] = (cache._store[k][0], time.monotonic() - 1) + cache.get("col-1", "q") + assert k not in cache._store + + +class TestCacheInvalidation: + def test_invalidate_removes_all_entries_for_collection(self): + cache.set("col-1", "query a", {"response": "a"}) + cache.set("col-1", "query b", {"response": "b"}) + cache.set("col-2", "query a", {"response": "c"}) + + removed = cache.invalidate("col-1") + + assert removed == 2 + assert cache.get("col-1", "query a") is None + assert cache.get("col-1", "query b") is None + # Other collection unaffected + assert cache.get("col-2", "query a") is not None + + def test_invalidate_unknown_collection_returns_zero(self): + assert cache.invalidate("nonexistent") == 0 + + def test_index_cleaned_up_after_invalidation(self): + cache.set("col-1", "q", {"response": "x"}) + cache.invalidate("col-1") + assert "col-1" not in cache._index + + +class TestCacheEviction: + def test_does_not_exceed_max_entries(self, monkeypatch): + monkeypatch.setattr(cache, "_MAX_ENTRIES", 5) + for i in range(10): + cache.set("col-1", f"query {i}", {"response": str(i)}) + assert len(cache._store) <= 5 diff --git a/tests/test_chat_test_endpoint.py b/tests/test_chat_test_endpoint.py new file mode 100644 index 0000000..a62827c --- /dev/null +++ b/tests/test_chat_test_endpoint.py @@ -0,0 +1,122 @@ +"""Tests for the /chat/{id}/test endpoint.""" +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +AUTH = {"Authorization": "Bearer test-token"} + + +def _make_chatbot(): + return { + "id": "cb-1", + "name": "Test Bot", + "is_published": True, + "qdrant_collection_name": "col-1", + "company_id": "company-1", + "handoff_enabled": False, + "handoff_keywords": [], + "lead_capture_enabled": False, + "lead_capture_trigger": None, + "booking_enabled": False, + "system_prompt": "", + "companies": {"name": "Acme", "logo_url": None}, + } + + +def _make_sb(chatbot=None, is_owner=True): + sb = MagicMock() + + def table_side(name): + m = MagicMock() + m.select.return_value = m + m.eq.return_value = m + m.in_.return_value = m + m.limit.return_value = m + m.order.return_value = m + + if name == "chatbots": + m.execute.return_value = MagicMock(data=[chatbot or _make_chatbot()]) + elif name == "companies": + cid = "company-1" if is_owner else "other-company" + m.execute.return_value = MagicMock(data=[{"id": cid, "owner_id": "owner-1"}]) + else: + m.execute.return_value = MagicMock(data=[], count=0) + return m + + sb.table.side_effect = table_side + sb.auth = MagicMock() + return sb + + +class TestChatTestEndpoint: + def _run_test(self, client, questions, chatbot=None, is_owner=True, rag_result=None): + default_rag = { + "response": "The answer is 42.", + "sources": [], + "confidence_score": 0.82, + "tokens_used": 20, + "model": "test-model", + } + with patch("app.routers.chat.get_supabase") as mock_sb, \ + patch("app.routers.chat.rag_engine") as mock_rag, \ + patch("app.dependencies.get_current_user") as mock_user: + mock_rag.process_query = AsyncMock(return_value=rag_result or default_rag) + mock_sb.return_value = _make_sb(chatbot=chatbot, is_owner=is_owner) + user = MagicMock() + user.id = "owner-1" + mock_user.return_value = user + return client.post( + "/api/v1/chat/cb-1/test", + json={"questions": questions}, + headers=AUTH, + ) + + def test_returns_list_of_results(self, client): + resp = self._run_test(client, ["What is your return policy?"]) + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + assert len(body) == 1 + + def test_result_shape(self, client): + resp = self._run_test(client, ["Hello?"]) + result = resp.json()[0] + assert "question" in result + assert "response" in result + assert "confidence_score" in result + assert "sources" in result + assert "model_used" in result + + def test_question_echoed_in_result(self, client): + resp = self._run_test(client, ["What are your hours?"]) + assert resp.json()[0]["question"] == "What are your hours?" + + def test_multiple_questions_all_answered(self, client): + questions = ["Q1", "Q2", "Q3"] + resp = self._run_test(client, questions) + assert len(resp.json()) == 3 + returned_questions = [r["question"] for r in resp.json()] + assert returned_questions == questions + + def test_requires_authentication(self, client): + resp = client.post("/api/v1/chat/cb-1/test", json={"questions": ["hi"]}) + assert resp.status_code == 401 + + def test_rejects_more_than_10_questions(self, client): + resp = self._run_test(client, [f"Q{i}" for i in range(11)]) + assert resp.status_code == 422 + + def test_rejects_empty_question_list(self, client): + resp = self._run_test(client, []) + assert resp.status_code == 422 + + def test_chatbot_without_collection_returns_400(self, client): + bot = _make_chatbot() + bot["qdrant_collection_name"] = None + resp = self._run_test(client, ["Hi"], chatbot=bot) + assert resp.status_code == 400 + + def test_confidence_score_passed_through(self, client): + rag_result = {"response": "Sure", "sources": [], "confidence_score": 0.73, + "tokens_used": 5, "model": "m"} + resp = self._run_test(client, ["Question?"], rag_result=rag_result) + assert resp.json()[0]["confidence_score"] == pytest.approx(0.73) diff --git a/tests/test_rag_cache.py b/tests/test_rag_cache.py new file mode 100644 index 0000000..0ccb82a --- /dev/null +++ b/tests/test_rag_cache.py @@ -0,0 +1,94 @@ +"""Tests for RAG response caching integration.""" +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from app.services import cache as response_cache + + +@pytest.fixture(autouse=True) +def clear_cache(): + response_cache._store.clear() + response_cache._index.clear() + yield + response_cache._store.clear() + response_cache._index.clear() + + +@pytest.fixture +def rag(): + from app.services.rag import RAGEngine + return RAGEngine() + + +@pytest.fixture +def chatbot_config(): + return { + "model": "accounts/fireworks/models/llama-v3p3-70b-instruct", + "max_tokens": 500, + "temperature": 0.7, + "company_name": "Test Corp", + "system_prompt": "", + } + + +@pytest.fixture +def good_search_result(): + return [{ + "payload": {"text": "We open 9am–6pm Mon–Fri.", "file_name": "faq.pdf", "page_number": 1}, + "score": 0.82, + }] + + +class TestRAGCaching: + async def test_second_identical_query_uses_cache(self, rag, chatbot_config, good_search_result): + llm_mock = AsyncMock(return_value={"content": "9am to 6pm", "tokens_used": 20, "model": "m"}) + + with patch.object(rag.embedding_svc, "embed_text", return_value=[0.1] * 1536), \ + patch.object(rag.vector_svc, "search", return_value=good_search_result), \ + patch.object(rag.llm_svc, "generate", llm_mock): + + await rag.process_query("What are your hours?", "col-1", chatbot_config) + await rag.process_query("What are your hours?", "col-1", chatbot_config) + + # LLM should only be called once; second call hits cache + assert llm_mock.call_count == 1 + + async def test_cache_not_used_when_conversation_history_present(self, rag, chatbot_config, good_search_result): + llm_mock = AsyncMock(return_value={"content": "Yes!", "tokens_used": 10, "model": "m"}) + + history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}] + + with patch.object(rag.embedding_svc, "embed_text", return_value=[0.1] * 1536), \ + patch.object(rag.vector_svc, "search", return_value=good_search_result), \ + patch.object(rag.llm_svc, "generate", llm_mock): + + await rag.process_query("Follow-up question", "col-1", chatbot_config, conversation_history=history) + await rag.process_query("Follow-up question", "col-1", chatbot_config, conversation_history=history) + + # Both calls go to LLM because history makes them stateful + assert llm_mock.call_count == 2 + + async def test_different_collections_cached_separately(self, rag, chatbot_config, good_search_result): + llm_mock = AsyncMock(return_value={"content": "Answer", "tokens_used": 10, "model": "m"}) + + with patch.object(rag.embedding_svc, "embed_text", return_value=[0.1] * 1536), \ + patch.object(rag.vector_svc, "search", return_value=good_search_result), \ + patch.object(rag.llm_svc, "generate", llm_mock): + + await rag.process_query("Same question", "col-A", chatbot_config) + await rag.process_query("Same question", "col-B", chatbot_config) + + # Different collections β†’ two LLM calls, not one + assert llm_mock.call_count == 2 + + async def test_confidence_score_returned_from_cache(self, rag, chatbot_config, good_search_result): + llm_mock = AsyncMock(return_value={"content": "Cached answer", "tokens_used": 10, "model": "m"}) + + with patch.object(rag.embedding_svc, "embed_text", return_value=[0.1] * 1536), \ + patch.object(rag.vector_svc, "search", return_value=good_search_result), \ + patch.object(rag.llm_svc, "generate", llm_mock): + + first = await rag.process_query("hours?", "col-1", chatbot_config) + second = await rag.process_query("hours?", "col-1", chatbot_config) + + assert first["confidence_score"] == second["confidence_score"] + assert second["response"] == "Cached answer" diff --git a/tests/test_url_refresh.py b/tests/test_url_refresh.py new file mode 100644 index 0000000..780a65b --- /dev/null +++ b/tests/test_url_refresh.py @@ -0,0 +1,115 @@ +"""Tests for URL source refresh endpoint.""" +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +AUTH = {"Authorization": "Bearer test-token"} + + +def _make_sb(source=None, chatbot_company="company-1"): + sb = MagicMock() + default_source = { + "id": "src-1", + "chatbot_id": "cb-1", + "url": "https://example.com/faq", + "status": "completed", + "page_title": "FAQ", + "chunk_count": 10, + "error_message": None, + } + + def table_side(name): + m = MagicMock() + m.select.return_value = m + m.insert.return_value = m + m.update.return_value = m + m.delete.return_value = m + m.eq.return_value = m + m.in_.return_value = m + m.returning.return_value = m + m.limit.return_value = m + m.order.return_value = m + + if name == "chatbots": + m.execute.return_value = MagicMock(data=[{ + "id": "cb-1", + "company_id": chatbot_company, + "qdrant_collection_name": "col-1", + }]) + elif name == "companies": + m.execute.return_value = MagicMock(data=[{"id": chatbot_company, "owner_id": "user-1"}]) + elif name == "url_sources": + m.execute.return_value = MagicMock(data=[source or default_source]) + else: + m.execute.return_value = MagicMock(data=[], count=0) + return m + + sb.table.side_effect = table_side + sb.auth = MagicMock() + return sb + + +class TestUrlRefresh: + def _refresh(self, client, source=None): + with patch("app.routers.documents.get_supabase") as mock_sb, \ + patch("app.routers.documents.vector_store") as mock_vs, \ + patch("app.routers.documents.response_cache") as mock_cache, \ + patch("app.dependencies.get_current_user") as mock_user, \ + patch("app.routers.documents._process_url_source", new_callable=AsyncMock): + mock_sb.return_value = _make_sb(source=source) + mock_vs.delete_by_document_id = MagicMock() + mock_vs.collection_exists = MagicMock(return_value=True) + mock_cache.invalidate = MagicMock() + user = MagicMock() + user.id = "user-1" + mock_user.return_value = user + return client.post( + "/api/v1/chatbots/cb-1/url-sources/src-1/refresh", + headers=AUTH, + ) + + def test_returns_200(self, client): + resp = self._refresh(client) + assert resp.status_code == 200 + + def test_source_reset_to_pending(self, client): + resp = self._refresh(client) + body = resp.json() + assert body["status"] == "pending" + assert body["chunk_count"] == 0 + + def test_returns_404_for_unknown_source(self, client): + with patch("app.routers.documents.get_supabase") as mock_sb, \ + patch("app.dependencies.get_current_user") as mock_user: + sb = _make_sb() + # Override url_sources to return empty + def table_side(name): + m = MagicMock() + m.select.return_value = m + m.update.return_value = m + m.eq.return_value = m + if name == "chatbots": + m.execute.return_value = MagicMock(data=[{ + "id": "cb-1", "company_id": "company-1", + "qdrant_collection_name": "col-1", + }]) + elif name == "companies": + m.execute.return_value = MagicMock(data=[{"id": "company-1", "owner_id": "user-1"}]) + else: + m.execute.return_value = MagicMock(data=[]) + return m + sb2 = MagicMock() + sb2.table.side_effect = table_side + sb2.auth = MagicMock() + mock_sb.return_value = sb2 + user = MagicMock() + user.id = "user-1" + mock_user.return_value = user + resp = client.post( + "/api/v1/chatbots/cb-1/url-sources/no-such-src/refresh", + headers=AUTH, + ) + assert resp.status_code == 404 + + def test_requires_authentication(self, client): + resp = client.post("/api/v1/chatbots/cb-1/url-sources/src-1/refresh") + assert resp.status_code == 401