mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-12 23:23:21 +00:00
fixed the RAg in test pipeline issue
This commit is contained in:
@@ -13,6 +13,7 @@ class Settings(BaseSettings):
|
|||||||
supabase_url: str = ""
|
supabase_url: str = ""
|
||||||
supabase_anon_key: str = ""
|
supabase_anon_key: str = ""
|
||||||
supabase_service_role_key: str = ""
|
supabase_service_role_key: str = ""
|
||||||
|
supabase_jwt_secret: Optional[str] = None # Settings → API → JWT Secret in Supabase dashboard
|
||||||
|
|
||||||
# Qdrant
|
# Qdrant
|
||||||
qdrant_url: str = "http://localhost:6333"
|
qdrant_url: str = "http://localhost:6333"
|
||||||
@@ -99,12 +100,24 @@ MODEL_CATALOG = {
|
|||||||
"badge": "Smart",
|
"badge": "Smart",
|
||||||
"description": "Cost-effective and highly capable model",
|
"description": "Cost-effective and highly capable model",
|
||||||
},
|
},
|
||||||
|
"accounts/fireworks/models/deepseek-v3p2": {
|
||||||
|
"name": "DeepSeek V3.2",
|
||||||
|
"provider": "Fireworks AI",
|
||||||
|
"badge": "Smart",
|
||||||
|
"description": "Latest DeepSeek — faster and more capable",
|
||||||
|
},
|
||||||
"accounts/fireworks/models/kimi-k2-instruct": {
|
"accounts/fireworks/models/kimi-k2-instruct": {
|
||||||
"name": "Kimi K2",
|
"name": "Kimi K2",
|
||||||
"provider": "Fireworks AI",
|
"provider": "Fireworks AI",
|
||||||
"badge": "Multilingual",
|
"badge": "Multilingual",
|
||||||
"description": "Strong multilingual and coding capabilities",
|
"description": "Strong multilingual and coding capabilities",
|
||||||
},
|
},
|
||||||
|
"accounts/fireworks/models/kimi-k2p5-instruct": {
|
||||||
|
"name": "Kimi K2.5",
|
||||||
|
"provider": "Fireworks AI",
|
||||||
|
"badge": "Multilingual",
|
||||||
|
"description": "Upgraded Kimi — stronger reasoning and multilingual",
|
||||||
|
},
|
||||||
|
|
||||||
# ── Pro tier (Premium providers) ───────────────────────────────────────────
|
# ── Pro tier (Premium providers) ───────────────────────────────────────────
|
||||||
# OpenAI
|
# OpenAI
|
||||||
@@ -156,7 +169,9 @@ MODEL_PROVIDERS = {
|
|||||||
"accounts/fireworks/models/llama-v3p3-70b-instruct": "fireworks",
|
"accounts/fireworks/models/llama-v3p3-70b-instruct": "fireworks",
|
||||||
"accounts/fireworks/models/qwen3-235b-a22b": "fireworks",
|
"accounts/fireworks/models/qwen3-235b-a22b": "fireworks",
|
||||||
"accounts/fireworks/models/deepseek-v3p1": "fireworks",
|
"accounts/fireworks/models/deepseek-v3p1": "fireworks",
|
||||||
|
"accounts/fireworks/models/deepseek-v3p2": "fireworks",
|
||||||
"accounts/fireworks/models/kimi-k2-instruct": "fireworks",
|
"accounts/fireworks/models/kimi-k2-instruct": "fireworks",
|
||||||
|
"accounts/fireworks/models/kimi-k2p5-instruct": "fireworks",
|
||||||
# OpenAI
|
# OpenAI
|
||||||
"gpt-4o": "openai",
|
"gpt-4o": "openai",
|
||||||
"gpt-4o-mini": "openai",
|
"gpt-4o-mini": "openai",
|
||||||
@@ -209,7 +224,9 @@ _ALL_FIREWORKS = [
|
|||||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||||
"accounts/fireworks/models/qwen3-235b-a22b",
|
"accounts/fireworks/models/qwen3-235b-a22b",
|
||||||
"accounts/fireworks/models/deepseek-v3p1",
|
"accounts/fireworks/models/deepseek-v3p1",
|
||||||
|
"accounts/fireworks/models/deepseek-v3p2",
|
||||||
"accounts/fireworks/models/kimi-k2-instruct",
|
"accounts/fireworks/models/kimi-k2-instruct",
|
||||||
|
"accounts/fireworks/models/kimi-k2p5-instruct",
|
||||||
]
|
]
|
||||||
_ALL_PREMIUM = [
|
_ALL_PREMIUM = [
|
||||||
"gpt-4o", "gpt-4o-mini",
|
"gpt-4o", "gpt-4o-mini",
|
||||||
|
|||||||
@@ -1,18 +1,77 @@
|
|||||||
from fastapi import Depends, HTTPException, status, Header
|
from fastapi import Depends, HTTPException, status, Header
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from app.database import get_supabase
|
from app.database import get_supabase
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _LocalUser:
|
||||||
|
"""Minimal user object built from JWT claims — mirrors the fields used downstream."""
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
role: str = "authenticated"
|
||||||
|
app_metadata: dict = field(default_factory=dict)
|
||||||
|
user_metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_jwt_local(token: str) -> Optional[_LocalUser]:
|
||||||
|
"""Verify a Supabase HS256 JWT using the local secret (no network call).
|
||||||
|
Returns None if the secret is not configured, the signature is wrong, or the token is expired."""
|
||||||
|
secret = settings.supabase_jwt_secret
|
||||||
|
if not secret:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parts = token.split(".")
|
||||||
|
if len(parts) != 3:
|
||||||
|
return None
|
||||||
|
header_b64, payload_b64, sig_b64 = parts
|
||||||
|
|
||||||
|
# Verify HMAC-SHA256 signature
|
||||||
|
message = f"{header_b64}.{payload_b64}".encode()
|
||||||
|
expected = hmac.new(secret.encode(), message, hashlib.sha256).digest()
|
||||||
|
padding = "=" * (-len(sig_b64) % 4)
|
||||||
|
actual = base64.urlsafe_b64decode(sig_b64 + padding)
|
||||||
|
if not hmac.compare_digest(expected, actual):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Decode payload
|
||||||
|
padding = "=" * (-len(payload_b64) % 4)
|
||||||
|
payload = json.loads(base64.urlsafe_b64decode(payload_b64 + padding))
|
||||||
|
|
||||||
|
# Check expiry
|
||||||
|
if payload.get("exp", 0) < time.time():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _LocalUser(
|
||||||
|
id=payload["sub"],
|
||||||
|
email=payload.get("email", ""),
|
||||||
|
role=payload.get("role", "authenticated"),
|
||||||
|
app_metadata=payload.get("app_metadata", {}),
|
||||||
|
user_metadata=payload.get("user_metadata", {}),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||||
):
|
):
|
||||||
"""Extract and verify the current user from Supabase JWT"""
|
"""Extract and verify the current user from a Supabase JWT.
|
||||||
|
|
||||||
|
Tries local HS256 verification first (no network call, no SSL risk).
|
||||||
|
Falls back to supabase.auth.get_user() only when the JWT secret is not configured.
|
||||||
|
"""
|
||||||
if not credentials:
|
if not credentials:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -20,8 +79,13 @@ async def get_current_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
token = credentials.credentials
|
token = credentials.credentials
|
||||||
supabase = get_supabase()
|
|
||||||
|
|
||||||
|
# ── Fast path: local verification ────────────────────────────────────────
|
||||||
|
user = _verify_jwt_local(token)
|
||||||
|
|
||||||
|
# ── Slow path: network call (only if SUPABASE_JWT_SECRET is not set) ─────
|
||||||
|
if user is None:
|
||||||
|
supabase = get_supabase()
|
||||||
try:
|
try:
|
||||||
response = supabase.auth.get_user(token)
|
response = supabase.auth.get_user(token)
|
||||||
if not response or not response.user:
|
if not response or not response.user:
|
||||||
@@ -30,9 +94,18 @@ async def get_current_user(
|
|||||||
detail="Invalid or expired token",
|
detail="Invalid or expired token",
|
||||||
)
|
)
|
||||||
user = response.user
|
user = response.user
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Auth error: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired token",
|
||||||
|
)
|
||||||
|
|
||||||
# Check for suspension
|
# ── Suspension check (DB, not network-auth, so still fast) ───────────────
|
||||||
try:
|
try:
|
||||||
|
supabase = get_supabase()
|
||||||
profile = supabase.table("user_profiles").select("suspended_at").eq("user_id", user.id).execute()
|
profile = supabase.table("user_profiles").select("suspended_at").eq("user_id", user.id).execute()
|
||||||
if profile.data and profile.data[0].get("suspended_at"):
|
if profile.data and profile.data[0].get("suspended_at"):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -42,17 +115,9 @@ async def get_current_user(
|
|||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Don't block login if profile lookup fails
|
pass # Never block login if profile lookup fails
|
||||||
|
|
||||||
return user
|
return user
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Auth error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid or expired token",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_admin_user(
|
async def get_admin_user(
|
||||||
|
|||||||
@@ -32,3 +32,4 @@ def configure_logging():
|
|||||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("hpack").setLevel(logging.WARNING)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
@@ -311,8 +312,7 @@ async def test_chat(
|
|||||||
company_data = chatbot.get("companies", {}) or {}
|
company_data = chatbot.get("companies", {}) or {}
|
||||||
chatbot_config = {**chatbot, "company_name": company_data.get("name", "")}
|
chatbot_config = {**chatbot, "company_name": company_data.get("name", "")}
|
||||||
|
|
||||||
results = []
|
async def _run_one(question: str) -> TestChatResult:
|
||||||
for question in body.questions:
|
|
||||||
try:
|
try:
|
||||||
result = await rag_engine.process_query(
|
result = await rag_engine.process_query(
|
||||||
query=question,
|
query=question,
|
||||||
@@ -322,22 +322,24 @@ async def test_chat(
|
|||||||
language="auto",
|
language="auto",
|
||||||
bypass_cache=True,
|
bypass_cache=True,
|
||||||
)
|
)
|
||||||
results.append(TestChatResult(
|
return TestChatResult(
|
||||||
question=question,
|
question=question,
|
||||||
response=result["response"],
|
response=result["response"],
|
||||||
confidence_score=result.get("confidence_score", 0.0),
|
confidence_score=result.get("confidence_score", 0.0),
|
||||||
sources=result.get("sources", []),
|
sources=result.get("sources", []),
|
||||||
model_used=result.get("model", ""),
|
model_used=result.get("model", ""),
|
||||||
))
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results.append(TestChatResult(
|
return TestChatResult(
|
||||||
question=question,
|
question=question,
|
||||||
response=f"Error: {e}",
|
response=f"Error: {e}",
|
||||||
confidence_score=0.0,
|
confidence_score=0.0,
|
||||||
sources=[],
|
sources=[],
|
||||||
model_used="",
|
model_used="",
|
||||||
))
|
)
|
||||||
return results
|
|
||||||
|
results = await asyncio.gather(*[_run_one(q) for q in body.questions])
|
||||||
|
return list(results)
|
||||||
|
|
||||||
|
|
||||||
# ── OLD analytics endpoint REMOVED ───────────────────────────────────────────
|
# ── OLD analytics endpoint REMOVED ───────────────────────────────────────────
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ async def upload_document(
|
|||||||
file_bytes=file_bytes,
|
file_bytes=file_bytes,
|
||||||
file_name=file.filename,
|
file_name=file.filename,
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
chatbot=chatbot,
|
chatbot_id=chatbot_id,
|
||||||
supabase=supabase,
|
supabase=supabase,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,16 +105,28 @@ async def _process_document_bg(
|
|||||||
file_bytes: bytes,
|
file_bytes: bytes,
|
||||||
file_name: str,
|
file_name: str,
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
chatbot: dict,
|
chatbot_id: str,
|
||||||
supabase,
|
supabase,
|
||||||
):
|
):
|
||||||
"""Background task to process and embed a document"""
|
"""Background task to process and embed a document"""
|
||||||
try:
|
try:
|
||||||
|
# Re-fetch chatbot to guarantee we use the canonical collection and company_id,
|
||||||
|
# not a snapshot that could have been captured before an update.
|
||||||
|
chatbot_row = supabase.table("chatbots").select("company_id, qdrant_collection_name").eq("id", chatbot_id).execute()
|
||||||
|
if not chatbot_row.data:
|
||||||
|
logger.error(f"Chatbot {chatbot_id} not found during document processing")
|
||||||
|
supabase.table("documents").update({
|
||||||
|
"status": "failed",
|
||||||
|
"error_message": "Chatbot not found"
|
||||||
|
}).eq("id", doc_id).execute()
|
||||||
|
return
|
||||||
|
|
||||||
|
chatbot = chatbot_row.data[0]
|
||||||
company_id = chatbot.get("company_id", "")
|
company_id = chatbot.get("company_id", "")
|
||||||
collection_name = chatbot.get("qdrant_collection_name")
|
collection_name = chatbot.get("qdrant_collection_name")
|
||||||
|
|
||||||
if not collection_name:
|
if not collection_name:
|
||||||
logger.error(f"No Qdrant collection for chatbot {chatbot['id']}")
|
logger.error(f"No Qdrant collection for chatbot {chatbot_id}")
|
||||||
supabase.table("documents").update({
|
supabase.table("documents").update({
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
"error_message": "Vector store not configured"
|
"error_message": "Vector store not configured"
|
||||||
@@ -168,7 +180,7 @@ async def _process_document_bg(
|
|||||||
}).eq("id", doc_id).execute()
|
}).eq("id", doc_id).execute()
|
||||||
|
|
||||||
response_cache.invalidate(collection_name)
|
response_cache.invalidate(collection_name)
|
||||||
logger.info(f"Document {doc_id} processed: {len(chunks)} chunks")
|
logger.info(f"Document {doc_id} processed: {len(chunks)} chunks → collection='{collection_name}' company='{company_id}'")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Document processing error for {doc_id}: {e}")
|
logger.error(f"Document processing error for {doc_id}: {e}")
|
||||||
@@ -274,7 +286,7 @@ async def retry_document_processing(
|
|||||||
file_bytes=file_bytes,
|
file_bytes=file_bytes,
|
||||||
file_name=document["file_name"],
|
file_name=document["file_name"],
|
||||||
doc_id=document_id,
|
doc_id=document_id,
|
||||||
chatbot=chatbot,
|
chatbot_id=chatbot_id,
|
||||||
supabase=supabase,
|
supabase=supabase,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -333,7 +345,7 @@ async def add_url_source(
|
|||||||
_process_url_source,
|
_process_url_source,
|
||||||
source_id=source_id,
|
source_id=source_id,
|
||||||
url=data.url,
|
url=data.url,
|
||||||
chatbot=chatbot,
|
chatbot_id=chatbot_id,
|
||||||
supabase=supabase,
|
supabase=supabase,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -394,12 +406,12 @@ async def refresh_url_source(
|
|||||||
"chunk_count": 0,
|
"chunk_count": 0,
|
||||||
}).eq("id", source_id).returning("representation").execute()
|
}).eq("id", source_id).returning("representation").execute()
|
||||||
|
|
||||||
background_tasks.add_task(_process_url_source, source_id, src["url"], chatbot, supabase)
|
background_tasks.add_task(_process_url_source, source_id, src["url"], chatbot_id, supabase)
|
||||||
|
|
||||||
return UrlSourceResponse(**{**src, "status": "pending", "chunk_count": 0})
|
return UrlSourceResponse(**{**src, "status": "pending", "chunk_count": 0})
|
||||||
|
|
||||||
|
|
||||||
async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase):
|
async def _process_url_source(source_id: str, url: str, chatbot_id: str, supabase):
|
||||||
"""Background task to scrape a URL and add its content to the vector store."""
|
"""Background task to scrape a URL and add its content to the vector store."""
|
||||||
from app.services.web_scraper import scrape_url
|
from app.services.web_scraper import scrape_url
|
||||||
from app.services.document_processor import chunk_text
|
from app.services.document_processor import chunk_text
|
||||||
@@ -407,6 +419,18 @@ async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase)
|
|||||||
from app.services.vector_store import vector_store
|
from app.services.vector_store import vector_store
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Re-fetch chatbot to guarantee we use the canonical collection and company_id.
|
||||||
|
chatbot_row = supabase.table("chatbots").select("company_id, qdrant_collection_name").eq("id", chatbot_id).execute()
|
||||||
|
if not chatbot_row.data:
|
||||||
|
logger.error(f"Chatbot {chatbot_id} not found during URL source processing")
|
||||||
|
supabase.table("url_sources").update({
|
||||||
|
"status": "failed",
|
||||||
|
"error_message": "Chatbot not found",
|
||||||
|
}).eq("id", source_id).execute()
|
||||||
|
return
|
||||||
|
|
||||||
|
chatbot = chatbot_row.data[0]
|
||||||
|
|
||||||
# Update status to processing
|
# Update status to processing
|
||||||
supabase.table("url_sources").update({"status": "processing"}).eq("id", source_id).execute()
|
supabase.table("url_sources").update({"status": "processing"}).eq("id", source_id).execute()
|
||||||
|
|
||||||
@@ -480,7 +504,8 @@ async def _process_url_source(source_id: str, url: str, chatbot: dict, supabase)
|
|||||||
}).eq("id", source_id).execute()
|
}).eq("id", source_id).execute()
|
||||||
|
|
||||||
response_cache.invalidate(collection_name)
|
response_cache.invalidate(collection_name)
|
||||||
logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url}")
|
logger.info(f"URL source {source_id} processed: {len(chunks)} chunks from {url} → collection='{collection_name}' company='{chatbot.get('company_id', '')}'")
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"URL source processing error {source_id}: {e}")
|
logger.error(f"URL source processing error {source_id}: {e}")
|
||||||
|
|||||||
@@ -1,9 +1,43 @@
|
|||||||
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
|
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
|
||||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Ordered fallback chain — tried in sequence when the primary model fails.
|
||||||
|
# Fireworks models are used for free/starter plans so they must always be available.
|
||||||
|
# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working).
|
||||||
|
_FIREWORKS_FALLBACKS = [
|
||||||
|
"accounts/fireworks/models/kimi-k2p5-instruct",
|
||||||
|
"accounts/fireworks/models/deepseek-v3p2",
|
||||||
|
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_model(model: str) -> str:
|
||||||
|
"""Strip date-based version suffixes from Fireworks model IDs.
|
||||||
|
e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905' → 'accounts/fireworks/models/kimi-k2-instruct'
|
||||||
|
Matches only purely-numeric suffixes (4–8 digits) so names like 'llama-v3p3-70b' are untouched."""
|
||||||
|
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
|
||||||
|
model = re.sub(r"-\d{4,8}$", "", model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_provider(model: str) -> str:
|
||||||
|
"""Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS.
|
||||||
|
Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'."""
|
||||||
|
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
|
||||||
|
return "fireworks"
|
||||||
|
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"):
|
||||||
|
return "openai"
|
||||||
|
if model.startswith("claude-"):
|
||||||
|
return "anthropic"
|
||||||
|
if model.startswith("gemini-"):
|
||||||
|
return "google"
|
||||||
|
logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks")
|
||||||
|
return "fireworks"
|
||||||
|
|
||||||
|
|
||||||
class LLMService:
|
class LLMService:
|
||||||
"""Routes requests to appropriate LLM provider"""
|
"""Routes requests to appropriate LLM provider"""
|
||||||
@@ -16,7 +50,8 @@ class LLMService:
|
|||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Generate a response from the LLM"""
|
"""Generate a response from the LLM"""
|
||||||
provider = MODEL_PROVIDERS.get(model, "openai")
|
model = _normalize_model(model)
|
||||||
|
provider = MODEL_PROVIDERS.get(model) or _infer_provider(model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if provider == "fireworks":
|
if provider == "fireworks":
|
||||||
@@ -31,9 +66,16 @@ class LLMService:
|
|||||||
return await self._call_openai(messages, model, max_tokens, temperature)
|
return await self._call_openai(messages, model, max_tokens, temperature)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM error ({provider}/{model}): {e}")
|
logger.error(f"LLM error ({provider}/{model}): {e}")
|
||||||
fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct"
|
if not settings.fireworks_api_key:
|
||||||
if model != fallback and settings.fireworks_api_key:
|
raise
|
||||||
|
for fallback in _FIREWORKS_FALLBACKS:
|
||||||
|
if model == fallback:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
logger.warning(f"[LLM] Falling back to {fallback}")
|
||||||
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
|
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
|
||||||
|
except Exception as fe:
|
||||||
|
logger.error(f"[LLM] Fallback {fallback} also failed: {fe}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _call_fireworks(
|
async def _call_fireworks(
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}.
|
RAG_SYSTEM_PROMPT = """You are a helpful AI assistant for {company_name}.
|
||||||
Your role is to answer questions based on the provided context from company documents.
|
Your role is to answer questions based on the provided context from the knowledge base (documents and web pages).
|
||||||
|
|
||||||
IMPORTANT RULES:
|
IMPORTANT RULES:
|
||||||
1. Answer based on the provided context below
|
1. Answer based on the provided context below
|
||||||
@@ -20,7 +20,7 @@ IMPORTANT RULES:
|
|||||||
{language_instruction}
|
{language_instruction}
|
||||||
{custom_instructions}
|
{custom_instructions}
|
||||||
|
|
||||||
Context from knowledge base:
|
Knowledge base context:
|
||||||
{context}
|
{context}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -74,14 +74,22 @@ class RAGEngine:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Step 2: Retrieve relevant chunks
|
# Step 2: Retrieve relevant chunks
|
||||||
# Fetch more than needed so that after filtering low-quality results
|
# Retrieve more candidates than needed (10) with a slightly relaxed threshold (0.45)
|
||||||
# we still have enough context. score_threshold=0.55 keeps only chunks
|
# so that content from both document and URL sources gets fair representation.
|
||||||
# that are genuinely relevant for text-embedding-3-small cosine similarity.
|
# Scraped web text embeds less cleanly than structured documents, so 0.55 was
|
||||||
|
# filtering out valid URL chunks. Context is capped by char limit below.
|
||||||
|
total_in_collection = self.vector_svc.count_vectors(collection_name)
|
||||||
|
logger.info(f"[RAG] Collection '{collection_name}' has {total_in_collection} vectors total")
|
||||||
|
|
||||||
|
# No score_threshold — always return the top-N most similar chunks by rank.
|
||||||
|
# Absolute cosine scores vary widely by document type and embedding model;
|
||||||
|
# filtering by a fixed cutoff here discards valid context when scores are
|
||||||
|
# uniformly low. The confidence_score below captures retrieval quality for
|
||||||
|
# handoff/fallback decisions without silencing the LLM's context.
|
||||||
retrieved = self.vector_svc.search(
|
retrieved = self.vector_svc.search(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query_vector=query_embedding,
|
query_vector=query_embedding,
|
||||||
limit=8,
|
limit=10,
|
||||||
score_threshold=0.55,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'")
|
logger.info(f"[RAG] Retrieved {len(retrieved)} chunks from collection '{collection_name}'")
|
||||||
@@ -90,20 +98,33 @@ class RAGEngine:
|
|||||||
text_preview = item.get("payload", {}).get("text", "")[:80]
|
text_preview = item.get("payload", {}).get("text", "")[:80]
|
||||||
logger.info(f"[RAG] Chunk {i+1}: score={score:.4f}, preview='{text_preview}...'")
|
logger.info(f"[RAG] Chunk {i+1}: score={score:.4f}, preview='{text_preview}...'")
|
||||||
|
|
||||||
# Step 3: Build sources
|
# Step 3: Build sources and labeled context
|
||||||
|
# Each chunk is prefixed with its source so the LLM can synthesize
|
||||||
|
# correctly when mixing document and URL content.
|
||||||
|
MAX_CONTEXT_CHARS = 10_000
|
||||||
sources = []
|
sources = []
|
||||||
context_parts = []
|
context_parts = []
|
||||||
seen_texts = set()
|
seen_texts = set()
|
||||||
|
total_chars = 0
|
||||||
|
|
||||||
for item in retrieved:
|
for item in retrieved:
|
||||||
payload = item.get("payload", {})
|
payload = item.get("payload", {})
|
||||||
text = payload.get("text", "")
|
text = payload.get("text", "")
|
||||||
if text and text not in seen_texts:
|
if not text or text in seen_texts:
|
||||||
|
continue
|
||||||
|
if total_chars + len(text) > MAX_CONTEXT_CHARS:
|
||||||
|
break
|
||||||
seen_texts.add(text)
|
seen_texts.add(text)
|
||||||
context_parts.append(text)
|
total_chars += len(text)
|
||||||
|
|
||||||
|
file_name = payload.get("file_name", "Document")
|
||||||
|
source_url = payload.get("source_url")
|
||||||
|
label = f"[Source: {source_url}]" if source_url else f"[Source: {file_name}]"
|
||||||
|
context_parts.append(f"{label}\n{text}")
|
||||||
|
|
||||||
sources.append(
|
sources.append(
|
||||||
SourceDocument(
|
SourceDocument(
|
||||||
document_name=payload.get("file_name", "Document"),
|
document_name=file_name,
|
||||||
chunk_text=text[:200] + "..." if len(text) > 200 else text,
|
chunk_text=text[:200] + "..." if len(text) > 200 else text,
|
||||||
score=item.get("score", 0.0),
|
score=item.get("score", 0.0),
|
||||||
page_number=payload.get("page_number"),
|
page_number=payload.get("page_number"),
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from qdrant_client import QdrantClient, models
|
from qdrant_client import QdrantClient, models
|
||||||
from qdrant_client.http.models import (
|
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
||||||
Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
|
||||||
)
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
import logging
|
import logging
|
||||||
@@ -103,15 +101,13 @@ class VectorStoreService:
|
|||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_vector: List[float],
|
query_vector: List[float],
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.3,
|
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Search for similar vectors"""
|
"""Search for similar vectors, returning the top-N by cosine score."""
|
||||||
try:
|
try:
|
||||||
results = self.client.query_points(
|
results = self.client.query_points(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query=query_vector,
|
query=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
score_threshold=score_threshold,
|
|
||||||
).points
|
).points
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@@ -122,7 +118,7 @@ class VectorStoreService:
|
|||||||
for r in results
|
for r in results
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error searching vectors: {e}")
|
logger.error(f"Error searching vectors in '{collection_name}': {e}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def delete_by_document_id(self, collection_name: str, document_id: str) -> bool:
|
def delete_by_document_id(self, collection_name: str, document_id: str) -> bool:
|
||||||
@@ -131,19 +127,21 @@ class VectorStoreService:
|
|||||||
self.client.delete(
|
self.client.delete(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
points_selector=models.FilterSelector(
|
points_selector=models.FilterSelector(
|
||||||
filter=Filter(
|
filter=models.Filter(
|
||||||
must=[
|
must=[
|
||||||
FieldCondition(
|
models.FieldCondition(
|
||||||
key="document_id",
|
key="document_id",
|
||||||
match=MatchValue(value=document_id),
|
match=models.MatchValue(value=document_id),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
wait=True,
|
||||||
)
|
)
|
||||||
|
logger.info(f"Deleted vectors for document '{document_id}' from '{collection_name}'")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting document vectors: {e}")
|
logger.error(f"Error deleting vectors for document '{document_id}' in '{collection_name}': {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def count_vectors(self, collection_name: str) -> int:
|
def count_vectors(self, collection_name: str) -> int:
|
||||||
|
|||||||
@@ -42,9 +42,22 @@ async def scrape_url(url: str) -> dict:
|
|||||||
main = soup.find("main") or soup.find("article") or soup.find("body") or soup
|
main = soup.find("main") or soup.find("article") or soup.find("body") or soup
|
||||||
text = main.get_text(separator="\n", strip=True)
|
text = main.get_text(separator="\n", strip=True)
|
||||||
|
|
||||||
# Clean up whitespace
|
# Clean up whitespace and filter structural noise
|
||||||
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
seen_lines: set[str] = set()
|
||||||
text = "\n".join(lines)
|
clean_lines = []
|
||||||
|
for line in text.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# Skip very short lines (nav items, button labels, breadcrumb separators)
|
||||||
|
if len(line) < 15:
|
||||||
|
continue
|
||||||
|
# Skip duplicate lines (nav/footer repeated across sections)
|
||||||
|
if line in seen_lines:
|
||||||
|
continue
|
||||||
|
seen_lines.add(line)
|
||||||
|
clean_lines.append(line)
|
||||||
|
text = "\n".join(clean_lines)
|
||||||
|
|
||||||
# Limit size
|
# Limit size
|
||||||
if len(text.encode("utf-8")) > MAX_TEXT_BYTES:
|
if len(text.encode("utf-8")) > MAX_TEXT_BYTES:
|
||||||
|
|||||||
Reference in New Issue
Block a user