fixed the RAg in test pipeline issue

This commit is contained in:
belviskhoremk
2026-04-26 21:43:19 +00:00
parent 78023ae9c5
commit 260a9c6353
9 changed files with 262 additions and 78 deletions

View File

@@ -13,6 +13,7 @@ class Settings(BaseSettings):
supabase_url: str = ""
supabase_anon_key: str = ""
supabase_service_role_key: str = ""
supabase_jwt_secret: Optional[str] = None # Settings → API → JWT Secret in Supabase dashboard
# Qdrant
qdrant_url: str = "http://localhost:6333"
@@ -99,12 +100,24 @@ MODEL_CATALOG = {
"badge": "Smart",
"description": "Cost-effective and highly capable model",
},
"accounts/fireworks/models/deepseek-v3p2": {
"name": "DeepSeek V3.2",
"provider": "Fireworks AI",
"badge": "Smart",
"description": "Latest DeepSeek — faster and more capable",
},
"accounts/fireworks/models/kimi-k2-instruct": {
"name": "Kimi K2",
"provider": "Fireworks AI",
"badge": "Multilingual",
"description": "Strong multilingual and coding capabilities",
},
"accounts/fireworks/models/kimi-k2p5-instruct": {
"name": "Kimi K2.5",
"provider": "Fireworks AI",
"badge": "Multilingual",
"description": "Upgraded Kimi — stronger reasoning and multilingual",
},
# ── Pro tier (Premium providers) ───────────────────────────────────────────
# OpenAI
@@ -156,7 +169,9 @@ MODEL_PROVIDERS = {
"accounts/fireworks/models/llama-v3p3-70b-instruct": "fireworks",
"accounts/fireworks/models/qwen3-235b-a22b": "fireworks",
"accounts/fireworks/models/deepseek-v3p1": "fireworks",
"accounts/fireworks/models/deepseek-v3p2": "fireworks",
"accounts/fireworks/models/kimi-k2-instruct": "fireworks",
"accounts/fireworks/models/kimi-k2p5-instruct": "fireworks",
# OpenAI
"gpt-4o": "openai",
"gpt-4o-mini": "openai",
@@ -209,7 +224,9 @@ _ALL_FIREWORKS = [
"accounts/fireworks/models/llama-v3p3-70b-instruct",
"accounts/fireworks/models/qwen3-235b-a22b",
"accounts/fireworks/models/deepseek-v3p1",
"accounts/fireworks/models/deepseek-v3p2",
"accounts/fireworks/models/kimi-k2-instruct",
"accounts/fireworks/models/kimi-k2p5-instruct",
]
_ALL_PREMIUM = [
"gpt-4o", "gpt-4o-mini",

View File

@@ -1,18 +1,77 @@
from fastapi import Depends, HTTPException, status, Header
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional
from dataclasses import dataclass, field
from app.database import get_supabase
from app.config import settings
import base64
import hashlib
import hmac
import json
import logging
import time
logger = logging.getLogger(__name__)
security = HTTPBearer(auto_error=False)
@dataclass
class _LocalUser:
"""Minimal user object built from JWT claims — mirrors the fields used downstream."""
id: str
email: str
role: str = "authenticated"
app_metadata: dict = field(default_factory=dict)
user_metadata: dict = field(default_factory=dict)
def _verify_jwt_local(token: str) -> Optional[_LocalUser]:
"""Verify a Supabase HS256 JWT using the local secret (no network call).
Returns None if the secret is not configured, the signature is wrong, or the token is expired."""
secret = settings.supabase_jwt_secret
if not secret:
return None
try:
parts = token.split(".")
if len(parts) != 3:
return None
header_b64, payload_b64, sig_b64 = parts
# Verify HMAC-SHA256 signature
message = f"{header_b64}.{payload_b64}".encode()
expected = hmac.new(secret.encode(), message, hashlib.sha256).digest()
padding = "=" * (-len(sig_b64) % 4)
actual = base64.urlsafe_b64decode(sig_b64 + padding)
if not hmac.compare_digest(expected, actual):
return None
# Decode payload
padding = "=" * (-len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(payload_b64 + padding))
# Check expiry
if payload.get("exp", 0) < time.time():
return None
return _LocalUser(
id=payload["sub"],
email=payload.get("email", ""),
role=payload.get("role", "authenticated"),
app_metadata=payload.get("app_metadata", {}),
user_metadata=payload.get("user_metadata", {}),
)
except Exception:
return None
async def get_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
):
"""Extract and verify the current user from Supabase JWT"""
"""Extract and verify the current user from a Supabase JWT.
Tries local HS256 verification first (no network call, no SSL risk).
Falls back to supabase.auth.get_user() only when the JWT secret is not configured.
"""
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -20,8 +79,13 @@ async def get_current_user(
)
token = credentials.credentials
supabase = get_supabase()
# ── Fast path: local verification ────────────────────────────────────────
user = _verify_jwt_local(token)
# ── Slow path: network call (only if SUPABASE_JWT_SECRET is not set) ─────
if user is None:
supabase = get_supabase()
try:
response = supabase.auth.get_user(token)
if not response or not response.user:
@@ -30,9 +94,18 @@ async def get_current_user(
detail="Invalid or expired token",
)
user = response.user
except HTTPException:
raise
except Exception as e:
logger.error(f"Auth error: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
# Check for suspension
# ── Suspension check (DB, not network-auth, so still fast) ───────────────
try:
supabase = get_supabase()
profile = supabase.table("user_profiles").select("suspended_at").eq("user_id", user.id).execute()
if profile.data and profile.data[0].get("suspended_at"):
raise HTTPException(
@@ -42,17 +115,9 @@ async def get_current_user(
except HTTPException:
raise
except Exception:
pass # Don't block login if profile lookup fails
pass # Never block login if profile lookup fails
return user
except HTTPException:
raise
except Exception as e:
logger.error(f"Auth error: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
async def get_admin_user(

View File

@@ -32,3 +32,4 @@ def configure_logging():
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("hpack").setLevel(logging.WARNING)

View File

@@ -1,3 +1,4 @@
import asyncio
import time
from collections import defaultdict
@@ -311,8 +312,7 @@ async def test_chat(
company_data = chatbot.get("companies", {}) or {}
chatbot_config = {**chatbot, "company_name": company_data.get("name", "")}
results = []
for question in body.questions:
async def _run_one(question: str) -> TestChatResult:
try:
result = await rag_engine.process_query(
query=question,
@@ -322,22 +322,24 @@ async def test_chat(
language="auto",
bypass_cache=True,
)
results.append(TestChatResult(
return TestChatResult(
question=question,
response=result["response"],
confidence_score=result.get("confidence_score", 0.0),
sources=result.get("sources", []),
model_used=result.get("model", ""),
))
)
except Exception as e:
results.append(TestChatResult(
return TestChatResult(
question=question,
response=f"Error: {e}",
confidence_score=0.0,
sources=[],
model_used="",
))
return results
)
results = await asyncio.gather(*[_run_one(q) for q in body.questions])
return list(results)
# ── OLD analytics endpoint REMOVED ───────────────────────────────────────────

View File

@@ -94,7 +94,7 @@ async def upload_document(
file_bytes=file_bytes,
file_name=file.filename,
doc_id=doc_id,
chatbot=chatbot,
chatbot_id=chatbot_id,
supabase=supabase,
)
@@ -105,16 +105,28 @@ async def _process_document_bg(
file_bytes: bytes,
file_name: str,
doc_id: str,
chatbot: dict,
chatbot_id: str,
supabase,
):
"""Background task to process and embed a document"""
try:
# Re-fetch chatbot to guarantee we use the canonical collection and company_id,
# not a snapshot that could have been captured before an update.
chatbot_row = supabase.table("chatbots").select("company_id, qdrant_collection_name").eq("id", chatbot_id).execute()
if not chatbot_row.data:
logger.error(f"Chatbot {chatbot_id} not found during document processing")
supabase.table("documents").update({
"status": "failed",
"error_message": "Chatbot not found"
}).eq("id", doc_id).execute()
return
chatbot = chatbot_row.data[0]
company_id = chatbot.get("company_id", "")
collection_name = chatbot.get("qdrant_collection_name")
if not collection_name:
logger.error(f"No Qdrant collection for chatbot {chatbot['id']}")
logger.error(f"No Qdrant collection for chatbot {chatbot_id}")
supabase.table("documents").update({
"status": "failed",
"error_message": "Vector store not configured"
@@ -168,7 +180,7 @@ async def _process_document_bg(
}).eq("id", doc_id).execute()
response_cache.invalidate(collection_name)
logger.info(f"Document {doc_id} processed: {len(chunks)} chunks")
logger.info(f"Document {doc_id} processed: {len(chunks)} chunks → collection='{collection_name}' company='{company_id}'")
except Exception as e:
logger.error(f"Document processing error for {doc_id}: {e}")
@@ -274,7 +286,7 @@ async def retry_document_processing(
file_bytes=file_bytes,
file_name=document["file_name"],
doc_id=document_id,
chatbot=chatbot,
chatbot_id=chatbot_id,
supabase=supabase,
)
@@ -333,7 +345,7 @@ async def add_url_source(
_process_url_source,
source_id=source_id,
url=data.url,
chatbot=chatbot,
chatbot_id=chatbot_id,
supabase=supabase,
)
@@ -394,12 +406,12 @@ async def refresh_url_source(
"chunk_count": 0,
}).eq("id", source_id).returning("representation").execute()
background_tasks.add_task(_process_url_source, source_id, src["url"], chatbot, supabase)
background_tasks.add_task(_process_url_source, source_id, src["url"], chatbot_id, supabase)
return UrlSourceResponse(**{**src, "status": "pending", "chunk_count": 0})
async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase):
async def _process_url_source(source_id: str, url: str, chatbot_id: str, supabase):
"""Background task to scrape a URL and add its content to the vector store."""
from app.services.web_scraper import scrape_url
from app.services.document_processor import chunk_text
@@ -407,6 +419,18 @@ async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase)
from app.services.vector_store import vector_store
try:
# Re-fetch chatbot to guarantee we use the canonical collection and company_id.
chatbot_row = supabase.table("chatbots").select("company_id, qdrant_collection_name").eq("id", chatbot_id).execute()
if not chatbot_row.data:
logger.error(f"Chatbot {chatbot_id} not found during URL source processing")
supabase.table("url_sources").update({
"status": "failed",
"error_message": "Chatbot not found",
}).eq("id", source_id).execute()
return
chatbot = chatbot_row.data[0]
# Update status to processing
supabase.table("url_sources").update({"status": "processing"}).eq("id", source_id).execute()
@@ -480,7 +504,8 @@ async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase)
}).eq("id", source_id).execute()
response_cache.invalidate(collection_name)
logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url}")
logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url} → collection='{collection_name}' company='{chatbot.get('company_id', '')}'")
except Exception as e:
logger.error(f"URL source processing error {source_id}: {e}")

View File

@@ -1,9 +1,43 @@
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
from typing import List, Dict, Any, Optional, AsyncGenerator
import logging
import re
logger = logging.getLogger(__name__)
# Ordered fallback chain — tried in sequence when the primary model fails.
# Fireworks models are used for free/starter plans so they must always be available.
# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working).
_FIREWORKS_FALLBACKS = [
"accounts/fireworks/models/kimi-k2p5-instruct",
"accounts/fireworks/models/deepseek-v3p2",
"accounts/fireworks/models/llama-v3p3-70b-instruct",
]
def _normalize_model(model: str) -> str:
"""Strip date-based version suffixes from Fireworks model IDs.
e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905''accounts/fireworks/models/kimi-k2-instruct'
Matches only purely-numeric suffixes (48 digits) so names like 'llama-v3p3-70b' are untouched."""
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
model = re.sub(r"-\d{4,8}$", "", model)
return model
def _infer_provider(model: str) -> str:
"""Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS.
Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'."""
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
return "fireworks"
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"):
return "openai"
if model.startswith("claude-"):
return "anthropic"
if model.startswith("gemini-"):
return "google"
logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks")
return "fireworks"
class LLMService:
"""Routes requests to appropriate LLM provider"""
@@ -16,7 +50,8 @@ class LLMService:
temperature: float = 0.7,
) -> Dict[str, Any]:
"""Generate a response from the LLM"""
provider = MODEL_PROVIDERS.get(model, "openai")
model = _normalize_model(model)
provider = MODEL_PROVIDERS.get(model) or _infer_provider(model)
try:
if provider == "fireworks":
@@ -31,9 +66,16 @@ class LLMService:
return await self._call_openai(messages, model, max_tokens, temperature)
except Exception as e:
logger.error(f"LLM error ({provider}/{model}): {e}")
fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct"
if model != fallback and settings.fireworks_api_key:
if not settings.fireworks_api_key:
raise
for fallback in _FIREWORKS_FALLBACKS:
if model == fallback:
continue
try:
logger.warning(f"[LLM] Falling back to {fallback}")
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
except Exception as fe:
logger.error(f"[LLM] Fallback {fallback} also failed: {fe}")
raise
async def _call_fireworks(

View File

@@ -9,7 +9,7 @@ import logging
logger = logging.getLogger(__name__)
RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}.
Your role is to answer questions based on the provided context from company documents.
Your role is to answer questions based on the provided context from the knowledge base (documents and web pages).
IMPORTANT RULES:
1. Answer based on the provided context below
@@ -20,7 +20,7 @@ IMPORTANT RULES:
{language_instruction}
{custom_instructions}
Context from knowledge base:
Knowledge base context:
{context}
"""
@@ -74,14 +74,22 @@ class RAGEngine:
}
# Step 2: Retrieve relevant chunks
# Fetch more than needed so that after filtering low-quality results
# we still have enough context. score_threshold=0.55 keeps only chunks
# that are genuinely relevant for text-embedding-3-small cosine similarity.
# Retrieve more candidates than needed (10) with a slightly relaxed threshold (0.45)
# so that content from both document and URL sources gets fair representation.
# Scraped web text embeds less cleanly than structured documents, so 0.55 was
# filtering out valid URL chunks. Context is capped by char limit below.
total_in_collection = self.vector_svc.count_vectors(collection_name)
logger.info(f"[RAG] Collection '{collection_name}' has {total_in_collection} vectors total")
# No score_threshold — always return the top-N most similar chunks by rank.
# Absolute cosine scores vary widely by document type and embedding model;
# filtering by a fixed cutoff here discards valid context when scores are
# uniformly low. The confidence_score below captures retrieval quality for
# handoff/fallback decisions without silencing the LLM's context.
retrieved = self.vector_svc.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=8,
score_threshold=0.55,
limit=10,
)
logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'")
@@ -90,20 +98,33 @@ class RAGEngine:
text_preview = item.get("payload", {}).get("text", "")[:80]
logger.info(f"[RAG] Chunk {i+1}: score={score:.4f}, preview='{text_preview}...'")
# Step 3: Build sources
# Step 3: Build sources and labeled context
# Each chunk is prefixed with its source so the LLM can synthesize
# correctly when mixing document and URL content.
MAX_CONTEXT_CHARS = 10_000
sources = []
context_parts = []
seen_texts = set()
total_chars = 0
for item in retrieved:
payload = item.get("payload", {})
text = payload.get("text", "")
if text and text not in seen_texts:
if not text or text in seen_texts:
continue
if total_chars + len(text) > MAX_CONTEXT_CHARS:
break
seen_texts.add(text)
context_parts.append(text)
total_chars += len(text)
file_name = payload.get("file_name", "Document")
source_url = payload.get("source_url")
label = f"[Source: {source_url}]" if source_url else f"[Source: {file_name}]"
context_parts.append(f"{label}\n{text}")
sources.append(
SourceDocument(
document_name=payload.get("file_name", "Document"),
document_name=file_name,
chunk_text=text[:200] + "..." if len(text) > 200 else text,
score=item.get("score", 0.0),
page_number=payload.get("page_number"),

View File

@@ -1,7 +1,5 @@
from qdrant_client import QdrantClient, models
from qdrant_client.http.models import (
Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
)
from qdrant_client.http.models import Distance, VectorParams, PointStruct
from app.config import settings
from typing import List, Dict, Any, Optional
import logging
@@ -103,15 +101,13 @@ class VectorStoreService:
collection_name: str,
query_vector: List[float],
limit: int = 5,
score_threshold: float = 0.3,
) -> List[Dict[str, Any]]:
"""Search for similar vectors"""
"""Search for similar vectors, returning the top-N by cosine score."""
try:
results = self.client.query_points(
collection_name=collection_name,
query=query_vector,
limit=limit,
score_threshold=score_threshold,
).points
return [
{
@@ -122,7 +118,7 @@ class VectorStoreService:
for r in results
]
except Exception as e:
logger.error(f"Error searching vectors: {e}")
logger.error(f"Error searching vectors in '{collection_name}': {e}", exc_info=True)
return []
def delete_by_document_id(self, collection_name: str, document_id: str) -> bool:
@@ -131,19 +127,21 @@ class VectorStoreService:
self.client.delete(
collection_name=collection_name,
points_selector=models.FilterSelector(
filter=Filter(
filter=models.Filter(
must=[
FieldCondition(
models.FieldCondition(
key="document_id",
match=MatchValue(value=document_id),
match=models.MatchValue(value=document_id),
)
]
)
),
wait=True,
)
logger.info(f"Deleted vectors for document '{document_id}' from '{collection_name}'")
return True
except Exception as e:
logger.error(f"Error deleting document vectors: {e}")
logger.error(f"Error deleting vectors for document '{document_id}' in '{collection_name}': {e}", exc_info=True)
return False
def count_vectors(self, collection_name: str) -> int:

View File

@@ -42,9 +42,22 @@ async def scrape_url(url: str) -> dict:
main = soup.find("main") or soup.find("article") or soup.find("body") or soup
text = main.get_text(separator="\n", strip=True)
# Clean up whitespace
lines = [line.strip() for line in text.splitlines() if line.strip()]
text = "\n".join(lines)
# Clean up whitespace and filter structural noise
seen_lines: set[str] = set()
clean_lines = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
# Skip very short lines (nav items, button labels, breadcrumb separators)
if len(line) < 15:
continue
# Skip duplicate lines (nav/footer repeated across sections)
if line in seen_lines:
continue
seen_lines.add(line)
clean_lines.append(line)
text = "\n".join(clean_lines)
# Limit size
if len(text.encode("utf-8")) > MAX_TEXT_BYTES: