from __future__ import annotations

from datetime import date, timedelta

from sqlalchemy import func
from sqlalchemy.orm import Session

from src.utils.db import SessionLocal
from src.utils.celery_worker import celery_app
from src.smart_inventory.apps.inventory.models import (
    InventorySnapshotDaily,
    DailySales,
    DemandForecast,
    ReorderPolicy,
    InventoryPlanningSnapshot,
)
from src.smart_inventory.apps.inventory.utils.inventory_planning_utils import (
    compute_safety_stock,
    compute_reorder_point,
)


@celery_app.task
def compute_inventory_planning_snapshot(snapshot_date_str: str | None = None):
    """
    Build inventory_planning_snapshot for a given date with:

      - current_reorder_point (history-based)
      - forecasted_reorder_point_90d (forecast-based)
      - stock_status (Under/Over/On Target)
      - recommended_order_qty, should_reorder
    """
    # 1) Determine snapshot date
    if snapshot_date_str:
        snapshot_date = date.fromisoformat(snapshot_date_str)
    else:
        snapshot_date = date.today()

    db: Session = SessionLocal()
    try:
        # 2) Clear existing rows for idempotency
        db.query(InventoryPlanningSnapshot).filter(
            InventoryPlanningSnapshot.snapshot_date == snapshot_date
        ).delete(synchronize_session=False)
        db.commit()

        # 3) Inventory snapshot for that date
        inv_snap_subq = (
            db.query(
                InventorySnapshotDaily.company_id.label("company_id"),
                InventorySnapshotDaily.location_id.label("location_id"),
                InventorySnapshotDaily.product_id.label("product_id"),
                InventorySnapshotDaily.on_hand_qty.label("on_hand_qty"),
                InventorySnapshotDaily.inbound_qty.label("inbound_qty"),
            )
            .filter(InventorySnapshotDaily.snapshot_date == snapshot_date)
            .subquery()
        )

        # 4) Historical demand stats (last 90 days)
        hist_window_days = 90
        hist_start_date = snapshot_date - timedelta(days=hist_window_days)

        sales_stats_subq = (
            db.query(
                DailySales.company_id.label("company_id"),
                DailySales.location_id.label("location_id"),
                DailySales.product_id.label("product_id"),
                func.coalesce(func.avg(DailySales.quantity_sold), 0.0).label("avg_daily_demand"),
                func.coalesce(func.stddev_samp(DailySales.quantity_sold), 0.0).label("sigma_daily_demand"),
            )
            .filter(DailySales.sale_date >= hist_start_date)
            .filter(DailySales.sale_date < snapshot_date)
            .group_by(
                DailySales.company_id,
                DailySales.location_id,
                DailySales.product_id,
            )
            .subquery()
        )

        # 5) Forecast horizon (next 90 days)
        forecast_horizon_days = 90
        horizon_start = snapshot_date + timedelta(days=1)
        horizon_end = snapshot_date + timedelta(days=forecast_horizon_days)

        forecast_90d_subq = (
            db.query(
                DemandForecast.company_id.label("company_id"),
                DemandForecast.location_id.label("location_id"),
                DemandForecast.product_id.label("product_id"),
                func.coalesce(func.sum(DemandForecast.forecast_qty), 0.0).label("forecast_90d_total"),
            )
            .filter(DemandForecast.target_date >= horizon_start)
            .filter(DemandForecast.target_date <= horizon_end)
            .group_by(
                DemandForecast.company_id,
                DemandForecast.location_id,
                DemandForecast.product_id,
            )
            .subquery()
        )

        # 6) Join everything with ReorderPolicy
        q = (
            db.query(
                inv_snap_subq.c.company_id,
                inv_snap_subq.c.location_id,
                inv_snap_subq.c.product_id,
                inv_snap_subq.c.on_hand_qty,
                inv_snap_subq.c.inbound_qty,
                func.coalesce(sales_stats_subq.c.avg_daily_demand, 0.0).label("avg_daily_demand_hist"),
                func.coalesce(sales_stats_subq.c.sigma_daily_demand, 0.0).label("sigma_daily_demand"),
                func.coalesce(ReorderPolicy.lead_time_days, 0).label("lead_time_days"),
                func.coalesce(ReorderPolicy.review_period_days, 0).label("review_period_days"),
                func.coalesce(ReorderPolicy.service_level_target, 0.95).label("service_level_target"),
                func.coalesce(forecast_90d_subq.c.forecast_90d_total, 0.0).label("forecast_90d_total"),
            )
            .outerjoin(
                sales_stats_subq,
                (inv_snap_subq.c.company_id == sales_stats_subq.c.company_id)
                & (inv_snap_subq.c.location_id == sales_stats_subq.c.location_id)
                & (inv_snap_subq.c.product_id == sales_stats_subq.c.product_id),
            )
            .outerjoin(
                ReorderPolicy,
                (inv_snap_subq.c.company_id == ReorderPolicy.company_id)
                & (inv_snap_subq.c.location_id == ReorderPolicy.location_id)
                & (inv_snap_subq.c.product_id == ReorderPolicy.product_id),
            )
            .outerjoin(
                forecast_90d_subq,
                (inv_snap_subq.c.company_id == forecast_90d_subq.c.company_id)
                & (inv_snap_subq.c.location_id == forecast_90d_subq.c.location_id)
                & (inv_snap_subq.c.product_id == forecast_90d_subq.c.product_id),
            )
        )

        planning_rows: list[InventoryPlanningSnapshot] = []

        for row in q.all():
            company_id = row.company_id
            location_id = row.location_id
            product_id = row.product_id

            on_hand_qty = float(row.on_hand_qty or 0.0)
            inbound_qty = float(row.inbound_qty or 0.0)
            avg_daily_demand_hist = float(row.avg_daily_demand_hist or 0.0)
            sigma_daily_demand = float(row.sigma_daily_demand or 0.0)
            lead_time_days = int(row.lead_time_days or 0)
            review_period_days = int(row.review_period_days or 0)
            service_level_target = float(row.service_level_target or 0.95)
            forecast_90d_total = float(row.forecast_90d_total or 0.0)

            # --- CURRENT (history-based) ---
            current_safety_stock = compute_safety_stock(
                avg_daily_demand=avg_daily_demand_hist,
                sigma_daily_demand=sigma_daily_demand,
                lead_time_days=lead_time_days,
                service_level=service_level_target,
            )
            current_reorder_point = compute_reorder_point(
                avg_daily_demand=avg_daily_demand_hist,
                safety_stock=current_safety_stock,
                lead_time_days=lead_time_days,
            )

            # --- FORECAST-BASED (next 90 days) ---
            if forecast_90d_total > 0 and forecast_horizon_days > 0:
                forecast_avg_daily_90d = forecast_90d_total / float(forecast_horizon_days)
            else:
                forecast_avg_daily_90d = avg_daily_demand_hist  # fallback to history

            # Using same sigma_daily_demand as proxy for variability
            forecast_safety_stock_90d = compute_safety_stock(
                avg_daily_demand=forecast_avg_daily_90d,
                sigma_daily_demand=sigma_daily_demand,
                lead_time_days=lead_time_days,
                service_level=service_level_target,
            )
            forecasted_reorder_point_90d = compute_reorder_point(
                avg_daily_demand=forecast_avg_daily_90d,
                safety_stock=forecast_safety_stock_90d,
                lead_time_days=lead_time_days,
            )

            # --- Inventory position & stock status (based on current ROP/target) ---
            available_stock = on_hand_qty + inbound_qty
            min_target = current_reorder_point
            # Simple target: reorder point + review period demand (history-based)
            max_target = current_reorder_point + (avg_daily_demand_hist * review_period_days)

            if available_stock < min_target:
                stock_status = "Understocked"
            elif available_stock > max_target:
                stock_status = "Overstocked"
            else:
                stock_status = "On Target"

            # --- Recommended order qty (based on current target) ---
            target_stock = max_target
            recommended_order_qty = max(target_stock - available_stock, 0.0)

            # TODO: integrate PO check to avoid double ordering
            has_open_po = False
            should_reorder = (
                (available_stock <= current_reorder_point)
                and (recommended_order_qty > 0)
                and not has_open_po
            )

            planning = InventoryPlanningSnapshot(
                snapshot_date=snapshot_date,
                company_id=company_id,
                location_id=location_id,
                product_id=product_id,
                on_hand_qty=on_hand_qty,
                inbound_qty=inbound_qty,
                available_stock=available_stock,
                avg_daily_demand=avg_daily_demand_hist,
                sigma_daily_demand=sigma_daily_demand,
                lead_time_days=lead_time_days,
                review_period_days=review_period_days,
                service_level_target=service_level_target,
                current_safety_stock=current_safety_stock,
                current_reorder_point=current_reorder_point,
                forecast_avg_daily_demand_90d=forecast_avg_daily_90d,
                forecast_safety_stock_90d=forecast_safety_stock_90d,
                forecasted_reorder_point_90d=forecasted_reorder_point_90d,
                min_target=min_target,
                max_target=max_target,
                stock_status=stock_status,
                recommended_order_qty=recommended_order_qty,
                should_reorder=should_reorder,
            )

            planning_rows.append(planning)

        if planning_rows:
            db.bulk_save_objects(planning_rows)
            db.commit()

        return {
            "snapshot_date": snapshot_date.isoformat(),
            "rows_inserted": len(planning_rows),
        }

    finally:
        db.close()
