import json
import uuid
import logging
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import redis
from sqlalchemy.orm import Session
from sqlalchemy import func, desc

try:
    from src.utils.settings import settings
except ImportError:
    from utils.settings import settings

from .models import SmartInventoryChatSession, SmartInventoryChatMessage

logger = logging.getLogger(__name__)

# Redis key prefix for chat sessions
CHAT_SESSION_PREFIX = "smart_inventory:chat:"
# TTL for chat sessions (1 hour)
CHAT_SESSION_TTL = 3600
# Maximum messages to keep in history before summarizing
MAX_HISTORY_MESSAGES = 20
# Token limit threshold (approximate) - when to trigger summarization
MAX_CONTEXT_CHARS = 15000  # ~4000 tokens


class ChatSessionService:
    """Service for managing chat sessions using Redis"""
    
    def __init__(self):
        """Initialize Redis connection"""
        self._redis_client = None
    
    @property
    def redis_client(self) -> redis.Redis:
        """Lazy initialization of Redis client"""
        if self._redis_client is None:
            self._redis_client = redis.Redis(
                host=settings.REDIS_HOST,
                port=settings.REDIS_PORT,
                db=settings.REDIS_DB,
                decode_responses=True
            )
        return self._redis_client
    
    def _get_session_key(self, session_id: str) -> str:
        """Generate Redis key for a session"""
        return f"{CHAT_SESSION_PREFIX}{session_id}"
    
    def create_session(self, company_id: int) -> str:
        """
        Create a new chat session and return the session ID
        """
        session_id = str(uuid.uuid4())
        session_data = {
            "session_id": session_id,
            "company_id": company_id,
            "created_at": datetime.now().isoformat(),
            "messages": [],
            "context_summary": None  # Will hold summarized older context
        }
        
        key = self._get_session_key(session_id)
        self.redis_client.setex(
            key,
            CHAT_SESSION_TTL,
            json.dumps(session_data)
        )
        
        logger.info(f"Created new chat session: {session_id}")
        return session_id
    
    def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
        """
        Get session data by session ID
        Returns None if session doesn't exist or expired
        """
        key = self._get_session_key(session_id)
        data = self.redis_client.get(key)
        
        if data is None:
            logger.info(f"Session not found or expired: {session_id}")
            return None
        
        return json.loads(data)
    
    def update_session(self, session_id: str, session_data: Dict[str, Any]) -> bool:
        """
        Update session data and refresh TTL
        """
        key = self._get_session_key(session_id)
        
        # Check if session exists
        if not self.redis_client.exists(key):
            logger.warning(f"Cannot update non-existent session: {session_id}")
            return False
        
        self.redis_client.setex(
            key,
            CHAT_SESSION_TTL,
            json.dumps(session_data)
        )
        return True
    
    def add_message(self, session_id: str, role: str, content: str, 
                    raw_data: Optional[List[Dict]] = None) -> bool:
        """
        Add a message to the session history
        role: 'user' or 'assistant'
        """
        session_data = self.get_session(session_id)
        if session_data is None:
            return False
        
        message = {
            "role": role,
            "content": content,
            "timestamp": datetime.now().isoformat()
        }
        
        # Store raw_data reference for assistant messages (limited)
        if role == "assistant" and raw_data:
            message["has_data"] = True
            message["data_count"] = len(raw_data)
        
        session_data["messages"].append(message)
        
        return self.update_session(session_id, session_data)
    
    def get_conversation_history(self, session_id: str) -> List[Dict[str, str]]:
        """
        Get conversation history formatted for LLM context
        Returns list of {role, content} dicts
        """
        session_data = self.get_session(session_id)
        if session_data is None:
            return []
        
        history = []
        
        # Add context summary if exists
        if session_data.get("context_summary"):
            history.append({
                "role": "system",
                "content": f"Previous conversation summary: {session_data['context_summary']}"
            })
        
        # Add recent messages
        for msg in session_data["messages"]:
            history.append({
                "role": msg["role"],
                "content": msg["content"]
            })
        
        return history
    
    def needs_summarization(self, session_id: str) -> bool:
        """
        Check if conversation history needs to be summarized
        """
        session_data = self.get_session(session_id)
        if session_data is None:
            return False
        
        messages = session_data.get("messages", [])
        
        # Check by message count
        if len(messages) >= MAX_HISTORY_MESSAGES:
            return True
        
        # Check by approximate character count (rough token estimate)
        total_chars = sum(len(msg.get("content", "")) for msg in messages)
        if total_chars > MAX_CONTEXT_CHARS:
            return True
        
        return False
    
    def set_context_summary(self, session_id: str, summary: str, 
                           keep_last_n: int = 4) -> bool:
        """
        Set the context summary and keep only the last N messages
        """
        session_data = self.get_session(session_id)
        if session_data is None:
            return False
        
        session_data["context_summary"] = summary
        session_data["messages"] = session_data["messages"][-keep_last_n:]
        
        return self.update_session(session_id, session_data)
    
    def delete_session(self, session_id: str) -> bool:
        """
        Delete a chat session
        """
        key = self._get_session_key(session_id)
        result = self.redis_client.delete(key)
        return result > 0
    
    def refresh_session_ttl(self, session_id: str) -> bool:
        """
        Refresh the TTL of a session without modifying data
        """
        key = self._get_session_key(session_id)
        return self.redis_client.expire(key, CHAT_SESSION_TTL)


