import logging
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import func, and_, desc
from datetime import datetime, timezone

from src.marketing.apps.hwGpt.model import ChatThread, ChatMessage
from src.marketing.apps.hwGpt import schema

logger = logging.getLogger(__name__)


class ThreadService:
    """Service for managing chat threads."""
    
    def __init__(self, db: Session):
        self.db = db
    
    def create_thread(self, thread_data: schema.ThreadCreateRequest) -> schema.ThreadResponse:
        """Create a new chat thread."""
        try:
            # Create thread
            thread = ChatThread(
                store_id=thread_data.store_id,
                branch_id=thread_data.branch_id,
                user_id=thread_data.user_id,
                title=thread_data.title or f"Chat Thread {datetime.now().strftime('%Y-%m-%d %H:%M')}",
                description=thread_data.description,
                thread_type=thread_data.thread_type
            )
            
            self.db.add(thread)
            self.db.commit()
            self.db.refresh(thread)
            
            logger.info(f"Created new thread {thread.id} for user {thread_data.user_id}")
            
            return self._build_thread_response(thread)
            
        except Exception as e:
            logger.error(f"Error creating thread: {str(e)}")
            self.db.rollback()
            raise
    
    def get_thread(self, thread_id: int, user_id: int) -> Optional[schema.ThreadResponse]:
        """Get a specific thread by ID."""
        try:
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == thread_id,
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                return None
            
            return self._build_thread_response(thread)
            
        except Exception as e:
            logger.error(f"Error getting thread {thread_id}: {str(e)}")
            return None
    
    def get_thread_with_messages(
        self, 
        thread_id: int, 
        user_id: int,
        limit: int = 50,
        offset: int = 0
    ) -> Optional[schema.ThreadWithMessagesResponse]:
        """Get thread with its messages."""
        try:
            # Get thread
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == thread_id,
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                return None
            
            # Get messages
            messages = self.db.query(ChatMessage).filter(
                ChatMessage.thread_id == thread_id,
                ChatMessage.is_deleted == False
            ).order_by(ChatMessage.created_at.asc()).offset(offset).limit(limit).all()
            
            # Build response
            thread_response = self._build_thread_response(thread)
            message_responses = [self._build_message_response(msg) for msg in messages]
            
            return schema.ThreadWithMessagesResponse(
                thread=thread_response,
                messages=message_responses,
                total_messages=len(message_responses)
            )
            
        except Exception as e:
            logger.error(f"Error getting thread with messages {thread_id}: {str(e)}")
            return None
    
    def list_user_threads(
        self, 
        user_id: int,
        store_id: Optional[int] = None,
        branch_id: Optional[int] = None,
        thread_type: Optional[str] = None,
        is_archived: Optional[bool] = None,
        page: int = 1,
        page_size: int = 20
    ) -> schema.ThreadListResponse:
        """List threads for a user with pagination and filters."""
        try:
            # Build query
            query = self.db.query(ChatThread).filter(
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            )
            
            # Apply filters
            if store_id:
                query = query.filter(ChatThread.store_id == store_id)
            
            if branch_id:
                query = query.filter(ChatThread.branch_id == branch_id)
            
            if thread_type:
                query = query.filter(ChatThread.thread_type == thread_type)
            
            if is_archived is not None:
                query = query.filter(ChatThread.is_archived == is_archived)
            
            # Get total count
            total_threads = query.count()
            
            # Apply pagination and ordering
            threads = query.order_by(desc(ChatThread.updated_at)).offset(
                (page - 1) * page_size
            ).limit(page_size).all()
            
            # Build responses
            thread_responses = [self._build_thread_response(thread) for thread in threads]
            
            return schema.ThreadListResponse(
                threads=thread_responses,
                total_threads=total_threads,
                page=page,
                page_size=page_size
            )
            
        except Exception as e:
            logger.error(f"Error listing threads for user {user_id}: {str(e)}")
            return schema.ThreadListResponse(
                threads=[],
                total_threads=0,
                page=page,
                page_size=page_size
            )
    
    def update_thread(
        self, 
        thread_id: int, 
        user_id: int, 
        update_data: schema.ThreadUpdateRequest
    ) -> Optional[schema.ThreadResponse]:
        """Update a thread."""
        try:
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == thread_id,
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                return None
            
            # Update fields
            if update_data.title is not None:
                thread.title = update_data.title
            if update_data.description is not None:
                thread.description = update_data.description
            if update_data.thread_type is not None:
                thread.thread_type = update_data.thread_type
            if update_data.is_archived is not None:
                thread.is_archived = update_data.is_archived
            
            thread.updated_at = datetime.now(timezone.utc)
            self.db.commit()
            self.db.refresh(thread)
            
            logger.info(f"Updated thread {thread_id}")
            
            return self._build_thread_response(thread)
            
        except Exception as e:
            logger.error(f"Error updating thread {thread_id}: {str(e)}")
            self.db.rollback()
            return None
    
    def delete_thread(self, thread_id: int, user_id: int) -> bool:
        """Soft delete a thread."""
        try:
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == thread_id,
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                return False
            
            # Soft delete thread
            thread.is_active = False
            thread.updated_at = datetime.now(timezone.utc)
            
            # Soft delete all messages in the thread
            messages = self.db.query(ChatMessage).filter(
                ChatMessage.thread_id == thread_id
            ).all()
            
            for message in messages:
                message.is_deleted = True
                message.updated_at = datetime.now(timezone.utc)
            
            self.db.commit()
            
            logger.info(f"Deleted thread {thread_id} and {len(messages)} messages")
            
            return True
            
        except Exception as e:
            logger.error(f"Error deleting thread {thread_id}: {str(e)}")
            self.db.rollback()
            return False
    
    def archive_thread(self, thread_id: int, user_id: int) -> bool:
        """Archive a thread."""
        try:
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == thread_id,
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                return False
            
            thread.is_archived = True
            thread.updated_at = datetime.now(timezone.utc)
            self.db.commit()
            
            logger.info(f"Archived thread {thread_id}")
            
            return True
            
        except Exception as e:
            logger.error(f"Error archiving thread {thread_id}: {str(e)}")
            self.db.rollback()
            return False
    
    def unarchive_thread(self, thread_id: int, user_id: int) -> bool:
        """Unarchive a thread."""
        try:
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == thread_id,
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                return False
            
            thread.is_archived = False
            thread.updated_at = datetime.now(timezone.utc)
            self.db.commit()
            
            logger.info(f"Unarchived thread {thread_id}")
            
            return True
            
        except Exception as e:
            logger.error(f"Error unarchiving thread {thread_id}: {str(e)}")
            self.db.rollback()
            return False
    
    def get_thread_stats(self, user_id: int, store_id: Optional[int] = None) -> Dict[str, Any]:
        """Get thread statistics for a user."""
        try:
            query = self.db.query(ChatThread).filter(
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            )
            
            if store_id:
                query = query.filter(ChatThread.store_id == store_id)
            
            total_threads = query.count()
            active_threads = query.filter(ChatThread.is_archived == False).count()
            archived_threads = query.filter(ChatThread.is_archived == True).count()
            
            # Get message count
            message_query = self.db.query(func.count(ChatMessage.id)).join(ChatThread).filter(
                ChatThread.user_id == user_id,
                ChatThread.is_active == True,
                ChatMessage.is_deleted == False
            )
            
            if store_id:
                message_query = message_query.filter(ChatThread.store_id == store_id)
            
            total_messages = message_query.scalar() or 0
            
            return {
                "total_threads": total_threads,
                "active_threads": active_threads,
                "archived_threads": archived_threads,
                "total_messages": total_messages
            }
            
        except Exception as e:
            logger.error(f"Error getting thread stats for user {user_id}: {str(e)}")
            return {
                "total_threads": 0,
                "active_threads": 0,
                "archived_threads": 0,
                "total_messages": 0
            }
    
    def _build_thread_response(self, thread: ChatThread) -> schema.ThreadResponse:
        """Build thread response with message count."""
        # Get message count
        message_count = self.db.query(ChatMessage).filter(
            ChatMessage.thread_id == thread.id,
            ChatMessage.is_deleted == False
        ).count()
        
        return schema.ThreadResponse(
            id=thread.id,
            store_id=thread.store_id,
            branch_id=thread.branch_id,
            user_id=thread.user_id,
            title=thread.title,
            description=thread.description,
            thread_type=thread.thread_type,
            is_active=thread.is_active,
            is_archived=thread.is_archived,
            created_at=thread.created_at,
            updated_at=thread.updated_at,
            message_count=message_count
        )
    
    def _build_message_response(self, message: ChatMessage) -> schema.MessageResponse:
        """Build message response."""
        return schema.MessageResponse(
            id=message.id,
            thread_id=message.thread_id,
            user_id=message.user_id,
            role=message.role,
            content=message.content,
            image_url=message.image_url,
            message_type=message.message_type,
            tokens_used=message.tokens_used,
            model_used=message.model_used,
            created_at=message.created_at,
            updated_at=message.updated_at,
            is_deleted=message.is_deleted
        )
