import json import urllib.request import jwt from app.config import get_settings from app.exceptions import UnauthorizedError _jwks_cache: dict[str, object] = {} def _get_public_key(kid: str) -> object: if kid in _jwks_cache: return _jwks_cache[kid] settings = get_settings() url = f"{settings.SUPABASE_URL}/auth/v1/.well-known/jwks.json" with urllib.request.urlopen(url, timeout=10) as resp: jwks = json.loads(resp.read()) for key_data in jwks.get("keys", []): if key_data.get("kid") == kid: public_key = jwt.algorithms.ECAlgorithm.from_jwk(key_data) _jwks_cache[kid] = public_key return public_key raise UnauthorizedError("Public key not found") def decode_token(token: str) -> dict: try: header = jwt.get_unverified_header(token) alg = header.get("alg", "HS256") kid = header.get("kid") if alg == "HS256": settings = get_settings() key = settings.SUPABASE_JWT_SECRET else: key = _get_public_key(kid) payload = jwt.decode( token, key, algorithms=[alg], audience="authenticated", options={"verify_exp": True}, ) return payload except jwt.ExpiredSignatureError: raise UnauthorizedError("Token has expired") except jwt.InvalidTokenError as e: raise UnauthorizedError(f"Invalid token: {e}") except UnauthorizedError: raise except Exception as e: raise UnauthorizedError(f"Token validation failed: {e}")