import logging
from typing import Dict, Any
from src.utils.celery_worker import celery_app

try:
    from src.utils.db import get_db_session
except ImportError as e:
    print(f"Warning: Could not import required modules: {e}")
    get_db_session = None

logger = logging.getLogger(__name__)


@celery_app.task(bind=True, name='src.smart_inventory.tasks.slow_movers_task.compute_slow_movers_90d')
def compute_slow_movers_90d(self, snapshot_date_str: str | None = None) -> Dict[str, Any]:
    """
    Celery task to:
      - aggregate current inventory per company/location/product
      - compute 90-day sales metrics
      - identify slow movers
      - save snapshot into slow_mover_snapshot

    snapshot_date_str: optional 'YYYY-MM-DD'. Defaults to today.
    """
    from datetime import date, timedelta, datetime
    from src.smart_inventory.apps.inventory.models import InventoryBatch, DailySales, SlowMoverSnapshot
    from sqlalchemy import func

    if snapshot_date_str:
        snapshot_date = date.fromisoformat(snapshot_date_str)
    else:
        snapshot_date = date.today()

    lookback_days = 90
    start_date = snapshot_date - timedelta(days=lookback_days)

    if not get_db_session:
        return {"success": False, "error": "Database session not available"}

    db = get_db_session()
    try:
        # 1) Optional: clear existing snapshot for this date to avoid duplicates
        db.query(SlowMoverSnapshot).filter(
            SlowMoverSnapshot.snapshot_date == snapshot_date
        ).delete(synchronize_session=False)
        db.commit()

        # 2) Subquery: sales aggregation for last 90 days
        sales_subq = (
            db.query(
                DailySales.company_id.label("company_id"),
                DailySales.location_id.label("location_id"),
                DailySales.product_id.label("product_id"),
                func.coalesce(func.sum(DailySales.quantity_sold), 0.0).label("total_sold_90d"),
                func.max(DailySales.sale_date).label("last_sale_date"),
            )
            .filter(
                func.date(DailySales.sale_date) >= start_date,
                func.date(DailySales.sale_date) < snapshot_date,
            )
            .group_by(
                DailySales.company_id,
                DailySales.location_id,
                DailySales.product_id,
            )
            .subquery()
        )

        # 3) Aggregate current inventory
        # Join with sales_subq to get sales metrics for each sku/location
        inv_q = (
            db.query(
                InventoryBatch.company_id.label("company_id"),
                InventoryBatch.location_id.label("location_id"),
                InventoryBatch.product_id.label("product_id"),
                func.sum(InventoryBatch.quantity_on_hand).label("on_hand_qty"),
                func.coalesce(sales_subq.c.total_sold_90d, 0.0).label("total_sold_90d"),
                sales_subq.c.last_sale_date.label("last_sale_date"),
            )
            .outerjoin(
                sales_subq,
                (InventoryBatch.company_id == sales_subq.c.company_id)
                & (InventoryBatch.location_id == sales_subq.c.location_id)
                & (InventoryBatch.product_id == sales_subq.c.product_id),
            )
            .group_by(
                InventoryBatch.company_id,
                InventoryBatch.location_id,
                InventoryBatch.product_id,
                sales_subq.c.total_sold_90d,
                sales_subq.c.last_sale_date,
            )
        )

        results = inv_q.all()

        snapshots: list[SlowMoverSnapshot] = []

        for row in results:
            company_id = row.company_id
            location_id = row.location_id
            product_id = row.product_id
            on_hand_qty = float(row.on_hand_qty or 0.0)
            total_sold_90d = float(row.total_sold_90d or 0.0)
            last_sale_date = row.last_sale_date  # may be None

            # 4) Compute metrics
            if total_sold_90d > 0:
                ads_90d = total_sold_90d / float(lookback_days)
            else:
                ads_90d = 0.0

            if ads_90d > 0:
                doh_90d = on_hand_qty / ads_90d
            else:
                # No sales but inventory exists -> treat as very high DOH
                doh_90d = 9999.0 if on_hand_qty > 0 else 0.0

            if last_sale_date is not None:
                # Ensure last_sale_date is a date object
                if isinstance(last_sale_date, datetime):
                    last_sale_date = last_sale_date.date()
                days_since_last_sale = (snapshot_date - last_sale_date).days
            else:
                # never sold in last 90 days (or ever, depending on your data)
                days_since_last_sale = 999

            # 5) Identify slow movers (simple rule – tweak thresholds as you like)
            is_slow_mover = False
            severity = None
            reason = None

            if on_hand_qty > 0:
                if doh_90d > 180:
                    is_slow_mover = True
                    severity = "dead"
                    reason = f"DOH {doh_90d:.1f} > 180 days, on_hand={on_hand_qty}"
                elif doh_90d > 90:
                    is_slow_mover = True
                    severity = "slow"
                    reason = f"DOH {doh_90d:.1f} > 90 days, on_hand={on_hand_qty}"
                elif days_since_last_sale > 60:
                    is_slow_mover = True
                    severity = "watchlist"
                    reason = f"No sales in {days_since_last_sale} days"

            snapshot = SlowMoverSnapshot(
                snapshot_date=snapshot_date,
                company_id=company_id,
                location_id=location_id,
                product_id=product_id,
                on_hand_qty=on_hand_qty,
                total_sold_90d=total_sold_90d,
                ads_90d=ads_90d,
                doh_90d=doh_90d,
                days_since_last_sale=days_since_last_sale,
                is_slow_mover=is_slow_mover,
                slow_mover_severity=severity,
                slow_mover_reason=reason,
            )
            snapshots.append(snapshot)

        if snapshots:
            db.bulk_save_objects(snapshots)
            db.commit()

        return {
            "success": True,
            "snapshot_date": snapshot_date.isoformat(),
            "total_rows": len(snapshots),
        }

    except Exception as e:
        db.rollback()
        logger.error(f"Error in compute_slow_movers_90d: {e}", exc_info=True)
        return {"success": False, "error": str(e)}
    finally:
        db.close()
