#auth_middleware.py
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.orm import Session
from src.utils.db import get_db
from src.apps.auth.services import decode_access_token
from src.apps.auth.models import Session as UserSession

class JWTAuthMiddleware(BaseHTTPMiddleware):
    """Middleware for handling JWT authentication."""

    async def dispatch(self, request: Request, call_next):
        if request.url.path not in ["/auth/login", "/users/register"]:
            auth_header = request.headers.get("Authorization")
            if not auth_header or not auth_header.startswith("Bearer "):
                raise HTTPException(status_code=401, detail="Not authenticated")

            token = auth_header.split(" ")[1]
            db = next(get_db())

            # Check if session exists before decoding
            session_exists = db.query(UserSession).filter(UserSession.token == token).count()
            if not session_exists:
                raise HTTPException(status_code=401, detail="Invalid session (logged out)")

            payload = decode_access_token(token)
            if not payload:
                raise HTTPException(status_code=401, detail="Invalid token")

        response = await call_next(request)
        return response
