fixed the RAg in test pipeline issue

This commit is contained in:
belviskhoremk
2026-04-26 21:43:19 +00:00
parent 78023ae9c5
commit 260a9c6353
9 changed files with 262 additions and 78 deletions

View File

@@ -1,9 +1,43 @@
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
from typing import List, Dict, Any, Optional, AsyncGenerator
import logging
import re
logger = logging.getLogger(__name__)
# Ordered fallback chain — tried in sequence when the primary model fails.
# Fireworks models are used for free/starter plans so they must always be available.
# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working).
_FIREWORKS_FALLBACKS = [
"accounts/fireworks/models/kimi-k2p5-instruct",
"accounts/fireworks/models/deepseek-v3p2",
"accounts/fireworks/models/llama-v3p3-70b-instruct",
]
def _normalize_model(model: str) -> str:
"""Strip date-based version suffixes from Fireworks model IDs.
e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905''accounts/fireworks/models/kimi-k2-instruct'
Matches only purely-numeric suffixes (48 digits) so names like 'llama-v3p3-70b' are untouched."""
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
model = re.sub(r"-\d{4,8}$", "", model)
return model
def _infer_provider(model: str) -> str:
"""Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS.
Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'."""
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
return "fireworks"
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"):
return "openai"
if model.startswith("claude-"):
return "anthropic"
if model.startswith("gemini-"):
return "google"
logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks")
return "fireworks"
class LLMService:
"""Routes requests to appropriate LLM provider"""
@@ -16,7 +50,8 @@ class LLMService:
temperature: float = 0.7,
) -> Dict[str, Any]:
"""Generate a response from the LLM"""
provider = MODEL_PROVIDERS.get(model, "openai")
model = _normalize_model(model)
provider = MODEL_PROVIDERS.get(model) or _infer_provider(model)
try:
if provider == "fireworks":
@@ -31,9 +66,16 @@ class LLMService:
return await self._call_openai(messages, model, max_tokens, temperature)
except Exception as e:
logger.error(f"LLM error ({provider}/{model}): {e}")
fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct"
if model != fallback and settings.fireworks_api_key:
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
if not settings.fireworks_api_key:
raise
for fallback in _FIREWORKS_FALLBACKS:
if model == fallback:
continue
try:
logger.warning(f"[LLM] Falling back to {fallback}")
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
except Exception as fe:
logger.error(f"[LLM] Fallback {fallback} also failed: {fe}")
raise
async def _call_fireworks(