Files
contexta_be/app/services/llm.py
2026-04-26 21:43:19 +00:00

208 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
from typing import List, Dict, Any, Optional, AsyncGenerator
import logging
import re
logger = logging.getLogger(__name__)
# Ordered fallback chain — tried in sequence when the primary model fails.
# Fireworks models are used for free/starter plans so they must always be available.
# llama-v3p3-70b-instruct is the guaranteed last resort (confirmed working).
_FIREWORKS_FALLBACKS = [
"accounts/fireworks/models/kimi-k2p5-instruct",
"accounts/fireworks/models/deepseek-v3p2",
"accounts/fireworks/models/llama-v3p3-70b-instruct",
]
def _normalize_model(model: str) -> str:
"""Strip date-based version suffixes from Fireworks model IDs.
e.g. 'accounts/fireworks/models/kimi-k2-instruct-0905''accounts/fireworks/models/kimi-k2-instruct'
Matches only purely-numeric suffixes (48 digits) so names like 'llama-v3p3-70b' are untouched."""
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
model = re.sub(r"-\d{4,8}$", "", model)
return model
def _infer_provider(model: str) -> str:
"""Infer the LLM provider from the model ID when it's not in MODEL_PROVIDERS.
Handles versioned variants like 'accounts/fireworks/models/kimi-k2-instruct-0905'."""
if model.startswith("accounts/fireworks/") or model.startswith("fireworks/"):
return "fireworks"
if model.startswith("gpt-") or model.startswith("o1") or model.startswith("o3"):
return "openai"
if model.startswith("claude-"):
return "anthropic"
if model.startswith("gemini-"):
return "google"
logger.warning(f"[LLM] Unknown model '{model}', defaulting to fireworks")
return "fireworks"
class LLMService:
"""Routes requests to appropriate LLM provider"""
async def generate(
self,
messages: List[Dict[str, str]],
model: str,
max_tokens: int = 1000,
temperature: float = 0.7,
) -> Dict[str, Any]:
"""Generate a response from the LLM"""
model = _normalize_model(model)
provider = MODEL_PROVIDERS.get(model) or _infer_provider(model)
try:
if provider == "fireworks":
return await self._call_fireworks(messages, model, max_tokens, temperature)
elif provider == "openai":
return await self._call_openai(messages, model, max_tokens, temperature)
elif provider == "anthropic":
return await self._call_anthropic(messages, model, max_tokens, temperature)
elif provider == "google":
return await self._call_google(messages, model, max_tokens, temperature)
else:
return await self._call_openai(messages, model, max_tokens, temperature)
except Exception as e:
logger.error(f"LLM error ({provider}/{model}): {e}")
if not settings.fireworks_api_key:
raise
for fallback in _FIREWORKS_FALLBACKS:
if model == fallback:
continue
try:
logger.warning(f"[LLM] Falling back to {fallback}")
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
except Exception as fe:
logger.error(f"[LLM] Fallback {fallback} also failed: {fe}")
raise
async def _call_fireworks(
self,
messages: List[Dict[str, str]],
model: str,
max_tokens: int,
temperature: float,
) -> Dict[str, Any]:
import httpx
headers = {
"Authorization": f"Bearer {settings.fireworks_api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(
"https://api.fireworks.ai/inference/v1/chat/completions",
headers=headers,
json=payload,
)
resp.raise_for_status()
data = resp.json()
return {
"content": data["choices"][0]["message"]["content"],
"tokens_used": data.get("usage", {}).get("total_tokens", 0),
"model": model,
}
async def _call_openai(
self,
messages: List[Dict[str, str]],
model: str,
max_tokens: int,
temperature: float,
) -> Dict[str, Any]:
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key=settings.openai_api_key)
response = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return {
"content": response.choices[0].message.content,
"tokens_used": response.usage.total_tokens if response.usage else 0,
"model": model,
}
async def _call_anthropic(
self,
messages: List[Dict[str, str]],
model: str,
max_tokens: int,
temperature: float,
) -> Dict[str, Any]:
import anthropic
client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key)
# Separate system message from conversation
system_msg = ""
conv_messages = []
for msg in messages:
if msg["role"] == "system":
system_msg = msg["content"]
else:
conv_messages.append(msg)
response = await client.messages.create(
model=model,
max_tokens=max_tokens,
system=system_msg if system_msg else "You are a helpful assistant.",
messages=conv_messages,
temperature=temperature,
)
return {
"content": response.content[0].text,
"tokens_used": response.usage.input_tokens + response.usage.output_tokens,
"model": model,
}
async def _call_google(
self,
messages: List[Dict[str, str]],
model: str,
max_tokens: int,
temperature: float,
) -> Dict[str, Any]:
import google.genai as genai
genai.configure(api_key=settings.google_api_key)
gemini_model = genai.GenerativeModel(model)
# Convert messages
parts = []
for msg in messages:
role = "user" if msg["role"] in ("user", "system") else "model"
parts.append({"role": role, "parts": [msg["content"]]})
# Use last message as prompt if only one
if len(parts) == 1:
response = await gemini_model.generate_content_async(
parts[0]["parts"][0],
generation_config={"max_output_tokens": max_tokens, "temperature": temperature},
)
else:
chat = gemini_model.start_chat(history=parts[:-1])
response = await chat.send_message_async(
parts[-1]["parts"][0],
generation_config={"max_output_tokens": max_tokens, "temperature": temperature},
)
return {
"content": response.text,
"tokens_used": 0,
"model": model,
}
llm_service = LLMService()