from pathlib import Path
from typing import List
import joblib
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from datetime import date


class CompanyDemandForecaster:
    """
    One model per company.
    Features:
      - location_id, product_id (categorical)
      - date-based features (dow, month, etc.)
      - time-series features: lags + rolling means of qty_sold
    Target:
      - qty_sold (per day)
    """

    def __init__(self, model_dir: str = "models"):
        self.model_dir = Path(model_dir)
        self.model_dir.mkdir(parents=True, exist_ok=True)

    # ---------- Paths ----------
    def _model_path(self, company_id: int) -> Path:
        return self.model_dir / f"company_{company_id}_demand.pkl"

    # ---------- Feature helpers ----------
    @staticmethod
    def _add_date_features(df: pd.DataFrame) -> pd.DataFrame:
        df = df.copy()
        df["date"] = pd.to_datetime(df["date"])
        df["day_of_week"] = df["date"].dt.weekday     # 0-6
        df["week_of_year"] = df["date"].dt.isocalendar().week.astype(int)
        df["month"] = df["date"].dt.month
        df["day_of_year"] = df["date"].dt.dayofyear
        df["is_weekend"] = df["day_of_week"].isin([5, 6]).astype(int)
        return df

    @staticmethod
    def _add_time_series_features(
        df: pd.DataFrame,
        group_cols: List[str],
        target_col: str = "qty_sold",
    ) -> pd.DataFrame:
        """
        Add lag and rolling features per (location_id, product_id).
        NOTE: df must be sorted by date before calling this.
        """
        df = df.copy()
        df = df.sort_values(group_cols + ["date"])

        g = df.groupby(group_cols)[target_col]

        # Lags
        df["lag_1"] = g.shift(1)
        df["lag_7"] = g.shift(7)
        df["lag_14"] = g.shift(14)

        # Rolling means on shifted series (to avoid leakage)
        df["roll_7_mean"] = g.shift(1).rolling(window=7, min_periods=1).mean()
        df["roll_28_mean"] = g.shift(1).rolling(window=28, min_periods=1).mean()

        # You can also add rolling std, max, etc., similarly

        # Fill initial NaNs with 0 or some fallback
        lag_cols = ["lag_1", "lag_7", "lag_14", "roll_7_mean", "roll_28_mean"]
        df[lag_cols] = df[lag_cols].fillna(0.0)

        return df

    # ---------- Training ----------
    def train_company_model(self, df: pd.DataFrame, company_id: int) -> str:
        """
        df must contain:
        ['company_id', 'location_id', 'product_id', 'date', 'qty_sold', ...]
        """

        df_company = df[df["company_id"] == company_id].copy()
        if df_company.empty:
            raise ValueError(f"No training data found for company_id={company_id}")

        # Add date features
        df_company = self._add_date_features(df_company)

        # Add time-series features per (location, product)
        group_cols = ["location_id", "product_id"]
        df_company = self._add_time_series_features(
            df_company,
            group_cols=group_cols,
            target_col="qty_sold",
        )

        # Define features & target
        feature_cols = [
            "location_id",
            "product_id",
            "day_of_week",
            "week_of_year",
            "month",
            "day_of_year",
            "is_weekend",
            "lag_1",
            "lag_7",
            "lag_14",
            "roll_7_mean",
            "roll_28_mean",
        ]
        target_col = "qty_sold"

        X = df_company[feature_cols]
        y = df_company[target_col]

        categorical_features = ["location_id", "product_id"]
        numeric_features = [c for c in feature_cols if c not in categorical_features]

        preprocessor = ColumnTransformer(
            transformers=[
                ("cat", OneHotEncoder(handle_unknown="ignore"), categorical_features),
                ("num", "passthrough", numeric_features),
            ]
        )

        model = RandomForestRegressor(
            n_estimators=300,
            random_state=42,
            n_jobs=-1,
        )

        pipeline = Pipeline(
            steps=[
                ("preprocessor", preprocessor),
                ("model", model),
            ]
        )

        pipeline.fit(X, y)

        # Save model + feature metadata
        model_path = self._model_path(company_id)
        joblib.dump(
            {
                "pipeline": pipeline,
                "feature_cols": feature_cols,
                "group_cols": group_cols,
                "target_col": target_col,
            },
            model_path,
        )

        return str(model_path)

    # ---------- Prediction ----------
    def predict_for_series(
        self,
        company_id: int,
        history_df: pd.DataFrame,
        location_id: int,
        product_id: int,
        future_dates: List[date],
    ) -> pd.DataFrame:
        """
        Predict demand for a given (company, location, product) for future_dates.

        history_df must contain at least:
        ['company_id', 'location_id', 'product_id', 'date', 'qty_sold']
        for THAT company/location/product, up to the DAY BEFORE the first future date.
        """

        model_path = self._model_path(company_id)
        if not model_path.exists():
            raise FileNotFoundError(
                f"Model file not found for company_id={company_id}. Train it first."
            )

        saved = joblib.load(model_path)
        pipeline = saved["pipeline"]
        feature_cols = saved["feature_cols"]
        group_cols = saved["group_cols"]

        # Filter history for this series
        df_hist = history_df[
            (history_df["company_id"] == company_id)
            & (history_df["location_id"] == location_id)
            & (history_df["product_id"] == product_id)
        ].copy()

        if df_hist.empty:
            raise ValueError(
                f"No history found for company={company_id}, "
                f"location={location_id}, product={product_id}"
            )

        # Build future frame
        df_future = pd.DataFrame(
            {
                "company_id": company_id,
                "location_id": location_id,
                "product_id": product_id,
                "date": future_dates,
                # qty_sold unknown for future, keep NaN
                "qty_sold": [float("nan")] * len(future_dates),
            }
        )

        # Combine history + future to compute lags & rolling consistently
        df_all = pd.concat([df_hist, df_future], ignore_index=True)

        # Add date-based features
        df_all = self._add_date_features(df_all)

        # Add time-series features (lags & rolling)
        df_all = self._add_time_series_features(
            df_all,
            group_cols=group_cols,
            target_col="qty_sold",
        )

        # Select only future rows for prediction
        mask_future = df_all["date"].isin(pd.to_datetime(future_dates))
        df_future_feats = df_all[mask_future].copy()

        X_future = df_future_feats[feature_cols]
        preds = pipeline.predict(X_future)

        df_future_feats["forecast_qty"] = preds

        return df_future_feats[
            ["company_id", "location_id", "product_id", "date", "forecast_qty"]
        ]
