mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-13 08:45:24 +00:00
208 lines
7.3 KiB
Python
208 lines
7.3 KiB
Python
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
|
||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||
import logging
|
||
import re
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Ordered fallback chain — tried in sequence when the primary model fails.
|
||
# Fireworks models are used for free/starter plans so they must always be available.
|
||
# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working).
|
||
_FIREWORKS_FALLBACKS = [
|
||
"accounts/fireworks/models/kimi-k2p5-instruct",
|
||
"accounts/fireworks/models/deepseek-v3p2",
|
||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||
]
|
||
|
||
|
||
def _normalize_model(model: str) -> str:
|
||
"""Strip date-based version suffixes from Fireworks model IDs.
|
||
e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905' → 'accounts/fireworks/models/kimi-k2-instruct'
|
||
Matches only purely-numeric suffixes (4–8 digits) so names like 'llama-v3p3-70b' are untouched."""
|
||
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
|
||
model = re.sub(r"-\d{4,8}$", "", model)
|
||
return model
|
||
|
||
|
||
def _infer_provider(model: str) -> str:
|
||
"""Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS.
|
||
Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'."""
|
||
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
|
||
return "fireworks"
|
||
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"):
|
||
return "openai"
|
||
if model.startswith("claude-"):
|
||
return "anthropic"
|
||
if model.startswith("gemini-"):
|
||
return "google"
|
||
logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks")
|
||
return "fireworks"
|
||
|
||
|
||
class LLMService:
|
||
"""Routes requests to appropriate LLM provider"""
|
||
|
||
async def generate(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str,
|
||
max_tokens: int = 1000,
|
||
temperature: float = 0.7,
|
||
) -> Dict[str, Any]:
|
||
"""Generate a response from the LLM"""
|
||
model = _normalize_model(model)
|
||
provider = MODEL_PROVIDERS.get(model) or _infer_provider(model)
|
||
|
||
try:
|
||
if provider == "fireworks":
|
||
return await self._call_fireworks(messages, model, max_tokens, temperature)
|
||
elif provider == "openai":
|
||
return await self._call_openai(messages, model, max_tokens, temperature)
|
||
elif provider == "anthropic":
|
||
return await self._call_anthropic(messages, model, max_tokens, temperature)
|
||
elif provider == "google":
|
||
return await self._call_google(messages, model, max_tokens, temperature)
|
||
else:
|
||
return await self._call_openai(messages, model, max_tokens, temperature)
|
||
except Exception as e:
|
||
logger.error(f"LLM error ({provider}/{model}): {e}")
|
||
if not settings.fireworks_api_key:
|
||
raise
|
||
for fallback in _FIREWORKS_FALLBACKS:
|
||
if model == fallback:
|
||
continue
|
||
try:
|
||
logger.warning(f"[LLM] Falling back to {fallback}")
|
||
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
|
||
except Exception as fe:
|
||
logger.error(f"[LLM] Fallback {fallback} also failed: {fe}")
|
||
raise
|
||
|
||
async def _call_fireworks(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str,
|
||
max_tokens: int,
|
||
temperature: float,
|
||
) -> Dict[str, Any]:
|
||
import httpx
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.fireworks_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"max_tokens": max_tokens,
|
||
"temperature": temperature,
|
||
}
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
resp = await client.post(
|
||
"https://api.fireworks.ai/inference/v1/chat/completions",
|
||
headers=headers,
|
||
json=payload,
|
||
)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
return {
|
||
"content": data["choices"][0]["message"]["content"],
|
||
"tokens_used": data.get("usage", {}).get("total_tokens", 0),
|
||
"model": model,
|
||
}
|
||
|
||
async def _call_openai(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str,
|
||
max_tokens: int,
|
||
temperature: float,
|
||
) -> Dict[str, Any]:
|
||
from openai import AsyncOpenAI
|
||
|
||
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature,
|
||
)
|
||
return {
|
||
"content": response.choices[0].message.content,
|
||
"tokens_used": response.usage.total_tokens if response.usage else 0,
|
||
"model": model,
|
||
}
|
||
|
||
async def _call_anthropic(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str,
|
||
max_tokens: int,
|
||
temperature: float,
|
||
) -> Dict[str, Any]:
|
||
import anthropic
|
||
|
||
client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||
|
||
# Separate system message from conversation
|
||
system_msg = ""
|
||
conv_messages = []
|
||
for msg in messages:
|
||
if msg["role"] == "system":
|
||
system_msg = msg["content"]
|
||
else:
|
||
conv_messages.append(msg)
|
||
|
||
response = await client.messages.create(
|
||
model=model,
|
||
max_tokens=max_tokens,
|
||
system=system_msg if system_msg else "You are a helpful assistant.",
|
||
messages=conv_messages,
|
||
temperature=temperature,
|
||
)
|
||
return {
|
||
"content": response.content[0].text,
|
||
"tokens_used": response.usage.input_tokens + response.usage.output_tokens,
|
||
"model": model,
|
||
}
|
||
|
||
async def _call_google(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str,
|
||
max_tokens: int,
|
||
temperature: float,
|
||
) -> Dict[str, Any]:
|
||
import google.genai as genai
|
||
|
||
genai.configure(api_key=settings.google_api_key)
|
||
gemini_model = genai.GenerativeModel(model)
|
||
|
||
# Convert messages
|
||
parts = []
|
||
for msg in messages:
|
||
role = "user" if msg["role"] in ("user", "system") else "model"
|
||
parts.append({"role": role, "parts": [msg["content"]]})
|
||
|
||
# Use last message as prompt if only one
|
||
if len(parts) == 1:
|
||
response = await gemini_model.generate_content_async(
|
||
parts[0]["parts"][0],
|
||
generation_config={"max_output_tokens": max_tokens, "temperature": temperature},
|
||
)
|
||
else:
|
||
chat = gemini_model.start_chat(history=parts[:-1])
|
||
response = await chat.send_message_async(
|
||
parts[-1]["parts"][0],
|
||
generation_config={"max_output_tokens": max_tokens, "temperature": temperature},
|
||
)
|
||
|
||
return {
|
||
"content": response.text,
|
||
"tokens_used": 0,
|
||
"model": model,
|
||
}
|
||
|
||
|
||
llm_service = LLMService()
|