import logging
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import func, and_, desc, asc
from datetime import datetime, timezone, timedelta

from src.marketing.apps.hwGpt.model import ChatMessage, ChatThread
from src.marketing.apps.hwGpt import schema

logger = logging.getLogger(__name__)


class ThreadHistoryService:
    """Service for managing chat thread history and messages."""
    
    def __init__(self, db: Session):
        self.db = db
    
    def add_message(self, message_data: schema.MessageCreateRequest) -> schema.MessageResponse:
        """Add a new message to a thread."""
        try:
            # Verify thread exists and belongs to user
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == message_data.thread_id,
                ChatThread.user_id == message_data.user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                raise ValueError("Thread not found or access denied")
            
            # Create message
            message = ChatMessage(
                thread_id=message_data.thread_id,
                user_id=message_data.user_id,
                role=message_data.role,
                content=message_data.content,
                message_type=message_data.message_type,
                image_url=message_data.image_url,
                model_used=message_data.model_used
            )
            
            self.db.add(message)
            self.db.commit()
            self.db.refresh(message)
            
            # Update thread's updated_at timestamp
            thread.updated_at = datetime.now(timezone.utc)
            self.db.commit()
            
            logger.info(f"Added message {message.id} to thread {message_data.thread_id}")
            
            return self._build_message_response(message)
            
        except Exception as e:
            logger.error(f"Error adding message: {str(e)}")
            self.db.rollback()
            raise
    
    def get_message(self, message_id: int, user_id: int) -> Optional[schema.MessageResponse]:
        """Get a specific message by ID."""
        try:
            message = self.db.query(ChatMessage).join(ChatThread).filter(
                ChatMessage.id == message_id,
                ChatThread.user_id == user_id,
                ChatMessage.is_deleted == False
            ).first()
            
            if not message:
                return None
            
            return self._build_message_response(message)
            
        except Exception as e:
            logger.error(f"Error getting message {message_id}: {str(e)}")
            return None
    
    def get_thread_messages(
        self, 
        thread_id: int, 
        user_id: int,
        limit: int = 50,
        offset: int = 0,
        before_message_id: Optional[int] = None,
        after_message_id: Optional[int] = None
    ) -> List[schema.MessageResponse]:
        """Get messages from a thread with various pagination options."""
        try:
            # Verify thread access
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == thread_id,
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                return []
            
            # Build query
            query = self.db.query(ChatMessage).filter(
                ChatMessage.thread_id == thread_id,
                ChatMessage.is_deleted == False
            )
            
            # Apply cursor-based pagination if specified
            if before_message_id:
                query = query.filter(ChatMessage.id < before_message_id)
            
            if after_message_id:
                query = query.filter(ChatMessage.id > after_message_id)
            
            # Get messages ordered by creation time
            messages = query.order_by(asc(ChatMessage.created_at)).offset(offset).limit(limit).all()
            
            return [self._build_message_response(msg) for msg in messages]
            
        except Exception as e:
            logger.error(f"Error getting thread messages {thread_id}: {str(e)}")
            return []
    
    def get_thread_messages_by_date_range(
        self, 
        thread_id: int, 
        user_id: int,
        start_date: datetime,
        end_date: datetime,
        limit: int = 100
    ) -> List[schema.MessageResponse]:
        """Get messages from a thread within a date range."""
        try:
            # Verify thread access
            thread = self.db.query(ChatThread).filter(
                ChatThread.id == thread_id,
                ChatThread.user_id == user_id,
                ChatThread.is_active == True
            ).first()
            
            if not thread:
                return []
            
            # Get messages in date range
            messages = self.db.query(ChatMessage).filter(
                ChatMessage.thread_id == thread_id,
                ChatMessage.is_deleted == False,
                ChatMessage.created_at >= start_date,
                ChatMessage.created_at <= end_date
            ).order_by(asc(ChatMessage.created_at)).limit(limit).all()
            
            return [self._build_message_response(msg) for msg in messages]
            
        except Exception as e:
            logger.error(f"Error getting thread messages by date range {thread_id}: {str(e)}")
            return []
    
    def search_messages(
        self, 
        user_id: int,
        query: str,
        thread_id: Optional[int] = None,
        store_id: Optional[int] = None,
        role: Optional[str] = None,
        limit: int = 50
    ) -> List[schema.MessageResponse]:
        """Search messages across threads."""
        try:
            # Build base query
            base_query = self.db.query(ChatMessage).join(ChatThread).filter(
                ChatThread.user_id == user_id,
                ChatThread.is_active == True,
                ChatMessage.is_deleted == False,
                ChatMessage.content.ilike(f"%{query}%")
            )
            
            # Apply filters
            if thread_id:
                base_query = base_query.filter(ChatMessage.thread_id == thread_id)
            
            if store_id:
                base_query = base_query.filter(ChatThread.store_id == store_id)
            
            if role:
                base_query = base_query.filter(ChatMessage.role == role)
            
            # Get results
            messages = base_query.order_by(desc(ChatMessage.created_at)).limit(limit).all()
            
            return [self._build_message_response(msg) for msg in messages]
            
        except Exception as e:
            logger.error(f"Error searching messages: {str(e)}")
            return []
    
    def update_message(
        self, 
        message_id: int, 
        user_id: int, 
        update_data: schema.MessageUpdateRequest
    ) -> Optional[schema.MessageResponse]:
        """Update a message."""
        try:
            # Get message and verify access
            message = self.db.query(ChatMessage).join(ChatThread).filter(
                ChatMessage.id == message_id,
                ChatThread.user_id == user_id,
                ChatMessage.is_deleted == False
            ).first()
            
            if not message:
                return None
            
            # Update fields
            if update_data.content is not None:
                message.content = update_data.content
            if update_data.message_type is not None:
                message.message_type = update_data.message_type
            if update_data.image_url is not None:
                message.image_url = update_data.image_url
            if update_data.is_deleted is not None:
                message.is_deleted = update_data.is_deleted
            
            message.updated_at = datetime.now(timezone.utc)
            self.db.commit()
            self.db.refresh(message)
            
            logger.info(f"Updated message {message_id}")
            
            return self._build_message_response(message)
            
        except Exception as e:
            logger.error(f"Error updating message {message_id}: {str(e)}")
            self.db.rollback()
            return None
    
    def delete_message(self, message_id: int, user_id: int) -> bool:
        """Soft delete a message."""
        try:
            # Get message and verify access
            message = self.db.query(ChatMessage).join(ChatThread).filter(
                ChatMessage.id == message_id,
                ChatThread.user_id == user_id,
                ChatMessage.is_deleted == False
            ).first()
            
            if not message:
                return False
            
            # Soft delete
            message.is_deleted = True
            message.updated_at = datetime.now(timezone.utc)
            self.db.commit()
            
            logger.info(f"Deleted message {message_id}")
            
            return True
            
        except Exception as e:
            logger.error(f"Error deleting message {message_id}: {str(e)}")
            self.db.rollback()
            return False
    
    def get_message_stats(
        self, 
        user_id: int, 
        thread_id: Optional[int] = None,
        store_id: Optional[int] = None,
        days: int = 30
    ) -> Dict[str, Any]:
        """Get message statistics for a user."""
        try:
            # Calculate date range
            end_date = datetime.now(timezone.utc)
            start_date = end_date - timedelta(days=days)
            
            # Build base query
            base_query = self.db.query(ChatMessage).join(ChatThread).filter(
                ChatThread.user_id == user_id,
                ChatThread.is_active == True,
                ChatMessage.is_deleted == False,
                ChatMessage.created_at >= start_date,
                ChatMessage.created_at <= end_date
            )
            
            # Apply filters
            if thread_id:
                base_query = base_query.filter(ChatMessage.thread_id == thread_id)
            
            if store_id:
                base_query = base_query.filter(ChatThread.store_id == store_id)
            
            # Get counts
            total_messages = base_query.count()
            
            # Count by role
            user_messages = base_query.filter(ChatMessage.role == "user").count()
            assistant_messages = base_query.filter(ChatMessage.role == "assistant").count()
            system_messages = base_query.filter(ChatMessage.role == "system").count()
            
            # Count by message type
            text_messages = base_query.filter(ChatMessage.message_type == "text").count()
            image_messages = base_query.filter(ChatMessage.message_type == "image").count()
            file_messages = base_query.filter(ChatMessage.message_type == "file").count()
            
            # Get daily message counts for the last 7 days
            daily_counts = []
            for i in range(7):
                date = end_date - timedelta(days=i)
                day_start = date.replace(hour=0, minute=0, second=0, microsecond=0)
                day_end = date.replace(hour=23, minute=59, second=59, microsecond=999999)
                
                daily_count = base_query.filter(
                    ChatMessage.created_at >= day_start,
                    ChatMessage.created_at <= day_end
                ).count()
                
                daily_counts.append({
                    "date": day_start.date().isoformat(),
                    "count": daily_count
                })
            
            return {
                "total_messages": total_messages,
                "user_messages": user_messages,
                "assistant_messages": assistant_messages,
                "system_messages": system_messages,
                "text_messages": text_messages,
                "image_messages": image_messages,
                "file_messages": file_messages,
                "daily_counts": list(reversed(daily_counts)),  # Oldest first
                "period_days": days
            }
            
        except Exception as e:
            logger.error(f"Error getting message stats for user {user_id}: {str(e)}")
            return {
                "total_messages": 0,
                "user_messages": 0,
                "assistant_messages": 0,
                "system_messages": 0,
                "text_messages": 0,
                "image_messages": 0,
                "file_messages": 0,
                "daily_counts": [],
                "period_days": days
            }
    
    def cleanup_old_messages(self, days: int = 90) -> int:
        """Clean up old deleted messages (admin function)."""
        try:
            cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
            
            # Count messages to be deleted
            old_messages = self.db.query(ChatMessage).filter(
                ChatMessage.is_deleted == True,
                ChatMessage.updated_at < cutoff_date
            ).count()
            
            # Delete old messages
            self.db.query(ChatMessage).filter(
                ChatMessage.is_deleted == True,
                ChatMessage.updated_at < cutoff_date
            ).delete()
            
            self.db.commit()
            
            logger.info(f"Cleaned up {old_messages} old deleted messages")
            
            return old_messages
            
        except Exception as e:
            logger.error(f"Error cleaning up old messages: {str(e)}")
            self.db.rollback()
            return 0
    
    def _build_message_response(self, message: ChatMessage) -> schema.MessageResponse:
        """Build message response."""
        return schema.MessageResponse(
            id=message.id,
            thread_id=message.thread_id,
            user_id=message.user_id,
            role=message.role,
            content=message.content,
            message_type=message.message_type,
            image_url=message.image_url,
            tokens_used=message.tokens_used,
            model_used=message.model_used,
            created_at=message.created_at,
            updated_at=message.updated_at,
            is_deleted=message.is_deleted
        )
