mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-13 08:30:07 +00:00
166 lines
5.4 KiB
Python
166 lines
5.4 KiB
Python
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMService:
|
|
"""Routes requests to appropriate LLM provider"""
|
|
|
|
async def generate(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
temperature: float = 0.7,
|
|
) -> Dict[str, Any]:
|
|
"""Generate a response from the LLM"""
|
|
provider = MODEL_PROVIDERS.get(model, "openai")
|
|
|
|
try:
|
|
if provider == "fireworks":
|
|
return await self._call_fireworks(messages, model, max_tokens, temperature)
|
|
elif provider == "openai":
|
|
return await self._call_openai(messages, model, max_tokens, temperature)
|
|
elif provider == "anthropic":
|
|
return await self._call_anthropic(messages, model, max_tokens, temperature)
|
|
elif provider == "google":
|
|
return await self._call_google(messages, model, max_tokens, temperature)
|
|
else:
|
|
return await self._call_openai(messages, model, max_tokens, temperature)
|
|
except Exception as e:
|
|
logger.error(f"LLM error ({provider}/{model}): {e}")
|
|
fallback = "accounts/fireworks/models/llama-v3p3-70b-instruct"
|
|
if model != fallback and settings.fireworks_api_key:
|
|
return await self._call_fireworks(messages, fallback, max_tokens, temperature)
|
|
raise
|
|
|
|
async def _call_fireworks(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
model: str,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
) -> Dict[str, Any]:
|
|
import httpx
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {settings.fireworks_api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
}
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
resp = await client.post(
|
|
"https://api.fireworks.ai/inference/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload,
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return {
|
|
"content": data["choices"][0]["message"]["content"],
|
|
"tokens_used": data.get("usage", {}).get("total_tokens", 0),
|
|
"model": model,
|
|
}
|
|
|
|
async def _call_openai(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
model: str,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
) -> Dict[str, Any]:
|
|
from openai import AsyncOpenAI
|
|
|
|
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
)
|
|
return {
|
|
"content": response.choices[0].message.content,
|
|
"tokens_used": response.usage.total_tokens if response.usage else 0,
|
|
"model": model,
|
|
}
|
|
|
|
async def _call_anthropic(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
model: str,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
) -> Dict[str, Any]:
|
|
import anthropic
|
|
|
|
client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key)
|
|
|
|
# Separate system message from conversation
|
|
system_msg = ""
|
|
conv_messages = []
|
|
for msg in messages:
|
|
if msg["role"] == "system":
|
|
system_msg = msg["content"]
|
|
else:
|
|
conv_messages.append(msg)
|
|
|
|
response = await client.messages.create(
|
|
model=model,
|
|
max_tokens=max_tokens,
|
|
system=system_msg if system_msg else "You are a helpful assistant.",
|
|
messages=conv_messages,
|
|
temperature=temperature,
|
|
)
|
|
return {
|
|
"content": response.content[0].text,
|
|
"tokens_used": response.usage.input_tokens + response.usage.output_tokens,
|
|
"model": model,
|
|
}
|
|
|
|
async def _call_google(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
model: str,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
) -> Dict[str, Any]:
|
|
import google.genai as genai
|
|
|
|
genai.configure(api_key=settings.google_api_key)
|
|
gemini_model = genai.GenerativeModel(model)
|
|
|
|
# Convert messages
|
|
parts = []
|
|
for msg in messages:
|
|
role = "user" if msg["role"] in ("user", "system") else "model"
|
|
parts.append({"role": role, "parts": [msg["content"]]})
|
|
|
|
# Use last message as prompt if only one
|
|
if len(parts) == 1:
|
|
response = await gemini_model.generate_content_async(
|
|
parts[0]["parts"][0],
|
|
generation_config={"max_output_tokens": max_tokens, "temperature": temperature},
|
|
)
|
|
else:
|
|
chat = gemini_model.start_chat(history=parts[:-1])
|
|
response = await chat.send_message_async(
|
|
parts[-1]["parts"][0],
|
|
generation_config={"max_output_tokens": max_tokens, "temperature": temperature},
|
|
)
|
|
|
|
return {
|
|
"content": response.text,
|
|
"tokens_used": 0,
|
|
"model": model,
|
|
}
|
|
|
|
|
|
llm_service = LLMService()
|