import logging
from src.utils.celery_worker import celery_app

try:
    from src.utils.db import get_db_session
except ImportError as e:
    print(f"Warning: Could not import required modules: {e}")
    get_db_session = None

logger = logging.getLogger(__name__)


@celery_app.task(bind=True)
def compute_service_level_daily(self, target_date_str: str | None = None):
    """
    Compute daily service level per (company, location, product) from
    sales (fulfilled demand) + stockout events (lost demand).

    service_level = fulfilled_qty / (fulfilled_qty + lost_sales_qty)
    """
    # Imports inside task to avoid circular imports
    try:
        from src.smart_inventory.apps.inventory.models import DailySales, ServiceLevelDaily  # StockoutEvent
        from sqlalchemy import func
        from datetime import date
    except ImportError as e:
        logger.error(f"Failed to import required modules: {e}")
        return {"success": False, "error": str(e)}

    if not get_db_session:
         return {"success": False, "error": "Database session not available"}

    # 1) Resolve date
    if target_date_str:
        target_date = date.fromisoformat(target_date_str)
    else:
        target_date = date.today()

    db = get_db_session()
    try:
        # 2) Idempotency: clear existing rows for this date
        db.query(ServiceLevelDaily).filter(
            ServiceLevelDaily.date == target_date
        ).delete(synchronize_session=False)
        db.commit()

        # 3) Aggregate sales (fulfilled demand) for this date
        sales_rows = (
            db.query(
                DailySales.company_id.label("company_id"),
                DailySales.location_id.label("location_id"),
                DailySales.product_id.label("product_id"),
                func.coalesce(func.sum(DailySales.quantity_sold), 0.0).label("fulfilled_qty"),
            )
            .filter(func.date(DailySales.sale_date) == target_date)
            .group_by(
                DailySales.company_id,
                DailySales.location_id,
                DailySales.product_id,
            )
            .all()
        )

        sales_map = {
            (r.company_id, r.location_id, r.product_id): float(r.fulfilled_qty or 0.0)
            for r in sales_rows
        }

        # 4) Aggregate stockout events (lost demand) for this date
        # COMMENTED OUT: Not using StockoutEvent for now
        # stockout_rows = (
        #     db.query(
        #         StockoutEvent.company_id.label("company_id"),
        #         StockoutEvent.location_id.label("location_id"),
        #         StockoutEvent.product_id.label("product_id"),
        #         func.coalesce(func.sum(StockoutEvent.lost_sales_qty), 0.0).label("lost_sales_qty"),
        #     )
        #     .filter(StockoutEvent.date == target_date)
        #     .group_by(
        #         StockoutEvent.company_id,
        #         StockoutEvent.location_id,
        #         StockoutEvent.product_id,
        #     )
        #     .all()
        # )
        #
        # stockout_map = {
        #     (r.company_id, r.location_id, r.product_id): float(r.lost_sales_qty or 0.0)
        #     for r in stockout_rows
        # }

        # 5) Union of all keys (only from sales for now)
        all_keys = set(sales_map.keys())  # | set(stockout_map.keys())

        service_level_rows = []

        for key in all_keys:
            company_id, location_id, product_id = key
            fulfilled_qty = sales_map.get(key, 0.0)
            lost_sales_qty = 0.0  # stockout_map.get(key, 0.0)  # Not using stockouts for now

            # demand_qty = fulfilled_qty + lost_sales_qty  # commented for now
            demand_qty = fulfilled_qty  # no lost sales for now

            if demand_qty <= 0:
                service_level_value = 1.0
            else:
                service_level_value = fulfilled_qty / demand_qty

            record = ServiceLevelDaily(
                date=target_date,
                company_id=company_id,
                location_id=location_id,
                product_id=product_id,
                demand_qty=demand_qty,
                fulfilled_qty=fulfilled_qty,
                lost_sales_qty=lost_sales_qty,
                service_level=service_level_value,
            )
            service_level_rows.append(record)

        if service_level_rows:
            db.bulk_save_objects(service_level_rows)
            db.commit()

        return {
            "success": True,
            "date": target_date.isoformat(),
            "rows_inserted": len(service_level_rows),
        }

    except Exception as e:
        db.rollback()
        logger.error(f"Error in compute_service_level_daily: {e}", exc_info=True)
        return {"success": False, "error": str(e)}
    finally:
        db.close()
