from pathlib import Path
from typing import List
import joblib
import pandas as pd
import math
from sklearn.ensemble import RandomForestRegressor
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from datetime import date


class CompanyDemandForecaster:
    """
    One model per company.
    Features:
      - location_id, product_id (categorical)
      - date-based features (dow, month, etc.)
      - time-series features: lags + rolling means of qty_sold
    Target:
      - qty_sold (per day)
    """

    def __init__(self, model_dir: str = "models"):
        self.model_dir = Path(model_dir)
        self.model_dir.mkdir(parents=True, exist_ok=True)

    # ---------- Paths ----------
    def _model_path(self, company_id: int) -> Path:
        return self.model_dir / f"company_{company_id}_demand.pkl"

    # ---------- Feature helpers ----------
    @staticmethod
    def _add_date_features(df: pd.DataFrame) -> pd.DataFrame:
        df = df.copy()
        # Convert to datetime with UTC to handle mixed timezone values
        # This prevents "Cannot mix tz-aware with tz-naive values" error
        df["date"] = pd.to_datetime(df["date"], utc=True)
        # Remove timezone info to make all dates timezone-naive
        df["date"] = df["date"].dt.tz_localize(None)
        df["day_of_week"] = df["date"].dt.weekday     # 0-6
        df["week_of_year"] = df["date"].dt.isocalendar().week.astype(int)
        df["month"] = df["date"].dt.month
        df["day_of_year"] = df["date"].dt.dayofyear
        df["is_weekend"] = df["day_of_week"].isin([5, 6]).astype(int)
        return df

    @staticmethod
    def _add_time_series_features(
        df: pd.DataFrame,
        group_cols: List[str],
        target_col: str = "qty_sold",
    ) -> pd.DataFrame:
        """
        Add lag and rolling features per (location_id, product_id).
        NOTE: df must be sorted by date before calling this.
        """
        df = df.copy()
        df = df.sort_values(group_cols + ["date"])

        g = df.groupby(group_cols)[target_col]

        # Lags
        df["lag_1"] = g.shift(1)
        df["lag_7"] = g.shift(7)
        df["lag_14"] = g.shift(14)

        # Rolling means on shifted series (to avoid leakage)
        df["roll_7_mean"] = g.shift(1).rolling(window=7, min_periods=1).mean()
        df["roll_28_mean"] = g.shift(1).rolling(window=28, min_periods=1).mean()

        # You can also add rolling std, max, etc., similarly

        # Fill initial NaNs with 0 or some fallback
        lag_cols = ["lag_1", "lag_7", "lag_14", "roll_7_mean", "roll_28_mean"]
        df[lag_cols] = df[lag_cols].fillna(0.0)

        return df

    # ---------- Training ----------
    def train_company_model(self, df: pd.DataFrame, company_id: int) -> str:
        """
        df must contain:
        ['company_id', 'location_id', 'product_id', 'date', 'qty_sold', ...]
        """

        df_company = df[df["company_id"] == company_id].copy()
        if df_company.empty:
            raise ValueError(f"No training data found for company_id={company_id}")

        # Add date features
        df_company = self._add_date_features(df_company)

        # Add time-series features per (location, product)
        group_cols = ["location_id", "product_id"]
        df_company = self._add_time_series_features(
            df_company,
            group_cols=group_cols,
            target_col="qty_sold",
        )

        # Define features & target
        feature_cols = [
            "location_id",
            "product_id",
            "day_of_week",
            "week_of_year",
            "month",
            "day_of_year",
            "is_weekend",
            "lag_1",
            "lag_7",
            "lag_14",
            "roll_7_mean",
            "roll_28_mean",
        ]
        target_col = "qty_sold"

        X = df_company[feature_cols]
        y = df_company[target_col]

        categorical_features = ["location_id", "product_id"]
        numeric_features = [c for c in feature_cols if c not in categorical_features]

        preprocessor = ColumnTransformer(
            transformers=[
                ("cat", OneHotEncoder(handle_unknown="ignore"), categorical_features),
                ("num", "passthrough", numeric_features),
            ]
        )

        model = RandomForestRegressor(
            n_estimators=300,
            random_state=42,
            n_jobs=-1,
        )

        pipeline = Pipeline(
            steps=[
                ("preprocessor", preprocessor),
                ("model", model),
            ]
        )

        pipeline.fit(X, y)

        # Save model + feature metadata
        model_path = self._model_path(company_id)
        joblib.dump(
            {
                "pipeline": pipeline,
                "feature_cols": feature_cols,
                "group_cols": group_cols,
                "target_col": target_col,
            },
            model_path,
        )

        return str(model_path)

    # ---------- Prediction ----------
    def predict_for_series(
        self,
        company_id: int,
        history_df: pd.DataFrame,
        location_id: int,
        product_id: int,
        future_dates: List[date], 
    ) -> pd.DataFrame:
        """
        Predict demand for a given (company, location, product) for future_dates.

        history_df must contain at least:
        ['company_id', 'location_id', 'product_id', 'date', 'qty_sold']
        for THAT company/location/product, up to the DAY BEFORE the first future date.
        """

        model_path = self._model_path(company_id)
        if not model_path.exists():
            raise FileNotFoundError(
                f"Model file not found for company_id={company_id}. Train it first."
            )

        saved = joblib.load(model_path)
        pipeline = saved["pipeline"]
        feature_cols = saved["feature_cols"]
        group_cols = saved["group_cols"]

        # Filter history for this series
        df_hist = history_df[
            (history_df["company_id"] == company_id)
            & (history_df["location_id"] == location_id)
            & (history_df["product_id"] == product_id)
        ].copy()

        if df_hist.empty:
            raise ValueError(
                f"No history found for company={company_id}, "
                f"location={location_id}, product={product_id}"
            )

        # Build future frame
        df_future = pd.DataFrame(
            {
                "company_id": company_id,
                "location_id": location_id,
                "product_id": product_id,
                "date": future_dates,
                # qty_sold unknown for future, keep NaN
                "qty_sold": [float("nan")] * len(future_dates),
            }
        )

        # Combine history + future to compute lags & rolling consistently
        df_all = pd.concat([df_hist, df_future], ignore_index=True)

        # Add date-based features
        df_all = self._add_date_features(df_all)

        # Add time-series features (lags & rolling)
        df_all = self._add_time_series_features(
            df_all,
            group_cols=group_cols,
            target_col="qty_sold",
        )

        # Select only future rows for prediction
        mask_future = df_all["date"].isin(pd.to_datetime(future_dates))
        df_future_feats = df_all[mask_future].copy()

        X_future = df_future_feats[feature_cols]
        preds = pipeline.predict(X_future)

        df_future_feats["forecast_qty"] = [math.ceil(pred) for pred in preds]

        return df_future_feats[
            ["company_id", "location_id", "product_id", "date", "forecast_qty"]
        ]
