"""
Monthly regression models for demand forecasting with Optuna hyperparameter tuning.
Aggregates daily data to monthly and predicts monthly demand.
Follows SOLID principles with separation of concerns.
"""
import numpy as np
import pandas as pd
from typing import Optional, Dict, Any, Tuple
from abc import ABC, abstractmethod
import optuna
from sklearn.model_selection import cross_val_score, KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import Ridge, Lasso, ElasticNet
import xgboost as xgb
import lightgbm as lgb
import pickle

from src.smart_inventory.core.monthly_prediction.regression_model import RegressionModelFactory


class MonthlyRegressor:
    """Monthly regression model with Optuna hyperparameter optimization."""
    
    def __init__(self, 
                 model_type: str = 'random_forest',
                 n_trials: int = 100,
                 cv_folds: int = 5,
                 scoring: str = 'neg_mean_squared_error',
                 random_state: int = 42):
        """
        Initialize monthly Optuna-based regressor.
        
        Args:
            model_type: Type of model to optimize
            n_trials: Number of Optuna trials
            cv_folds: Number of cross-validation folds
            scoring: Scoring metric for optimization
            random_state: Random seed
        """
        self.model_type = model_type
        self.n_trials = n_trials
        self.cv_folds = cv_folds
        self.scoring = scoring
        self.random_state = random_state
        self.best_model = None
        self.best_params = None
        self.study = None
    
    @staticmethod
    def aggregate_to_monthly(df: pd.DataFrame, 
                            target_column: str = 'quantity_sold',
                            date_column: str = 'sale_date') -> pd.DataFrame:
        """
        Aggregate daily data to monthly by product_id and location_id.
        
        Args:
            df: DataFrame with daily sales data
            target_column: Column to aggregate (default: 'quantity_sold')
            date_column: Date column name (default: 'sale_date')
            
        Returns:
            DataFrame with monthly aggregated data
        """
        df = df.copy()
        
        # Ensure date column is datetime
        if not pd.api.types.is_datetime64_any_dtype(df[date_column]):
            df[date_column] = pd.to_datetime(df[date_column])
        
        # Create year-month column
        df['year_month'] = df[date_column].dt.to_period('M')
        
        # Group by year_month, product_id, and location_id
        monthly_df = df.groupby(['year_month', 'product_id', 'location_id']).agg({
            target_column: 'sum'
        }).reset_index()
        
        # Convert year_month to datetime (first day of month)
        monthly_df['sale_date'] = pd.to_datetime(monthly_df['year_month'].astype(str))
        
        # Drop year_month column
        monthly_df = monthly_df.drop(columns=['year_month'])
        
        return monthly_df
    
    @staticmethod
    def create_monthly_features(df: pd.DataFrame,
                                target_column: str = 'quantity_sold',
                                date_column: str = 'sale_date') -> Tuple[pd.DataFrame, pd.Series]:
        """
        Create features for monthly regression model.
        Only uses year, month, and date-related features (no lag or rolling features).
        
        Args:
            df: DataFrame with monthly aggregated data
            target_column: Name of target column
            date_column: Name of date column
            
        Returns:
            Tuple of (features_df, target_series)
        """
        df = df.copy()
        
        # Ensure date column is datetime
        if not pd.api.types.is_datetime64_any_dtype(df[date_column]):
            df[date_column] = pd.to_datetime(df[date_column])
        
        # Sort by date
        df = df.sort_values([date_column, 'product_id', 'location_id'])
        
        # Extract date features
        df['year'] = df[date_column].dt.year
        df['month'] = df[date_column].dt.month
        df['quarter'] = df[date_column].dt.quarter
        
        # Cyclical encoding for periodic features
        df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
        df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
        df['quarter_sin'] = np.sin(2 * np.pi * df['quarter'] / 4)
        df['quarter_cos'] = np.cos(2 * np.pi * df['quarter'] / 4)
        
        # One-hot encode product_id and location_id
        for col in ['product_id', 'location_id']:
            if col in df.columns:
                dummies = pd.get_dummies(df[col], prefix=col, drop_first=False)
                # Convert boolean dummies to int to ensure numeric type
                dummies = dummies.astype(int)
                df = pd.concat([df, dummies], axis=1)
                df = df.drop(columns=[col])
        
        # Separate features and target
        exclude_cols = [target_column, date_column]
        datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist()
        exclude_cols.extend(datetime_cols)
        
        feature_columns = [
            col for col in df.columns 
            if col not in exclude_cols
        ]
        
        X = df[feature_columns].copy()
        y = df[target_column].copy()
        
        # Fill NaN values
        X = X.fillna(0)
        
        # Ensure all columns are numeric
        for col in X.columns:
            if not pd.api.types.is_numeric_dtype(X[col]):
                X[col] = pd.to_numeric(X[col], errors='coerce').fillna(0)
        
        # Final conversion to ensure all are numeric
        X = X.select_dtypes(include=[np.number])
        
        return X, y
    
    def _create_trial_params(self, trial: optuna.Trial, model_type: str) -> Dict[str, Any]:
        """Create hyperparameter suggestions for a trial (same as OptunaRegressor)."""
        model_type = model_type.lower()
        
        if model_type == 'random_forest':
            return {
                'n_estimators': trial.suggest_int('n_estimators', 50, 500),
                'max_depth': trial.suggest_int('max_depth', 5, 50),
                'min_samples_split': trial.suggest_int('min_samples_split', 2, 20),
                'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 10),
                'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', None])
            }
        elif model_type == 'xgboost':
            return {
                'n_estimators': trial.suggest_int('n_estimators', 50, 500),
                'max_depth': trial.suggest_int('max_depth', 3, 10),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
                'subsample': trial.suggest_float('subsample', 0.6, 1.0),
                'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
                'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
                'gamma': trial.suggest_float('gamma', 0, 5),
                'reg_alpha': trial.suggest_float('reg_alpha', 0, 10),
                'reg_lambda': trial.suggest_float('reg_lambda', 0, 10)
            }
        elif model_type == 'lightgbm':
            return {
                'n_estimators': trial.suggest_int('n_estimators', 50, 500),
                'max_depth': trial.suggest_int('max_depth', 3, 15),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
                'num_leaves': trial.suggest_int('num_leaves', 10, 300),
                'subsample': trial.suggest_float('subsample', 0.6, 1.0),
                'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
                'min_child_samples': trial.suggest_int('min_child_samples', 5, 100),
                'reg_alpha': trial.suggest_float('reg_alpha', 0, 10),
                'reg_lambda': trial.suggest_float('reg_lambda', 0, 10)
            }
        elif model_type == 'gradient_boosting':
            return {
                'n_estimators': trial.suggest_int('n_estimators', 50, 300),
                'max_depth': trial.suggest_int('max_depth', 3, 10),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.2, log=True),
                'subsample': trial.suggest_float('subsample', 0.6, 1.0),
                'min_samples_split': trial.suggest_int('min_samples_split', 2, 20),
                'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 10)
            }
        elif model_type == 'ridge':
            return {
                'alpha': trial.suggest_float('alpha', 0.1, 100, log=True)
            }
        elif model_type == 'lasso':
            return {
                'alpha': trial.suggest_float('alpha', 0.1, 100, log=True)
            }
        elif model_type == 'elastic_net':
            return {
                'alpha': trial.suggest_float('alpha', 0.1, 100, log=True),
                'l1_ratio': trial.suggest_float('l1_ratio', 0.0, 1.0)
            }
        else:
            raise ValueError(f"Unknown model type: {model_type}")
    
    def optimize(self, X: np.ndarray, y: np.ndarray) -> optuna.Study:
        """
        Optimize hyperparameters using Optuna.
        
        Args:
            X: Feature matrix
            y: Target vector
            
        Returns:
            Optuna study object
        """
        def objective(trial):
            params = self._create_trial_params(trial, self.model_type)
            model = RegressionModelFactory.create_model(self.model_type, params)
            
            kfold = KFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state)
            scores = cross_val_score(
                model, X, y, 
                cv=kfold, 
                scoring=self.scoring,
                n_jobs=-1
            )
            
            return scores.mean()
        
        self.study = optuna.create_study(
            direction='maximize',
            study_name=f'monthly_{self.model_type}_optimization',
            sampler=optuna.samplers.TPESampler(seed=self.random_state)
        )
        
        print(f"Starting Optuna optimization for monthly {self.model_type} ({self.n_trials} trials)...")
        self.study.optimize(objective, n_trials=self.n_trials, show_progress_bar=True)
        
        self.best_params = self.study.best_params
        self.best_model = RegressionModelFactory.create_model(self.model_type, self.best_params)
        
        print(f"\nBest trial: {self.study.best_trial.number}")
        print(f"Best score: {self.study.best_value:.4f}")
        print(f"Best params: {self.best_params}")
        
        return self.study
    
    def fit(self, X: np.ndarray, y: np.ndarray):
        """Train the best model found during optimization."""
        if self.best_model is None:
            raise ValueError("Model must be optimized before fitting. Call optimize() first.")
        self.best_model.fit(X, y)
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions using the best model."""
        if self.best_model is None:
            raise ValueError("Model must be fitted before prediction. Call fit() first.")
        return self.best_model.predict(X)
    
    def evaluate(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
        """Evaluate model performance."""
        y_pred = self.predict(X)
        
        mse = mean_squared_error(y, y_pred)
        mae = mean_absolute_error(y, y_pred)
        rmse = np.sqrt(mse)
        mape = np.mean(np.abs((y - y_pred) / (y + 1e-8))) * 100
        r2 = r2_score(y, y_pred)
        
        return {
            'MSE': mse,
            'MAE': mae,
            'RMSE': rmse,
            'MAPE': mape,
            'R2': r2
        }
    
    def save_model(self, filepath: str):
        """Save the trained model to disk as pickle file."""
        if self.best_model is None:
            raise ValueError("No model to save")
        if not filepath.endswith('.pkl'):
            filepath = filepath.rsplit('.', 1)[0] + '.pkl'
        
        model_data = {
            'model': self.best_model,
            'best_params': self.best_params,
            'model_type': self.model_type
        }
        with open(filepath, 'wb') as f:
            pickle.dump(model_data, f)
    
    def load_model(self, filepath: str):
        """Load a trained model from pickle file."""
        if not filepath.endswith('.pkl'):
            filepath = filepath.rsplit('.', 1)[0] + '.pkl'
        
        with open(filepath, 'rb') as f:
            loaded = pickle.load(f)
        self.best_model = loaded['model']
        self.best_params = loaded['best_params']
        self.model_type = loaded['model_type']

