From 260a9c6353a3342e2ac63d941391a5a7f66c7986 Mon Sep 17 00:00:00 2001 From: belviskhoremk Date: Sun, 26 Apr 2026 21:43:19 +0000 Subject: [PATCH] fixed the RAg in test pipeline issue --- app/config.py | 17 +++++ app/dependencies.py | 117 +++++++++++++++++++++++++++-------- app/logging_config.py | 1 + app/routers/chat.py | 16 ++--- app/routers/documents.py | 43 ++++++++++--- app/services/llm.py | 50 +++++++++++++-- app/services/rag.py | 57 +++++++++++------ app/services/vector_store.py | 20 +++--- app/services/web_scraper.py | 19 +++++- 9 files changed, 262 insertions(+), 78 deletions(-) diff --git a/app/config.py b/app/config.py index a4c9635..6ae3aa1 100644 --- a/app/config.py +++ b/app/config.py @@ -13,6 +13,7 @@ class Settings(BaseSettings): supabase_url: str = "" supabase_anon_key: str = "" supabase_service_role_key: str = "" + supabase_jwt_secret: Optional[str] = None # Settings → API → JWT Secret in Supabase dashboard # Qdrant qdrant_url: str = "http://localhost:6333" @@ -99,12 +100,24 @@ MODEL_CATALOG = { "badge": "Smart", "description": "Cost-effective and highly capable model", }, + "accounts/fireworks/models/deepseek-v3p2": { + "name": "DeepSeek V3.2", + "provider": "Fireworks AI", + "badge": "Smart", + "description": "Latest DeepSeek — faster and more capable", + }, "accounts/fireworks/models/kimi-k2-instruct": { "name": "Kimi K2", "provider": "Fireworks AI", "badge": "Multilingual", "description": "Strong multilingual and coding capabilities", }, + "accounts/fireworks/models/kimi-k2p5-instruct": { + "name": "Kimi K2.5", + "provider": "Fireworks AI", + "badge": "Multilingual", + "description": "Upgraded Kimi — stronger reasoning and multilingual", + }, # ── Pro tier (Premium providers) ─────────────────────────────────────────── # OpenAI @@ -156,7 +169,9 @@ MODEL_PROVIDERS = { "accounts/fireworks/models/llama-v3p3-70b-instruct": "fireworks", "accounts/fireworks/models/qwen3-235b-a22b": "fireworks", "accounts/fireworks/models/deepseek-v3p1": "fireworks", + "accounts/fireworks/models/deepseek-v3p2": "fireworks", "accounts/fireworks/models/kimi-k2-instruct": "fireworks", + "accounts/fireworks/models/kimi-k2p5-instruct": "fireworks", # OpenAI "gpt-4o": "openai", "gpt-4o-mini": "openai", @@ -209,7 +224,9 @@ _ALL_FIREWORKS = [ "accounts/fireworks/models/llama-v3p3-70b-instruct", "accounts/fireworks/models/qwen3-235b-a22b", "accounts/fireworks/models/deepseek-v3p1", + "accounts/fireworks/models/deepseek-v3p2", "accounts/fireworks/models/kimi-k2-instruct", + "accounts/fireworks/models/kimi-k2p5-instruct", ] _ALL_PREMIUM = [ "gpt-4o", "gpt-4o-mini", diff --git a/app/dependencies.py b/app/dependencies.py index b9ba036..fb9e651 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -1,18 +1,77 @@ from fastapi import Depends, HTTPException, status, Header from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from typing import Optional +from dataclasses import dataclass, field from app.database import get_supabase from app.config import settings +import base64 +import hashlib +import hmac +import json import logging +import time logger = logging.getLogger(__name__) security = HTTPBearer(auto_error=False) +@dataclass +class _LocalUser: + """Minimal user object built from JWT claims — mirrors the fields used downstream.""" + id: str + email: str + role: str = "authenticated" + app_metadata: dict = field(default_factory=dict) + user_metadata: dict = field(default_factory=dict) + + +def _verify_jwt_local(token: str) -> Optional[_LocalUser]: + """Verify a Supabase HS256 JWT using the local secret (no network call). + Returns None if the secret is not configured, the signature is wrong, or the token is expired.""" + secret = settings.supabase_jwt_secret + if not secret: + return None + try: + parts = token.split(".") + if len(parts) != 3: + return None + header_b64, payload_b64, sig_b64 = parts + + # Verify HMAC-SHA256 signature + message = f"{header_b64}.{payload_b64}".encode() + expected = hmac.new(secret.encode(), message, hashlib.sha256).digest() + padding = "=" * (-len(sig_b64) % 4) + actual = base64.urlsafe_b64decode(sig_b64 + padding) + if not hmac.compare_digest(expected, actual): + return None + + # Decode payload + padding = "=" * (-len(payload_b64) % 4) + payload = json.loads(base64.urlsafe_b64decode(payload_b64 + padding)) + + # Check expiry + if payload.get("exp", 0) < time.time(): + return None + + return _LocalUser( + id=payload["sub"], + email=payload.get("email", ""), + role=payload.get("role", "authenticated"), + app_metadata=payload.get("app_metadata", {}), + user_metadata=payload.get("user_metadata", {}), + ) + except Exception: + return None + + async def get_current_user( credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), ): - """Extract and verify the current user from Supabase JWT""" + """Extract and verify the current user from a Supabase JWT. + + Tries local HS256 verification first (no network call, no SSL risk). + Falls back to supabase.auth.get_user() only when the JWT secret is not configured. + """ if not credentials: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -20,39 +79,45 @@ async def get_current_user( ) token = credentials.credentials - supabase = get_supabase() - try: - response = supabase.auth.get_user(token) - if not response or not response.user: + # ── Fast path: local verification ──────────────────────────────────────── + user = _verify_jwt_local(token) + + # ── Slow path: network call (only if SUPABASE_JWT_SECRET is not set) ───── + if user is None: + supabase = get_supabase() + try: + response = supabase.auth.get_user(token) + if not response or not response.user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + ) + user = response.user + except HTTPException: + raise + except Exception as e: + logger.error(f"Auth error: {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token", ) - user = response.user - # Check for suspension - try: - profile = supabase.table("user_profiles").select("suspended_at").eq("user_id", user.id).execute() - if profile.data and profile.data[0].get("suspended_at"): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Account suspended. Please contact support.", - ) - except HTTPException: - raise - except Exception: - pass # Don't block login if profile lookup fails - - return user + # ── Suspension check (DB, not network-auth, so still fast) ─────────────── + try: + supabase = get_supabase() + profile = supabase.table("user_profiles").select("suspended_at").eq("user_id", user.id).execute() + if profile.data and profile.data[0].get("suspended_at"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Account suspended. Please contact support.", + ) except HTTPException: raise - except Exception as e: - logger.error(f"Auth error: {e}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - ) + except Exception: + pass # Never block login if profile lookup fails + + return user async def get_admin_user( diff --git a/app/logging_config.py b/app/logging_config.py index d97ef27..75657c3 100644 --- a/app/logging_config.py +++ b/app/logging_config.py @@ -32,3 +32,4 @@ def configure_logging(): logging.getLogger("uvicorn.access").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) + logging.getLogger("hpack").setLevel(logging.WARNING) diff --git a/app/routers/chat.py b/app/routers/chat.py index d212356..d687712 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -1,3 +1,4 @@ +import asyncio import time from collections import defaultdict @@ -311,8 +312,7 @@ async def test_chat( company_data = chatbot.get("companies", {}) or {} chatbot_config = {**chatbot, "company_name": company_data.get("name", "")} - results = [] - for question in body.questions: + async def _run_one(question: str) -> TestChatResult: try: result = await rag_engine.process_query( query=question, @@ -322,22 +322,24 @@ async def test_chat( language="auto", bypass_cache=True, ) - results.append(TestChatResult( + return TestChatResult( question=question, response=result["response"], confidence_score=result.get("confidence_score", 0.0), sources=result.get("sources", []), model_used=result.get("model", ""), - )) + ) except Exception as e: - results.append(TestChatResult( + return TestChatResult( question=question, response=f"Error: {e}", confidence_score=0.0, sources=[], model_used="", - )) - return results + ) + + results = await asyncio.gather(*[_run_one(q) for q in body.questions]) + return list(results) # ── OLD analytics endpoint REMOVED ─────────────────────────────────────────── diff --git a/app/routers/documents.py b/app/routers/documents.py index 7e8faea..d53c7cd 100644 --- a/app/routers/documents.py +++ b/app/routers/documents.py @@ -94,7 +94,7 @@ async def upload_document( file_bytes=file_bytes, file_name=file.filename, doc_id=doc_id, - chatbot=chatbot, + chatbot_id=chatbot_id, supabase=supabase, ) @@ -105,16 +105,28 @@ async def _process_document_bg( file_bytes: bytes, file_name: str, doc_id: str, - chatbot: dict, + chatbot_id: str, supabase, ): """Background task to process and embed a document""" try: + # Re-fetch chatbot to guarantee we use the canonical collection and company_id, + # not a snapshot that could have been captured before an update. + chatbot_row = supabase.table("chatbots").select("company_id, qdrant_collection_name").eq("id", chatbot_id).execute() + if not chatbot_row.data: + logger.error(f"Chatbot {chatbot_id} not found during document processing") + supabase.table("documents").update({ + "status": "failed", + "error_message": "Chatbot not found" + }).eq("id", doc_id).execute() + return + + chatbot = chatbot_row.data[0] company_id = chatbot.get("company_id", "") collection_name = chatbot.get("qdrant_collection_name") if not collection_name: - logger.error(f"No Qdrant collection for chatbot {chatbot['id']}") + logger.error(f"No Qdrant collection for chatbot {chatbot_id}") supabase.table("documents").update({ "status": "failed", "error_message": "Vector store not configured" @@ -168,7 +180,7 @@ async def _process_document_bg( }).eq("id", doc_id).execute() response_cache.invalidate(collection_name) - logger.info(f"Document {doc_id} processed: {len(chunks)} chunks") + logger.info(f"Document {doc_id} processed: {len(chunks)} chunks → collection='{collection_name}' company='{company_id}'") except Exception as e: logger.error(f"Document processing error for {doc_id}: {e}") @@ -274,7 +286,7 @@ async def retry_document_processing( file_bytes=file_bytes, file_name=document["file_name"], doc_id=document_id, - chatbot=chatbot, + chatbot_id=chatbot_id, supabase=supabase, ) @@ -333,7 +345,7 @@ async def add_url_source( _process_url_source, source_id=source_id, url=data.url, - chatbot=chatbot, + chatbot_id=chatbot_id, supabase=supabase, ) @@ -394,12 +406,12 @@ async def refresh_url_source( "chunk_count": 0, }).eq("id", source_id).returning("representation").execute() - background_tasks.add_task(_process_url_source, source_id, src["url"], chatbot, supabase) + background_tasks.add_task(_process_url_source, source_id, src["url"], chatbot_id, supabase) return UrlSourceResponse(**{**src, "status": "pending", "chunk_count": 0}) -async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase): +async def _process_url_source(source_id: str, url: str, chatbot_id: str, supabase): """Background task to scrape a URL and add its content to the vector store.""" from app.services.web_scraper import scrape_url from app.services.document_processor import chunk_text @@ -407,6 +419,18 @@ async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase) from app.services.vector_store import vector_store try: + # Re-fetch chatbot to guarantee we use the canonical collection and company_id. + chatbot_row = supabase.table("chatbots").select("company_id, qdrant_collection_name").eq("id", chatbot_id).execute() + if not chatbot_row.data: + logger.error(f"Chatbot {chatbot_id} not found during URL source processing") + supabase.table("url_sources").update({ + "status": "failed", + "error_message": "Chatbot not found", + }).eq("id", source_id).execute() + return + + chatbot = chatbot_row.data[0] + # Update status to processing supabase.table("url_sources").update({"status": "processing"}).eq("id", source_id).execute() @@ -480,7 +504,8 @@ async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase) }).eq("id", source_id).execute() response_cache.invalidate(collection_name) - logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url}") + logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url} → collection='{collection_name}' company='{chatbot.get('company_id', '')}'") + except Exception as e: logger.error(f"URL source processing error {source_id}: {e}") diff --git a/app/services/llm.py b/app/services/llm.py index ec6f10c..8fb256a 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -1,9 +1,43 @@ from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS from typing import List, Dict, Any, Optional, AsyncGenerator import logging +import re logger = logging.getLogger(__name__) +# Ordered fallback chain — tried in sequence when the primary model fails. +# Fireworks models are used for free/starter plans so they must always be available. +# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working). +_FIREWORKS_FALLBACKS = [ + "accounts/fireworks/models/kimi-k2p5-instruct", + "accounts/fireworks/models/deepseek-v3p2", + "accounts/fireworks/models/llama-v3p3-70b-instruct", +] + + +def _normalize_model(model: str) -> str: + """Strip date-based version suffixes from Fireworks model IDs. + e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905' → 'accounts/fireworks/models/kimi-k2-instruct' + Matches only purely-numeric suffixes (4–8 digits) so names like 'llama-v3p3-70b' are untouched.""" + if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"): + model = re.sub(r"-\d{4,8}$", "", model) + return model + + +def _infer_provider(model: str) -> str: + """Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS. + Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'.""" + if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"): + return "fireworks" + if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"): + return "openai" + if model.startswith("claude-"): + return "anthropic" + if model.startswith("gemini-"): + return "google" + logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks") + return "fireworks" + class LLMService: """Routes requests to appropriate LLM provider""" @@ -16,7 +50,8 @@ class LLMService: temperature: float = 0.7, ) -> Dict[str, Any]: """Generate a response from the LLM""" - provider = MODEL_PROVIDERS.get(model, "openai") + model = _normalize_model(model) + provider = MODEL_PROVIDERS.get(model) or _infer_provider(model) try: if provider == "fireworks": @@ -31,9 +66,16 @@ class LLMService: return await self._call_openai(messages, model, max_tokens, temperature) except Exception as e: logger.error(f"LLM error ({provider}/{model}): {e}") - fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct" - if model != fallback and settings.fireworks_api_key: - return await self._call_fireworks(messages, fallback, max_tokens, temperature) + if not settings.fireworks_api_key: + raise + for fallback in _FIREWORKS_FALLBACKS: + if model == fallback: + continue + try: + logger.warning(f"[LLM] Falling back to {fallback}") + return await self._call_fireworks(messages, fallback, max_tokens, temperature) + except Exception as fe: + logger.error(f"[LLM] Fallback {fallback} also failed: {fe}") raise async def _call_fireworks( diff --git a/app/services/rag.py b/app/services/rag.py index 9ef1d25..7cb6dde 100644 --- a/app/services/rag.py +++ b/app/services/rag.py @@ -9,7 +9,7 @@ import logging logger = logging.getLogger(__name__) RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}. -Your role is to answer questions based on the provided context from company documents. +Your role is to answer questions based on the provided context from the knowledge base (documents and web pages). IMPORTANT RULES: 1. Answer based on the provided context below @@ -20,7 +20,7 @@ IMPORTANT RULES: {language_instruction} {custom_instructions} -Context from knowledge base: +Knowledge base context: {context} """ @@ -74,14 +74,22 @@ class RAGEngine: } # Step 2: Retrieve relevant chunks - # Fetch more than needed so that after filtering low-quality results - # we still have enough context. score_threshold=0.55 keeps only chunks - # that are genuinely relevant for text-embedding-3-small cosine similarity. + # Retrieve more candidates than needed (10) with a slightly relaxed threshold (0.45) + # so that content from both document and URL sources gets fair representation. + # Scraped web text embeds less cleanly than structured documents, so 0.55 was + # filtering out valid URL chunks. Context is capped by char limit below. + total_in_collection = self.vector_svc.count_vectors(collection_name) + logger.info(f"[RAG] Collection '{collection_name}' has {total_in_collection} vectors total") + + # No score_threshold — always return the top-N most similar chunks by rank. + # Absolute cosine scores vary widely by document type and embedding model; + # filtering by a fixed cutoff here discards valid context when scores are + # uniformly low. The confidence_score below captures retrieval quality for + # handoff/fallback decisions without silencing the LLM's context. retrieved = self.vector_svc.search( collection_name=collection_name, query_vector=query_embedding, - limit=8, - score_threshold=0.55, + limit=10, ) logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'") @@ -90,25 +98,38 @@ class RAGEngine: text_preview = item.get("payload", {}).get("text", "")[:80] logger.info(f"[RAG] Chunk {i+1}: score={score:.4f}, preview='{text_preview}...'") - # Step 3: Build sources + # Step 3: Build sources and labeled context + # Each chunk is prefixed with its source so the LLM can synthesize + # correctly when mixing document and URL content. + MAX_CONTEXT_CHARS = 10_000 sources = [] context_parts = [] seen_texts = set() + total_chars = 0 for item in retrieved: payload = item.get("payload", {}) text = payload.get("text", "") - if text and text not in seen_texts: - seen_texts.add(text) - context_parts.append(text) - sources.append( - SourceDocument( - document_name=payload.get("file_name", "Document"), - chunk_text=text[:200] + "..." if len(text) > 200 else text, - score=item.get("score", 0.0), - page_number=payload.get("page_number"), - ) + if not text or text in seen_texts: + continue + if total_chars + len(text) > MAX_CONTEXT_CHARS: + break + seen_texts.add(text) + total_chars += len(text) + + file_name = payload.get("file_name", "Document") + source_url = payload.get("source_url") + label = f"[Source: {source_url}]" if source_url else f"[Source: {file_name}]" + context_parts.append(f"{label}\n{text}") + + sources.append( + SourceDocument( + document_name=file_name, + chunk_text=text[:200] + "..." if len(text) > 200 else text, + score=item.get("score", 0.0), + page_number=payload.get("page_number"), ) + ) if context_parts: context = "\n\n---\n\n".join(context_parts) diff --git a/app/services/vector_store.py b/app/services/vector_store.py index 79f76af..16bc062 100644 --- a/app/services/vector_store.py +++ b/app/services/vector_store.py @@ -1,7 +1,5 @@ from qdrant_client import QdrantClient, models -from qdrant_client.http.models import ( - Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue -) +from qdrant_client.http.models import Distance, VectorParams, PointStruct from app.config import settings from typing import List, Dict, Any, Optional import logging @@ -103,15 +101,13 @@ class VectorStoreService: collection_name: str, query_vector: List[float], limit: int = 5, - score_threshold: float = 0.3, ) -> List[Dict[str, Any]]: - """Search for similar vectors""" + """Search for similar vectors, returning the top-N by cosine score.""" try: results = self.client.query_points( collection_name=collection_name, query=query_vector, limit=limit, - score_threshold=score_threshold, ).points return [ { @@ -122,7 +118,7 @@ class VectorStoreService: for r in results ] except Exception as e: - logger.error(f"Error searching vectors: {e}") + logger.error(f"Error searching vectors in '{collection_name}': {e}", exc_info=True) return [] def delete_by_document_id(self, collection_name: str, document_id: str) -> bool: @@ -131,19 +127,21 @@ class VectorStoreService: self.client.delete( collection_name=collection_name, points_selector=models.FilterSelector( - filter=Filter( + filter=models.Filter( must=[ - FieldCondition( + models.FieldCondition( key="document_id", - match=MatchValue(value=document_id), + match=models.MatchValue(value=document_id), ) ] ) ), + wait=True, ) + logger.info(f"Deleted vectors for document '{document_id}' from '{collection_name}'") return True except Exception as e: - logger.error(f"Error deleting document vectors: {e}") + logger.error(f"Error deleting vectors for document '{document_id}' in '{collection_name}': {e}", exc_info=True) return False def count_vectors(self, collection_name: str) -> int: diff --git a/app/services/web_scraper.py b/app/services/web_scraper.py index ee1cc6c..2a12ffd 100644 --- a/app/services/web_scraper.py +++ b/app/services/web_scraper.py @@ -42,9 +42,22 @@ async def scrape_url(url: str) -> dict: main = soup.find("main") or soup.find("article") or soup.find("body") or soup text = main.get_text(separator="\n", strip=True) - # Clean up whitespace - lines = [line.strip() for line in text.splitlines() if line.strip()] - text = "\n".join(lines) + # Clean up whitespace and filter structural noise + seen_lines: set[str] = set() + clean_lines = [] + for line in text.splitlines(): + line = line.strip() + if not line: + continue + # Skip very short lines (nav items, button labels, breadcrumb separators) + if len(line) < 15: + continue + # Skip duplicate lines (nav/footer repeated across sections) + if line in seen_lines: + continue + seen_lines.add(line) + clean_lines.append(line) + text = "\n".join(clean_lines) # Limit size if len(text.encode("utf-8")) > MAX_TEXT_BYTES: