"""
Training pipeline for monthly regression model with Optuna hyperparameter tuning.
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional
import argparse
from sklearn.model_selection import train_test_split

from src.utils.db import get_db_session
from src.smart_inventory.apps.inventory.models import DailySales
from src.smart_inventory.core.monthly_prediction.monthly_regression_model import MonthlyRegressor


def main(company_id: int,
         data_file: str = 'data/daily_sales.csv',
         model_type: Optional[str] = None,
         n_trials: int = 100,
         test_size: float = 0.2,
         product_id: Optional[int] = None,
         location_id: Optional[int] = None,
         train_all_models: bool = True):
    """
    Main training pipeline for monthly regression model.
    
    Args:
        company_id: Company ID to train model for (required)
        data_file: Path to sales data CSV (deprecated, uses database)
        model_type: Type of model (if None and train_all_models=True, trains all models)
        n_trials: Number of Optuna optimization trials
        test_size: Proportion of data for testing
        train_all_models: If True, train all models and select best (default: True)
    """
    print("=" * 70)
    print("Monthly Regression Model Training with Optuna Hyperparameter Tuning")
    print("=" * 70)
    
    # Step 1: Load data from database
    print(f"\n[1/5] Loading data from database (DailySales table)...")
    print(f"Training model for company_id={company_id}")
    db = get_db_session()
    try:
        # Query DailySales from database - filter by company_id
        query = db.query(
            DailySales.company_id,
            DailySales.product_id,
            DailySales.location_id,
            DailySales.sale_date,
            DailySales.quantity_sold
        ).filter(DailySales.company_id == company_id)

        # Optional filters for training a narrower model
        if product_id is not None:
            query = query.filter(DailySales.product_id == product_id)
        if location_id is not None:
            query = query.filter(DailySales.location_id == location_id)
        
        # Execute query and convert to DataFrame
        results = query.all()
        if not results:
            raise ValueError(f"No data found in database for company_id={company_id}")
        
        df = pd.DataFrame(results, columns=['company_id', 'product_id', 'location_id', 'sale_date', 'quantity_sold'])
        print(f"Loaded {len(df)} daily sales records from database for company {company_id}")
    finally:
        db.close()
    
    # Show data scope
    unique_products = df['product_id'].nunique()
    unique_locations = df['location_id'].nunique()
    print(f"Data scope: {unique_products} unique products, {unique_locations} unique locations")
    
    # Step 2: Aggregate to monthly
    print("\n[2/5] Aggregating daily data to monthly...")
    monthly_regressor = MonthlyRegressor()
    df_monthly = monthly_regressor.aggregate_to_monthly(
        df,
        target_column='quantity_sold',
        date_column='sale_date'
    )
    print(f"Aggregated to {len(df_monthly)} monthly records")
    print(f"Date range: {df_monthly['sale_date'].min()} to {df_monthly['sale_date'].max()}")
    
    # Step 3: Feature engineering
    print("\n[3/5] Engineering monthly features...")
    X, y = monthly_regressor.create_monthly_features(
        df_monthly,
        target_column='quantity_sold',
        date_column='sale_date'
    )
    
    print(f"Created {X.shape[1]} features")
    print(f"Feature names: {list(X.columns[:15])}{'...' if len(X.columns) > 15 else ''}")
    
    # Step 4: Train/test split
    print(f"\n[4/5] Splitting data (test_size={test_size})...")
    X_array = X.values.astype(np.float32)
    y_array = y.values.astype(np.float32)
    
    X_train, X_test, y_train, y_test = train_test_split(
        X_array, y_array, test_size=test_size, random_state=42, shuffle=False
    )
    print(f"Training samples: {len(X_train)} months")
    print(f"Test samples: {len(X_test)} months")
    
    # Step 5: Optimize and train
    # Define all available models
    all_models = ['random_forest', 'xgboost', 'lightgbm', 'gradient_boosting', 
                  'ridge', 'lasso', 'elastic_net']
    
    if train_all_models and model_type is None:
        # Train all models and select the best
        print(f"\n[5/5] Training all models and selecting the best performer...")
        print(f"Models to train: {', '.join(all_models)}")
        print(f"Trials per model: {n_trials}")
        
        results = {}
        
        for current_model_type in all_models:
            print(f"\n{'='*70}")
            print(f"Training {current_model_type.upper()}...")
            print(f"{'='*70}")
            
            try:
                monthly_regressor = MonthlyRegressor(
                    model_type=current_model_type,
                    n_trials=n_trials,
                    cv_folds=5,
                    scoring='neg_mean_squared_error',
                    random_state=42
                )
                
                monthly_regressor.optimize(X_train, y_train)
                monthly_regressor.fit(X_train, y_train)
                
                # Evaluate on test set
                test_metrics = monthly_regressor.evaluate(X_test, y_test)
                
                results[current_model_type] = {
                    'regressor': monthly_regressor,
                    'test_metrics': test_metrics,
                    'test_rmse': test_metrics['RMSE'],
                    'test_r2': test_metrics['R2']
                }
                
                print(f"{current_model_type.upper()} - Test RMSE: {test_metrics['RMSE']:.4f}, R²: {test_metrics['R2']:.4f}")
                
            except Exception as e:
                print(f"Error training {current_model_type}: {e}")
                continue
        
        # Find best model (lowest RMSE on test set)
        if not results:
            raise ValueError("No models were successfully trained!")
        
        best_model_type = min(results.keys(), key=lambda k: results[k]['test_rmse'])
        best_result = results[best_model_type]
        monthly_regressor = best_result['regressor']
        
        print(f"\n{'='*70}")
        print("MODEL COMPARISON RESULTS")
        print(f"{'='*70}")
        print(f"{'Model':<20} {'Test RMSE':<15} {'Test R²':<15} {'Test MAE':<15}")
        print("-" * 70)
        for model_name, result in sorted(results.items(), key=lambda x: x[1]['test_rmse']):
            metrics = result['test_metrics']
            marker = " <-- BEST" if model_name == best_model_type else ""
            print(f"{model_name:<20} {metrics['RMSE']:<15.4f} {metrics['R2']:<15.4f} {metrics['MAE']:<15.4f}{marker}")
        print(f"{'='*70}")
        print(f"\nBest model: {best_model_type.upper()} (Test RMSE: {best_result['test_rmse']:.4f})")
        
        model_type = best_model_type  # Use best model for saving
        
    else:
        # Train single specified model
        if model_type is None:
            model_type = 'random_forest'
        
        print(f"\n[5/5] Optimizing {model_type} with Optuna ({n_trials} trials)...")
        monthly_regressor = MonthlyRegressor(
            model_type=model_type,
            n_trials=n_trials,
            cv_folds=5,
            scoring='neg_mean_squared_error',
            random_state=42
        )
        
        monthly_regressor.optimize(X_train, y_train)
        monthly_regressor.fit(X_train, y_train)
    
    # Evaluate best/single model
    print("\nEvaluating model...")
    train_metrics = monthly_regressor.evaluate(X_train, y_train)
    test_metrics = monthly_regressor.evaluate(X_test, y_test)
    
    print("\n" + "=" * 70)
    print("Training Set Metrics:")
    print("=" * 70)
    for metric, value in train_metrics.items():
        print(f"{metric:10s}: {value:.4f}")
    
    print("\n" + "=" * 70)
    print("Test Set Metrics:")
    print("=" * 70)
    for metric, value in test_metrics.items():
        print(f"{metric:10s}: {value:.4f}")
    
    # Feature importance
    if hasattr(monthly_regressor.best_model, 'feature_importances_'):
        print("\n" + "=" * 70)
        print("Top 10 Most Important Features:")
        print("=" * 70)
        feature_importance = pd.DataFrame({
            'feature': X.columns,
            'importance': monthly_regressor.best_model.feature_importances_
        }).sort_values('importance', ascending=False)
        print(feature_importance.head(10).to_string(index=False))
    
    # Save model with company-specific naming
    # Use absolute path to ensure models are saved in the correct location
    import os
    project_root = Path(__file__).parent.parent.parent.parent.parent
    models_dir = project_root / 'models'
    models_dir.mkdir(exist_ok=True)
    
    model_filename = models_dir / f"company_{company_id}_monthly_demand_forecast_model.pkl"
    monthly_regressor.save_model(str(model_filename))
    print(f"\nModel saved to: {model_filename}")
    
    # Save feature names
    feature_names_path = models_dir / f"company_{company_id}_monthly_demand_forecast_features.json"
    import json
    with open(feature_names_path, 'w') as f:
        json.dump(list(X.columns), f)
    print(f"Feature names saved to: {feature_names_path}")
    
    print("\n" + "=" * 70)
    print("Training completed successfully!")
    print("=" * 70)
    
    return monthly_regressor, X_train, X_test, y_train, y_test


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Train monthly regression model with Optuna hyperparameter tuning',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Train all models and select the best (default behavior)
  python train_monthly_regression.py
  
  # Train all models with 100 trials each
  python train_monthly_regression.py --trials 100
  
  # Train only XGBoost model
  python train_monthly_regression.py --model xgboost --single-model
  
  # Train all models for specific product and location
  python train_monthly_regression.py --product-id 12 --location-id 1
  
  # Train only LightGBM for specific product and location
  python train_monthly_regression.py --product-id 12 --location-id 1 --model lightgbm --single-model
        """
    )

    parser.add_argument('--data', type=str, default=None,
                       help='Deprecated: Data is now fetched from database (DailySales table)')
    parser.add_argument('--company-id', type=int, required=True,
                       help='Company ID to train model for (required)')
    parser.add_argument('--model', type=str, default=None,
                       choices=['random_forest', 'xgboost', 'lightgbm', 'gradient_boosting', 
                               'ridge', 'lasso', 'elastic_net', 'svr'],
                       help='Type of regression model (if not specified, trains all and selects best)')
    parser.add_argument('--trials', type=int, default=50,
                       help='Number of Optuna optimization trials per model (default: 50)')
    parser.add_argument('--test-size', type=float, default=0.2,
                       help='Proportion of data for testing')
    parser.add_argument('--product-id', type=int, default=None,
                       help='Filter by product ID (optional)')
    parser.add_argument('--location-id', type=int, default=None,
                       help='Filter by location ID (optional)')
    parser.add_argument('--single-model', action='store_true',
                       help='Train only the specified model (or random_forest if none specified)')
    
    args = parser.parse_args()
    
    try:
        main(
            company_id=args.company_id,
            data_file=args.data,
            model_type=args.model,
            n_trials=args.trials,
            test_size=args.test_size,
            product_id=args.product_id,
            location_id=args.location_id,
            train_all_models=not args.single_model
        )
    except Exception as e:
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()
        exit(1)

