import logging
from typing import Dict, Any
from src.utils.celery_worker import celery_app

try:
    from src.utils.db import get_db_session
except ImportError as e:
    print(f"Warning: Could not import required modules: {e}")
    get_db_session = None

logger = logging.getLogger(__name__)


@celery_app.task(bind=True)
def compute_daily_sales(self) -> Dict[str, Any]:
    """
    Celery task to compute daily sales from SalesOrderLine and store in DailySales table
    Aggregates sales by product_id, location_id, and date
    
    Returns:
        Dict containing task results
    """
    # Import models here to avoid circular imports
    try:
        from src.smart_inventory.apps.inventory.models import SalesOrderLine, SalesOrder, DailySales
        from sqlalchemy import func, Date
    except ImportError as e:
        logger.error(f"Failed to import required modules: {e}")
        return {
            "success": False,
            "error": f"Module import error: {str(e)}",
            "task_id": self.request.id
        }
    
    # Check if required modules are available
    if not get_db_session:
        error_msg = "Database session not available"
        logger.error(error_msg)
        return {
            "success": False,
            "error": error_msg,
            "task_id": self.request.id
        }
    
    try:
        logger.info("Starting daily sales computation...")
        self.update_state(state='PROGRESS', meta={'message': 'Initializing daily sales computation...'})
        
        # Create database session
        db = get_db_session()
        
        try:
            self.update_state(state='PROGRESS', meta={'message': 'Querying sales data...'})
            
            # Query to aggregate sales order lines by product, location, and date
            # Join with SalesOrder to get company_id, location_id and sold_at date
            aggregated_sales = db.query(
                SalesOrder.company_id,
                SalesOrderLine.product_id,
                SalesOrder.location_id,
                func.date(SalesOrder.sold_at).label('sale_date'),
                func.sum(SalesOrderLine.quantity).label('quantity_sold'),
                func.sum(SalesOrderLine.quantity * SalesOrderLine.unit_price).label('total_amount')
            ).join(
                SalesOrder, SalesOrderLine.sales_order_id == SalesOrder.id
            ).group_by(
                SalesOrder.company_id,
                SalesOrderLine.product_id,
                SalesOrder.location_id,
                func.date(SalesOrder.sold_at)
            ).all()
            
            if not aggregated_sales:
                logger.info("No sales data found to process")
                result = {
                    "success": True,
                    "message": "No sales data to process",
                    "processed_count": 0,
                    "task_id": self.request.id
                }
                self.update_state(state='SUCCESS', meta=result)
                return result
            
            self.update_state(state='PROGRESS', meta={'message': f'Processing {len(aggregated_sales)} daily sales records...'})
            
            # Process aggregated data and upsert into DailySales table
            processed_count = 0
            updated_count = 0
            inserted_count = 0
            
            for sale in aggregated_sales:
                # Check if record already exists
                existing_record = db.query(DailySales).filter(
                    DailySales.company_id == sale.company_id,
                    DailySales.product_id == sale.product_id,
                    DailySales.location_id == sale.location_id,
                    func.date(DailySales.sale_date) == sale.sale_date
                ).first()
                
                if existing_record:
                    # Update existing record
                    existing_record.quantity_sold = sale.quantity_sold
                    existing_record.total_amount = float(sale.total_amount)
                    updated_count += 1
                else:
                    # Insert new record
                    from datetime import datetime
                    new_daily_sale = DailySales(
                        company_id=sale.company_id,
                        product_id=sale.product_id,
                        location_id=sale.location_id,
                        sale_date=datetime.combine(sale.sale_date, datetime.min.time()),
                        quantity_sold=sale.quantity_sold,
                        total_amount=float(sale.total_amount)
                    )
                    db.add(new_daily_sale)
                    inserted_count += 1
                
                processed_count += 1
            
            # Commit the transaction
            db.commit()
            
            result = {
                "success": True,
                "message": "Daily sales computation completed successfully",
                "processed_count": processed_count,
                "inserted_count": inserted_count,
                "updated_count": updated_count,
                "task_id": self.request.id
            }
            
            logger.info(f"Daily sales computation completed: {result}")
            self.update_state(state='SUCCESS', meta=result)
            return result
            
        except Exception as e:
            db.rollback()
            error_msg = f"Error during daily sales computation: {str(e)}"
            logger.error(error_msg, exc_info=True)
            result = {
                "success": False,
                "error": error_msg,
                "task_id": self.request.id
            }
            self.update_state(state='FAILURE', meta=result)
            return result
            
        finally:
            # Always close the database session
            db.close()
            
    except Exception as e:
        error_msg = f"Daily sales computation task failed: {str(e)}"
        logger.error(error_msg, exc_info=True)
        
        result = {
            "success": False,
            "error": error_msg,
            "task_id": self.request.id
        }
        
        self.update_state(state='FAILURE', meta=result)
        return result
