mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-12 23:23:21 +00:00
fixed the RAg in test pipeline issue
This commit is contained in:
@@ -1,9 +1,43 @@
|
||||
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ordered fallback chain — tried in sequence when the primary model fails.
|
||||
# Fireworks models are used for free/starter plans so they must always be available.
|
||||
# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working).
|
||||
_FIREWORKS_FALLBACKS = [
|
||||
"accounts/fireworks/models/kimi-k2p5-instruct",
|
||||
"accounts/fireworks/models/deepseek-v3p2",
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_model(model: str) -> str:
|
||||
"""Strip date-based version suffixes from Fireworks model IDs.
|
||||
e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905' → 'accounts/fireworks/models/kimi-k2-instruct'
|
||||
Matches only purely-numeric suffixes (4–8 digits) so names like 'llama-v3p3-70b' are untouched."""
|
||||
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
|
||||
model = re.sub(r"-\d{4,8}$", "", model)
|
||||
return model
|
||||
|
||||
|
||||
def _infer_provider(model: str) -> str:
|
||||
"""Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS.
|
||||
Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'."""
|
||||
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
|
||||
return "fireworks"
|
||||
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"):
|
||||
return "openai"
|
||||
if model.startswith("claude-"):
|
||||
return "anthropic"
|
||||
if model.startswith("gemini-"):
|
||||
return "google"
|
||||
logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks")
|
||||
return "fireworks"
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""Routes requests to appropriate LLM provider"""
|
||||
@@ -16,7 +50,8 @@ class LLMService:
|
||||
temperature: float = 0.7,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a response from the LLM"""
|
||||
provider = MODEL_PROVIDERS.get(model, "openai")
|
||||
model = _normalize_model(model)
|
||||
provider = MODEL_PROVIDERS.get(model) or _infer_provider(model)
|
||||
|
||||
try:
|
||||
if provider == "fireworks":
|
||||
@@ -31,9 +66,16 @@ class LLMService:
|
||||
return await self._call_openai(messages, model, max_tokens, temperature)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM error ({provider}/{model}): {e}")
|
||||
fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct"
|
||||
if model != fallback and settings.fireworks_api_key:
|
||||
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
|
||||
if not settings.fireworks_api_key:
|
||||
raise
|
||||
for fallback in _FIREWORKS_FALLBACKS:
|
||||
if model == fallback:
|
||||
continue
|
||||
try:
|
||||
logger.warning(f"[LLM] Falling back to {fallback}")
|
||||
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
|
||||
except Exception as fe:
|
||||
logger.error(f"[LLM] Fallback {fallback} also failed: {fe}")
|
||||
raise
|
||||
|
||||
async def _call_fireworks(
|
||||
|
||||
Reference in New Issue
Block a user