mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-12 23:23:21 +00:00
fixed the RAg in test pipeline issue
This commit is contained in:
54
app/services/cache.py
Normal file
54
app/services/cache.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""
|
||||||
|
In-memory TTL cache for RAG responses.
|
||||||
|
Keyed on (collection_name, normalised query).
|
||||||
|
Only used for stateless queries (no conversation history).
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
_store: dict[str, tuple[Any, float]] = {}
|
||||||
|
_index: dict[str, set[str]] = {} # collection_name → set of cache keys
|
||||||
|
_MAX_ENTRIES = 500
|
||||||
|
TTL = 6 * 3600 # 6 hours
|
||||||
|
_set_ctor = set # save builtin before it's shadowed by our module-level `set` function
|
||||||
|
|
||||||
|
|
||||||
|
def _make_key(collection_name: str, query: str) -> str:
|
||||||
|
raw = f"{collection_name}::{query.lower().strip()}"
|
||||||
|
return hashlib.sha256(raw.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def get(collection_name: str, query: str) -> Optional[Any]:
|
||||||
|
k = _make_key(collection_name, query)
|
||||||
|
entry = _store.get(k)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
value, expires_at = entry
|
||||||
|
if time.monotonic() > expires_at:
|
||||||
|
_store.pop(k, None)
|
||||||
|
_index.get(collection_name, _set_ctor()).discard(k)
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def set(collection_name: str, query: str, value: Any) -> None:
|
||||||
|
if len(_store) >= _MAX_ENTRIES:
|
||||||
|
now = time.monotonic()
|
||||||
|
expired = [k for k, (_, exp) in _store.items() if exp < now]
|
||||||
|
for k in expired:
|
||||||
|
_store.pop(k, None)
|
||||||
|
if len(_store) >= _MAX_ENTRIES:
|
||||||
|
oldest = min(_store, key=lambda k: _store[k][1])
|
||||||
|
_store.pop(oldest, None)
|
||||||
|
k = _make_key(collection_name, query)
|
||||||
|
_store[k] = (value, time.monotonic() + TTL)
|
||||||
|
_index.setdefault(collection_name, _set_ctor()).add(k)
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate(collection_name: str) -> int:
|
||||||
|
"""Drop all cached entries for a collection — call after any KB update."""
|
||||||
|
keys = _index.pop(collection_name, _set_ctor())
|
||||||
|
for k in keys:
|
||||||
|
_store.pop(k, None)
|
||||||
|
return len(keys)
|
||||||
69
app/services/notification_service.py
Normal file
69
app/services/notification_service.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
from app.services.telegram_service import send_message as tg_send
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_handoff_alert(
|
||||||
|
chatbot_id: str,
|
||||||
|
chatbot_name: str,
|
||||||
|
conversation_history: List[dict],
|
||||||
|
trigger_message: str,
|
||||||
|
conversation_id: str,
|
||||||
|
low_confidence: bool,
|
||||||
|
supabase,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Notify the chatbot owner via their connected Telegram bot.
|
||||||
|
Owner must have sent /owner to their bot to register their chat_id.
|
||||||
|
Returns True if notification was sent.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
conn = (
|
||||||
|
supabase.table("channel_connections")
|
||||||
|
.select("bot_token, owner_chat_id")
|
||||||
|
.eq("chatbot_id", chatbot_id)
|
||||||
|
.eq("channel", "telegram")
|
||||||
|
.eq("is_active", True)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
if not conn.data:
|
||||||
|
logger.info(f"No Telegram connection for chatbot {chatbot_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
bot_token: Optional[str] = conn.data[0].get("bot_token")
|
||||||
|
owner_chat_id = conn.data[0].get("owner_chat_id")
|
||||||
|
|
||||||
|
if not bot_token or not owner_chat_id:
|
||||||
|
logger.info(
|
||||||
|
f"Owner has not registered for notifications on chatbot {chatbot_id}. "
|
||||||
|
"They should send /owner to their Telegram bot."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
recent = conversation_history[-4:]
|
||||||
|
history_lines = "\n".join(
|
||||||
|
f"{'👤' if m['role'] == 'user' else '🤖'} {m['content'][:120]}"
|
||||||
|
for m in recent
|
||||||
|
)
|
||||||
|
|
||||||
|
reason = "❓ Bot couldn't answer confidently" if low_confidence else "🙋 User requested a human"
|
||||||
|
inbox_url = f"{settings.app_url}/inbox"
|
||||||
|
|
||||||
|
text = (
|
||||||
|
f"🔔 *Handoff — {chatbot_name}*\n"
|
||||||
|
f"Reason: {reason}\n\n"
|
||||||
|
f"Last messages:\n{history_lines}\n\n"
|
||||||
|
f"▶ [Open in inbox]({inbox_url})"
|
||||||
|
)
|
||||||
|
|
||||||
|
sent = await tg_send(bot_token, owner_chat_id, text)
|
||||||
|
if sent:
|
||||||
|
logger.info(f"Handoff alert sent to owner for chatbot {chatbot_id}")
|
||||||
|
return sent
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send handoff alert for chatbot {chatbot_id}: {e}")
|
||||||
|
return False
|
||||||
3
migrations/004_user_language.sql
Normal file
3
migrations/004_user_language.sql
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
-- Add language preference to user_profiles
|
||||||
|
ALTER TABLE user_profiles
|
||||||
|
ADD COLUMN IF NOT EXISTS language VARCHAR(10) NOT NULL DEFAULT 'fr';
|
||||||
87
tests/test_cache.py
Normal file
87
tests/test_cache.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Tests for the in-memory response cache."""
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from app.services import cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Wipe cache state before each test."""
|
||||||
|
cache._store.clear()
|
||||||
|
cache._index.clear()
|
||||||
|
yield
|
||||||
|
cache._store.clear()
|
||||||
|
cache._index.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheGetSet:
|
||||||
|
def test_miss_on_empty_cache(self):
|
||||||
|
assert cache.get("col-1", "hello") is None
|
||||||
|
|
||||||
|
def test_set_then_get_returns_value(self):
|
||||||
|
payload = {"response": "Hi", "sources": []}
|
||||||
|
cache.set("col-1", "hello", payload)
|
||||||
|
assert cache.get("col-1", "hello") == payload
|
||||||
|
|
||||||
|
def test_different_collections_are_independent(self):
|
||||||
|
cache.set("col-a", "query", {"response": "A"})
|
||||||
|
cache.set("col-b", "query", {"response": "B"})
|
||||||
|
assert cache.get("col-a", "query")["response"] == "A"
|
||||||
|
assert cache.get("col-b", "query")["response"] == "B"
|
||||||
|
|
||||||
|
def test_query_normalisation_ignores_case_and_whitespace(self):
|
||||||
|
cache.set("col-1", " Hello World ", {"response": "hi"})
|
||||||
|
assert cache.get("col-1", "hello world") is not None
|
||||||
|
assert cache.get("col-1", "HELLO WORLD") is not None
|
||||||
|
|
||||||
|
def test_overwrite_updates_value(self):
|
||||||
|
cache.set("col-1", "q", {"response": "old"})
|
||||||
|
cache.set("col-1", "q", {"response": "new"})
|
||||||
|
assert cache.get("col-1", "q")["response"] == "new"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheExpiry:
|
||||||
|
def test_expired_entry_returns_none(self, monkeypatch):
|
||||||
|
cache.set("col-1", "q", {"response": "old"})
|
||||||
|
# Manually expire the entry
|
||||||
|
k = list(cache._store.keys())[0]
|
||||||
|
cache._store[k] = (cache._store[k][0], time.monotonic() - 1)
|
||||||
|
assert cache.get("col-1", "q") is None
|
||||||
|
|
||||||
|
def test_expired_entry_is_evicted_from_store(self, monkeypatch):
|
||||||
|
cache.set("col-1", "q", {"response": "x"})
|
||||||
|
k = list(cache._store.keys())[0]
|
||||||
|
cache._store[k] = (cache._store[k][0], time.monotonic() - 1)
|
||||||
|
cache.get("col-1", "q")
|
||||||
|
assert k not in cache._store
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheInvalidation:
|
||||||
|
def test_invalidate_removes_all_entries_for_collection(self):
|
||||||
|
cache.set("col-1", "query a", {"response": "a"})
|
||||||
|
cache.set("col-1", "query b", {"response": "b"})
|
||||||
|
cache.set("col-2", "query a", {"response": "c"})
|
||||||
|
|
||||||
|
removed = cache.invalidate("col-1")
|
||||||
|
|
||||||
|
assert removed == 2
|
||||||
|
assert cache.get("col-1", "query a") is None
|
||||||
|
assert cache.get("col-1", "query b") is None
|
||||||
|
# Other collection unaffected
|
||||||
|
assert cache.get("col-2", "query a") is not None
|
||||||
|
|
||||||
|
def test_invalidate_unknown_collection_returns_zero(self):
|
||||||
|
assert cache.invalidate("nonexistent") == 0
|
||||||
|
|
||||||
|
def test_index_cleaned_up_after_invalidation(self):
|
||||||
|
cache.set("col-1", "q", {"response": "x"})
|
||||||
|
cache.invalidate("col-1")
|
||||||
|
assert "col-1" not in cache._index
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheEviction:
|
||||||
|
def test_does_not_exceed_max_entries(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(cache, "_MAX_ENTRIES", 5)
|
||||||
|
for i in range(10):
|
||||||
|
cache.set("col-1", f"query {i}", {"response": str(i)})
|
||||||
|
assert len(cache._store) <= 5
|
||||||
122
tests/test_chat_test_endpoint.py
Normal file
122
tests/test_chat_test_endpoint.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Tests for the /chat/{id}/test endpoint."""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
AUTH = {"Authorization": "Bearer test-token"}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chatbot():
|
||||||
|
return {
|
||||||
|
"id": "cb-1",
|
||||||
|
"name": "Test Bot",
|
||||||
|
"is_published": True,
|
||||||
|
"qdrant_collection_name": "col-1",
|
||||||
|
"company_id": "company-1",
|
||||||
|
"handoff_enabled": False,
|
||||||
|
"handoff_keywords": [],
|
||||||
|
"lead_capture_enabled": False,
|
||||||
|
"lead_capture_trigger": None,
|
||||||
|
"booking_enabled": False,
|
||||||
|
"system_prompt": "",
|
||||||
|
"companies": {"name": "Acme", "logo_url": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sb(chatbot=None, is_owner=True):
|
||||||
|
sb = MagicMock()
|
||||||
|
|
||||||
|
def table_side(name):
|
||||||
|
m = MagicMock()
|
||||||
|
m.select.return_value = m
|
||||||
|
m.eq.return_value = m
|
||||||
|
m.in_.return_value = m
|
||||||
|
m.limit.return_value = m
|
||||||
|
m.order.return_value = m
|
||||||
|
|
||||||
|
if name == "chatbots":
|
||||||
|
m.execute.return_value = MagicMock(data=[chatbot or _make_chatbot()])
|
||||||
|
elif name == "companies":
|
||||||
|
cid = "company-1" if is_owner else "other-company"
|
||||||
|
m.execute.return_value = MagicMock(data=[{"id": cid, "owner_id": "owner-1"}])
|
||||||
|
else:
|
||||||
|
m.execute.return_value = MagicMock(data=[], count=0)
|
||||||
|
return m
|
||||||
|
|
||||||
|
sb.table.side_effect = table_side
|
||||||
|
sb.auth = MagicMock()
|
||||||
|
return sb
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatTestEndpoint:
|
||||||
|
def _run_test(self, client, questions, chatbot=None, is_owner=True, rag_result=None):
|
||||||
|
default_rag = {
|
||||||
|
"response": "The answer is 42.",
|
||||||
|
"sources": [],
|
||||||
|
"confidence_score": 0.82,
|
||||||
|
"tokens_used": 20,
|
||||||
|
"model": "test-model",
|
||||||
|
}
|
||||||
|
with patch("app.routers.chat.get_supabase") as mock_sb, \
|
||||||
|
patch("app.routers.chat.rag_engine") as mock_rag, \
|
||||||
|
patch("app.dependencies.get_current_user") as mock_user:
|
||||||
|
mock_rag.process_query = AsyncMock(return_value=rag_result or default_rag)
|
||||||
|
mock_sb.return_value = _make_sb(chatbot=chatbot, is_owner=is_owner)
|
||||||
|
user = MagicMock()
|
||||||
|
user.id = "owner-1"
|
||||||
|
mock_user.return_value = user
|
||||||
|
return client.post(
|
||||||
|
"/api/v1/chat/cb-1/test",
|
||||||
|
json={"questions": questions},
|
||||||
|
headers=AUTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_returns_list_of_results(self, client):
|
||||||
|
resp = self._run_test(client, ["What is your return policy?"])
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert isinstance(body, list)
|
||||||
|
assert len(body) == 1
|
||||||
|
|
||||||
|
def test_result_shape(self, client):
|
||||||
|
resp = self._run_test(client, ["Hello?"])
|
||||||
|
result = resp.json()[0]
|
||||||
|
assert "question" in result
|
||||||
|
assert "response" in result
|
||||||
|
assert "confidence_score" in result
|
||||||
|
assert "sources" in result
|
||||||
|
assert "model_used" in result
|
||||||
|
|
||||||
|
def test_question_echoed_in_result(self, client):
|
||||||
|
resp = self._run_test(client, ["What are your hours?"])
|
||||||
|
assert resp.json()[0]["question"] == "What are your hours?"
|
||||||
|
|
||||||
|
def test_multiple_questions_all_answered(self, client):
|
||||||
|
questions = ["Q1", "Q2", "Q3"]
|
||||||
|
resp = self._run_test(client, questions)
|
||||||
|
assert len(resp.json()) == 3
|
||||||
|
returned_questions = [r["question"] for r in resp.json()]
|
||||||
|
assert returned_questions == questions
|
||||||
|
|
||||||
|
def test_requires_authentication(self, client):
|
||||||
|
resp = client.post("/api/v1/chat/cb-1/test", json={"questions": ["hi"]})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_rejects_more_than_10_questions(self, client):
|
||||||
|
resp = self._run_test(client, [f"Q{i}" for i in range(11)])
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_rejects_empty_question_list(self, client):
|
||||||
|
resp = self._run_test(client, [])
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_chatbot_without_collection_returns_400(self, client):
|
||||||
|
bot = _make_chatbot()
|
||||||
|
bot["qdrant_collection_name"] = None
|
||||||
|
resp = self._run_test(client, ["Hi"], chatbot=bot)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_confidence_score_passed_through(self, client):
|
||||||
|
rag_result = {"response": "Sure", "sources": [], "confidence_score": 0.73,
|
||||||
|
"tokens_used": 5, "model": "m"}
|
||||||
|
resp = self._run_test(client, ["Question?"], rag_result=rag_result)
|
||||||
|
assert resp.json()[0]["confidence_score"] == pytest.approx(0.73)
|
||||||
94
tests/test_rag_cache.py
Normal file
94
tests/test_rag_cache.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""Tests for RAG response caching integration."""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from app.services import cache as response_cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
response_cache._store.clear()
|
||||||
|
response_cache._index.clear()
|
||||||
|
yield
|
||||||
|
response_cache._store.clear()
|
||||||
|
response_cache._index.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rag():
|
||||||
|
from app.services.rag import RAGEngine
|
||||||
|
return RAGEngine()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chatbot_config():
|
||||||
|
return {
|
||||||
|
"model": "accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||||
|
"max_tokens": 500,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"company_name": "Test Corp",
|
||||||
|
"system_prompt": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def good_search_result():
|
||||||
|
return [{
|
||||||
|
"payload": {"text": "We open 9am–6pm Mon–Fri.", "file_name": "faq.pdf", "page_number": 1},
|
||||||
|
"score": 0.82,
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGCaching:
|
||||||
|
async def test_second_identical_query_uses_cache(self, rag, chatbot_config, good_search_result):
|
||||||
|
llm_mock = AsyncMock(return_value={"content": "9am to 6pm", "tokens_used": 20, "model": "m"})
|
||||||
|
|
||||||
|
with patch.object(rag.embedding_svc, "embed_text", return_value=[0.1] * 1536), \
|
||||||
|
patch.object(rag.vector_svc, "search", return_value=good_search_result), \
|
||||||
|
patch.object(rag.llm_svc, "generate", llm_mock):
|
||||||
|
|
||||||
|
await rag.process_query("What are your hours?", "col-1", chatbot_config)
|
||||||
|
await rag.process_query("What are your hours?", "col-1", chatbot_config)
|
||||||
|
|
||||||
|
# LLM should only be called once; second call hits cache
|
||||||
|
assert llm_mock.call_count == 1
|
||||||
|
|
||||||
|
async def test_cache_not_used_when_conversation_history_present(self, rag, chatbot_config, good_search_result):
|
||||||
|
llm_mock = AsyncMock(return_value={"content": "Yes!", "tokens_used": 10, "model": "m"})
|
||||||
|
|
||||||
|
history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]
|
||||||
|
|
||||||
|
with patch.object(rag.embedding_svc, "embed_text", return_value=[0.1] * 1536), \
|
||||||
|
patch.object(rag.vector_svc, "search", return_value=good_search_result), \
|
||||||
|
patch.object(rag.llm_svc, "generate", llm_mock):
|
||||||
|
|
||||||
|
await rag.process_query("Follow-up question", "col-1", chatbot_config, conversation_history=history)
|
||||||
|
await rag.process_query("Follow-up question", "col-1", chatbot_config, conversation_history=history)
|
||||||
|
|
||||||
|
# Both calls go to LLM because history makes them stateful
|
||||||
|
assert llm_mock.call_count == 2
|
||||||
|
|
||||||
|
async def test_different_collections_cached_separately(self, rag, chatbot_config, good_search_result):
|
||||||
|
llm_mock = AsyncMock(return_value={"content": "Answer", "tokens_used": 10, "model": "m"})
|
||||||
|
|
||||||
|
with patch.object(rag.embedding_svc, "embed_text", return_value=[0.1] * 1536), \
|
||||||
|
patch.object(rag.vector_svc, "search", return_value=good_search_result), \
|
||||||
|
patch.object(rag.llm_svc, "generate", llm_mock):
|
||||||
|
|
||||||
|
await rag.process_query("Same question", "col-A", chatbot_config)
|
||||||
|
await rag.process_query("Same question", "col-B", chatbot_config)
|
||||||
|
|
||||||
|
# Different collections → two LLM calls, not one
|
||||||
|
assert llm_mock.call_count == 2
|
||||||
|
|
||||||
|
async def test_confidence_score_returned_from_cache(self, rag, chatbot_config, good_search_result):
|
||||||
|
llm_mock = AsyncMock(return_value={"content": "Cached answer", "tokens_used": 10, "model": "m"})
|
||||||
|
|
||||||
|
with patch.object(rag.embedding_svc, "embed_text", return_value=[0.1] * 1536), \
|
||||||
|
patch.object(rag.vector_svc, "search", return_value=good_search_result), \
|
||||||
|
patch.object(rag.llm_svc, "generate", llm_mock):
|
||||||
|
|
||||||
|
first = await rag.process_query("hours?", "col-1", chatbot_config)
|
||||||
|
second = await rag.process_query("hours?", "col-1", chatbot_config)
|
||||||
|
|
||||||
|
assert first["confidence_score"] == second["confidence_score"]
|
||||||
|
assert second["response"] == "Cached answer"
|
||||||
115
tests/test_url_refresh.py
Normal file
115
tests/test_url_refresh.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""Tests for URL source refresh endpoint."""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
|
||||||
|
AUTH = {"Authorization": "Bearer test-token"}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sb(source=None, chatbot_company="company-1"):
|
||||||
|
sb = MagicMock()
|
||||||
|
default_source = {
|
||||||
|
"id": "src-1",
|
||||||
|
"chatbot_id": "cb-1",
|
||||||
|
"url": "https://example.com/faq",
|
||||||
|
"status": "completed",
|
||||||
|
"page_title": "FAQ",
|
||||||
|
"chunk_count": 10,
|
||||||
|
"error_message": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def table_side(name):
|
||||||
|
m = MagicMock()
|
||||||
|
m.select.return_value = m
|
||||||
|
m.insert.return_value = m
|
||||||
|
m.update.return_value = m
|
||||||
|
m.delete.return_value = m
|
||||||
|
m.eq.return_value = m
|
||||||
|
m.in_.return_value = m
|
||||||
|
m.returning.return_value = m
|
||||||
|
m.limit.return_value = m
|
||||||
|
m.order.return_value = m
|
||||||
|
|
||||||
|
if name == "chatbots":
|
||||||
|
m.execute.return_value = MagicMock(data=[{
|
||||||
|
"id": "cb-1",
|
||||||
|
"company_id": chatbot_company,
|
||||||
|
"qdrant_collection_name": "col-1",
|
||||||
|
}])
|
||||||
|
elif name == "companies":
|
||||||
|
m.execute.return_value = MagicMock(data=[{"id": chatbot_company, "owner_id": "user-1"}])
|
||||||
|
elif name == "url_sources":
|
||||||
|
m.execute.return_value = MagicMock(data=[source or default_source])
|
||||||
|
else:
|
||||||
|
m.execute.return_value = MagicMock(data=[], count=0)
|
||||||
|
return m
|
||||||
|
|
||||||
|
sb.table.side_effect = table_side
|
||||||
|
sb.auth = MagicMock()
|
||||||
|
return sb
|
||||||
|
|
||||||
|
|
||||||
|
class TestUrlRefresh:
|
||||||
|
def _refresh(self, client, source=None):
|
||||||
|
with patch("app.routers.documents.get_supabase") as mock_sb, \
|
||||||
|
patch("app.routers.documents.vector_store") as mock_vs, \
|
||||||
|
patch("app.routers.documents.response_cache") as mock_cache, \
|
||||||
|
patch("app.dependencies.get_current_user") as mock_user, \
|
||||||
|
patch("app.routers.documents._process_url_source", new_callable=AsyncMock):
|
||||||
|
mock_sb.return_value = _make_sb(source=source)
|
||||||
|
mock_vs.delete_by_document_id = MagicMock()
|
||||||
|
mock_vs.collection_exists = MagicMock(return_value=True)
|
||||||
|
mock_cache.invalidate = MagicMock()
|
||||||
|
user = MagicMock()
|
||||||
|
user.id = "user-1"
|
||||||
|
mock_user.return_value = user
|
||||||
|
return client.post(
|
||||||
|
"/api/v1/chatbots/cb-1/url-sources/src-1/refresh",
|
||||||
|
headers=AUTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_returns_200(self, client):
|
||||||
|
resp = self._refresh(client)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_source_reset_to_pending(self, client):
|
||||||
|
resp = self._refresh(client)
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "pending"
|
||||||
|
assert body["chunk_count"] == 0
|
||||||
|
|
||||||
|
def test_returns_404_for_unknown_source(self, client):
|
||||||
|
with patch("app.routers.documents.get_supabase") as mock_sb, \
|
||||||
|
patch("app.dependencies.get_current_user") as mock_user:
|
||||||
|
sb = _make_sb()
|
||||||
|
# Override url_sources to return empty
|
||||||
|
def table_side(name):
|
||||||
|
m = MagicMock()
|
||||||
|
m.select.return_value = m
|
||||||
|
m.update.return_value = m
|
||||||
|
m.eq.return_value = m
|
||||||
|
if name == "chatbots":
|
||||||
|
m.execute.return_value = MagicMock(data=[{
|
||||||
|
"id": "cb-1", "company_id": "company-1",
|
||||||
|
"qdrant_collection_name": "col-1",
|
||||||
|
}])
|
||||||
|
elif name == "companies":
|
||||||
|
m.execute.return_value = MagicMock(data=[{"id": "company-1", "owner_id": "user-1"}])
|
||||||
|
else:
|
||||||
|
m.execute.return_value = MagicMock(data=[])
|
||||||
|
return m
|
||||||
|
sb2 = MagicMock()
|
||||||
|
sb2.table.side_effect = table_side
|
||||||
|
sb2.auth = MagicMock()
|
||||||
|
mock_sb.return_value = sb2
|
||||||
|
user = MagicMock()
|
||||||
|
user.id = "user-1"
|
||||||
|
mock_user.return_value = user
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/chatbots/cb-1/url-sources/no-such-src/refresh",
|
||||||
|
headers=AUTH,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_requires_authentication(self, client):
|
||||||
|
resp = client.post("/api/v1/chatbots/cb-1/url-sources/src-1/refresh")
|
||||||
|
assert resp.status_code == 401
|
||||||
Reference in New Issue
Block a user