mirror of
http://88.130.71.182:3000/BlitTech/contexta_be.git
synced 2026-06-12 23:23:21 +00:00
Initial commit
This commit is contained in:
171
app/services/llm.py
Normal file
171
app/services/llm.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from app.config import settings, MODEL_PROVIDERS, PLAN_LIMITS
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""Routes requests to appropriate LLM provider"""
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
temperature: float = 0.7,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a response from the LLM"""
|
||||
provider = MODEL_PROVIDERS.get(model, "openai")
|
||||
|
||||
try:
|
||||
if provider == "fireworks":
|
||||
return await self._call_fireworks(messages, model, max_tokens, temperature)
|
||||
elif provider == "openai":
|
||||
return await self._call_openai(messages, model, max_tokens, temperature)
|
||||
elif provider == "anthropic":
|
||||
return await self._call_anthropic(messages, model, max_tokens, temperature)
|
||||
elif provider == "google":
|
||||
return await self._call_google(messages, model, max_tokens, temperature)
|
||||
else:
|
||||
return await self._call_openai(messages, model, max_tokens, temperature)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM error ({provider}/{model}): {e}")
|
||||
# Fallback to a basic model if available
|
||||
if model != "accounts/fireworks/models/llama-v3p1-70b-instruct" and settings.fireworks_api_key:
|
||||
logger.info("Falling back to Fireworks AI")
|
||||
return await self._call_fireworks(
|
||||
messages,
|
||||
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
max_tokens,
|
||||
temperature,
|
||||
)
|
||||
raise
|
||||
|
||||
async def _call_fireworks(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> Dict[str, Any]:
|
||||
import httpx
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.fireworks_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
resp = await client.post(
|
||||
"https://api.fireworks.ai/inference/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"content": data["choices"][0]["message"]["content"],
|
||||
"tokens_used": data.get("usage", {}).get("total_tokens", 0),
|
||||
"model": model,
|
||||
}
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> Dict[str, Any]:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(api_key=settings.openai_api_key)
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"tokens_used": response.usage.total_tokens if response.usage else 0,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
async def _call_anthropic(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> Dict[str, Any]:
|
||||
import anthropic
|
||||
|
||||
client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||
|
||||
# Separate system message from conversation
|
||||
system_msg = ""
|
||||
conv_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
system_msg = msg["content"]
|
||||
else:
|
||||
conv_messages.append(msg)
|
||||
|
||||
response = await client.messages.create(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
system=system_msg if system_msg else "You are a helpful assistant.",
|
||||
messages=conv_messages,
|
||||
temperature=temperature,
|
||||
)
|
||||
return {
|
||||
"content": response.content[0].text,
|
||||
"tokens_used": response.usage.input_tokens + response.usage.output_tokens,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
async def _call_google(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> Dict[str, Any]:
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=settings.google_api_key)
|
||||
gemini_model = genai.GenerativeModel(model)
|
||||
|
||||
# Convert messages
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] in ("user", "system") else "model"
|
||||
parts.append({"role": role, "parts": [msg["content"]]})
|
||||
|
||||
# Use last message as prompt if only one
|
||||
if len(parts) == 1:
|
||||
response = await gemini_model.generate_content_async(
|
||||
parts[0]["parts"][0],
|
||||
generation_config={"max_output_tokens": max_tokens, "temperature": temperature},
|
||||
)
|
||||
else:
|
||||
chat = gemini_model.start_chat(history=parts[:-1])
|
||||
response = await chat.send_message_async(
|
||||
parts[-1]["parts"][0],
|
||||
generation_config={"max_output_tokens": max_tokens, "temperature": temperature},
|
||||
)
|
||||
|
||||
return {
|
||||
"content": response.text,
|
||||
"tokens_used": 0,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
|
||||
llm_service = LLMService()
|
||||
Reference in New Issue
Block a user