from fastapi import Depends, HTTPException, status, Header from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from typing import Optional from dataclasses import dataclass, field from app.database import get_supabase from app.config import settings import base64 import hashlib import hmac import json import logging import time logger = logging.getLogger(__name__) security = HTTPBearer(auto_error=False) @dataclass class _LocalUser: """Minimal user object built from JWT claims — mirrors the fields used downstream.""" id: str email: str role: str = "authenticated" app_metadata: dict = field(default_factory=dict) user_metadata: dict = field(default_factory=dict) def _verify_jwt_local(token: str) -> Optional[_LocalUser]: """Verify a Supabase HS256 JWT using the local secret (no network call). Returns None if the secret is not configured, the signature is wrong, or the token is expired.""" secret = settings.supabase_jwt_secret if not secret: return None try: parts = token.split(".") if len(parts) != 3: return None header_b64, payload_b64, sig_b64 = parts # Verify HMAC-SHA256 signature message = f"{header_b64}.{payload_b64}".encode() expected = hmac.new(secret.encode(), message, hashlib.sha256).digest() padding = "=" * (-len(sig_b64) % 4) actual = base64.urlsafe_b64decode(sig_b64 + padding) if not hmac.compare_digest(expected, actual): return None # Decode payload padding = "=" * (-len(payload_b64) % 4) payload = json.loads(base64.urlsafe_b64decode(payload_b64 + padding)) # Check expiry if payload.get("exp", 0) < time.time(): return None return _LocalUser( id=payload["sub"], email=payload.get("email", ""), role=payload.get("role", "authenticated"), app_metadata=payload.get("app_metadata", {}), user_metadata=payload.get("user_metadata", {}), ) except Exception: return None async def get_current_user( credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), ): """Extract and verify the current user from a Supabase JWT. Tries local HS256 verification first (no network call, no SSL risk). Falls back to supabase.auth.get_user() only when the JWT secret is not configured. """ if not credentials: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", ) token = credentials.credentials # ── Fast path: local verification ──────────────────────────────────────── user = _verify_jwt_local(token) # ── Slow path: network call (only if SUPABASE_JWT_SECRET is not set) ───── if user is None: supabase = get_supabase() try: response = supabase.auth.get_user(token) if not response or not response.user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token", ) user = response.user except HTTPException: raise except Exception as e: logger.error(f"Auth error: {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token", ) # ── Suspension check (DB, not network-auth, so still fast) ─────────────── try: supabase = get_supabase() profile = supabase.table("user_profiles").select("suspended_at").eq("user_id", user.id).execute() if profile.data and profile.data[0].get("suspended_at"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Account suspended. Please contact support.", ) except HTTPException: raise except Exception: pass # Never block login if profile lookup fails return user async def get_admin_user( current_user=Depends(get_current_user), ): """Require the current user to be an admin.""" supabase = get_supabase() try: profile = supabase.table("user_profiles").select("is_admin").eq("user_id", current_user.id).execute() if not profile.data or not profile.data[0].get("is_admin"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required", ) except HTTPException: raise except Exception as e: logger.error(f"Admin check failed: {e}") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required", ) return current_user async def get_optional_user( credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), ): """Optional auth - returns None if not authenticated""" if not credentials: return None try: return await get_current_user(credentials) except HTTPException: return None async def get_user_subscription(user=Depends(get_current_user)): """Get user's subscription plan""" supabase = get_supabase() try: result = ( supabase.table("subscriptions") .select("*") .eq("user_id", user.id) .eq("status", "active") .execute() ) if result.data: return result.data[0] return {"plan": "free", "status": "active", "user_id": user.id} except Exception: return {"plan": "free", "status": "active", "user_id": user.id} async def require_plan(min_plan: str, user=Depends(get_current_user)): """Require a minimum plan level""" plan_order = ["free", "starter", "business", "agency", "enterprise"] subscription = await get_user_subscription(user) user_plan = subscription.get("plan", "free") if plan_order.index(user_plan) < plan_order.index(min_plan): raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"This feature requires {min_plan} plan or higher", ) return user