# Singleton instance
chat_session_service = ChatSessionService()


# ============== Database Persistence Services ==============

def generate_chat_name(question: str, max_length: int = 100) -> str:
    """
    Generate a chat name from the first question.
    Truncates to max_length and adds ellipsis if needed.
    """
    # Clean up the question
    chat_name = question.strip()
    
    # Remove common prefixes
    prefixes_to_remove = ["what is", "what are", "tell me about", "show me", "can you"]
    lower_name = chat_name.lower()
    for prefix in prefixes_to_remove:
        if lower_name.startswith(prefix):
            chat_name = chat_name[len(prefix):].strip()
            break
    
    # Capitalize first letter
    if chat_name:
        chat_name = chat_name[0].upper() + chat_name[1:]
    
    # Truncate if needed
    if len(chat_name) > max_length - 3:
        chat_name = chat_name[:max_length - 3] + "..."
    
    return chat_name if chat_name else "New Chat"


class ChatHistoryService:
    """Service for managing chat history in the database"""
    
    def create_session(
        self, 
        db: Session, 
        chat_session_id: str,
        company_id: int, 
        user_id: int, 
        store_id: int, 
        branch_id: int,
        first_question: str
    ) -> SmartInventoryChatSession:
        """
        Create a new chat session in the database.
        """
        chat_name = generate_chat_name(first_question)
        
        session = SmartInventoryChatSession(
            chat_session_id=chat_session_id,
            company_id=company_id,
            user_id=user_id,
            store_id=store_id,
            branch_id=branch_id,
            chat_name=chat_name
        )
        
        db.add(session)
        db.commit()
        db.refresh(session)
        
        logger.info(f"Created chat session in DB: {chat_session_id}")
        return session
    
    def get_session_by_chat_id(
        self, 
        db: Session, 
        chat_session_id: str
    ) -> Optional[SmartInventoryChatSession]:
        """
        Get a chat session by its chat_session_id.
        """
        return db.query(SmartInventoryChatSession).filter(
            SmartInventoryChatSession.chat_session_id == chat_session_id
        ).first()
    
    def add_message(
        self, 
        db: Session, 
        chat_session_id: str, 
        role: str, 
        content: str
    ) -> Optional[SmartInventoryChatMessage]:
        """
        Add a message to an existing chat session.
        """
        session = self.get_session_by_chat_id(db, chat_session_id)
        if session is None:
            logger.warning(f"Cannot add message - session not found: {chat_session_id}")
            return None
        
        # Store role as lowercase string directly
        message_role = role.lower()  # 'user' or 'assistant'
        
        message = SmartInventoryChatMessage(
            session_id=session.id,
            role=message_role,
            content=content
        )
        
        db.add(message)
        
        # Update session's updated_at
        session.updated_at = datetime.utcnow()
        
        db.commit()
        db.refresh(message)
        
        return message
    
    def get_last_n_messages(
        self, 
        db: Session, 
        chat_session_id: str, 
        n: int = 10
    ) -> List[Dict[str, str]]:
        """
        Get the last N messages (n/2 questions and n/2 answers) for LLM context.
        Returns list of {role, content} dicts.
        """
        session = self.get_session_by_chat_id(db, chat_session_id)
        if session is None:
            return []
        
        # Get last N messages ordered by created_at
        messages = db.query(SmartInventoryChatMessage).filter(
            SmartInventoryChatMessage.session_id == session.id
        ).order_by(
            desc(SmartInventoryChatMessage.created_at)
        ).limit(n).all()
        
        # Reverse to get chronological order
        messages = list(reversed(messages))
        
        return [
            {"role": msg.role, "content": msg.content}
            for msg in messages
        ]
    
    def get_chat_history_list(
        self, 
        db: Session, 
        user_id: int,
        store_id: int,
        branch_id: int,
        page: int = 1, 
        page_size: int = 20
    ) -> Tuple[List[Dict[str, Any]], int]:
        """
        Get paginated list of chat sessions for a user/store/branch.
        Returns (sessions_with_count, total_count).
        """
        # Base query
        query = db.query(SmartInventoryChatSession).filter(
            SmartInventoryChatSession.user_id == user_id,
            SmartInventoryChatSession.store_id == store_id,
            SmartInventoryChatSession.branch_id == branch_id
        )
        
        # Get total count
        total = query.count()
        
        # Get paginated results ordered by updated_at desc (most recent first)
        offset = (page - 1) * page_size
        sessions = query.order_by(
            desc(SmartInventoryChatSession.updated_at)
        ).offset(offset).limit(page_size).all()
        
        # Get message counts for each session
        result = []
        for session in sessions:
            message_count = db.query(func.count(SmartInventoryChatMessage.id)).filter(
                SmartInventoryChatMessage.session_id == session.id
            ).scalar()
            
            result.append({
                "id": session.id,
                "chat_session_id": session.chat_session_id,
                "chat_name": session.chat_name,
                "company_id": session.company_id,
                "user_id": session.user_id,
                "store_id": session.store_id,
                "branch_id": session.branch_id,
                "created_at": session.created_at,
                "updated_at": session.updated_at,
                "message_count": message_count
            })
        
        return result, total
    
    def get_chat_session_detail(
        self, 
        db: Session, 
        chat_session_id: str
    ) -> Optional[Dict[str, Any]]:
        """
        Get detailed chat session with all messages.
        """
        session = self.get_session_by_chat_id(db, chat_session_id)
        if session is None:
            return None
        
        # Get all messages for this session
        messages = db.query(SmartInventoryChatMessage).filter(
            SmartInventoryChatMessage.session_id == session.id
        ).order_by(SmartInventoryChatMessage.created_at).all()
        
        return {
            "id": session.id,
            "chat_session_id": session.chat_session_id,
            "chat_name": session.chat_name,
            "company_id": session.company_id,
            "user_id": session.user_id,
            "store_id": session.store_id,
            "branch_id": session.branch_id,
            "created_at": session.created_at,
            "updated_at": session.updated_at,
            "messages": [
                {
                    "id": msg.id,
                    "role": msg.role,
                    "content": msg.content,
                    "created_at": msg.created_at
                }
                for msg in messages
            ]
        }


# Singleton instance
chat_history_service = ChatHistoryService()

