"""
Prediction script for monthly regression model.
"""
import numpy as np
import pandas as pd
from pathlib import Path
import argparse
import json

from src.utils.db import get_db_session
from src.smart_inventory.apps.inventory.models import DailySales
from src.smart_inventory.core.monthly_prediction.monthly_regression_model import MonthlyRegressor


def predict_monthly(model_path: str,
                   data_path: str,
                   months_ahead: int = 6,
                   product_id: int = None,
                   location_id: int = None):
    """
    Make monthly predictions using trained regression model.
    
    Args:
        model_path: Path to saved model
        data_path: Path to sales data CSV
        months_ahead: Number of months to predict ahead
        product_id: Product ID (required if not in data)
        location_id: Location ID (required if not in data)
    """
    print("Loading model...")
    regressor = MonthlyRegressor()
    regressor.load_model(model_path)
    
    # Load feature names
    feature_names_path = model_path.replace('.pkl', '_features.json')
    expected_features = None
    if Path(feature_names_path).exists():
        with open(feature_names_path, 'r') as f:
            expected_features = json.load(f)
        print(f"Loaded feature names from {feature_names_path}")
    else:
        print("Warning: Feature names file not found. Predictions may fail if feature mismatch.")
    
    print("Loading and preparing data from database...")
    db = get_db_session()
    try:
        # Query DailySales from database
        query = db.query(
            DailySales.product_id,
            DailySales.location_id,
            DailySales.sale_date,
            DailySales.quantity_sold
        )
        
        # Apply filters at query level if provided
        if product_id is not None:
            query = query.filter(DailySales.product_id == product_id)
        if location_id is not None:
            query = query.filter(DailySales.location_id == location_id)
        
        # Execute query and convert to DataFrame
        results = query.all()
        if not results:
            raise ValueError("No data found for the specified filters")
        
        df = pd.DataFrame(results, columns=['product_id', 'location_id', 'sale_date', 'quantity_sold'])
        print(f"Loaded {len(df)} daily sales records from database")
    finally:
        db.close()
    
    # Note: Filtering already applied at database query level above
    # # Apply filters
    # if product_id is not None:
    #     df = df[df['product_id'] == product_id]
    # if location_id is not None:
    #     df = df[df['location_id'] == location_id]
    # 
    # if len(df) == 0:
    #     raise ValueError("No data found for the specified filters")
    
    # Aggregate to monthly
    df_monthly = regressor.aggregate_to_monthly(
        df,
        target_column='quantity_sold',
        date_column='sale_date'
    )
    
    # Ensure product_id and location_id are set
    if product_id is None or location_id is None:
        if product_id is None:
            product_id = df_monthly['product_id'].iloc[0]
        if location_id is None:
            location_id = df_monthly['location_id'].iloc[0]
        print(f"Using product_id={product_id}, location_id={location_id} from data")
    
    # Filter to specific product and location
    df_monthly = df_monthly[
        (df_monthly['product_id'] == product_id) & 
        (df_monthly['location_id'] == location_id)
    ]
    
    if len(df_monthly) == 0:
        raise ValueError(f"No monthly data found for product_id={product_id} and location_id={location_id}")
    
    # Get latest month and create future months
    df_monthly['sale_date'] = pd.to_datetime(df_monthly['sale_date'])
    last_date = df_monthly['sale_date'].max()
    
    # Create future months (first day of each month)
    future_months = pd.date_range(
        start=last_date + pd.offsets.MonthBegin(1),
        periods=months_ahead,
        freq='MS'  # Month Start
    )
    
    predictions = []
    
    print(f"\nPredicting {months_ahead} months ahead...")
    print(f"For product_id={product_id}, location_id={location_id}")
    
    for i, future_month in enumerate(future_months):
        # Create a simple row for this future month (no need for historical data since no lag/rolling features)
        future_row = pd.DataFrame({
            'sale_date': [future_month],
            'product_id': [product_id],
            'location_id': [location_id],
            'quantity_sold': [0]  # Placeholder, not used in features
        })
        
        # Prepare features for this future month
        X_future, _ = regressor.create_monthly_features(
            future_row,
            target_column='quantity_sold',
            date_column='sale_date'
        )
        
        # Get the prediction row
        X_pred = X_future.copy()
        
        # Ensure feature order matches training
        if expected_features:
            X_pred = X_pred.reindex(columns=expected_features, fill_value=0)
        
        # Make prediction
        X_pred_array = X_pred.values.astype(np.float32)
        pred = regressor.predict(X_pred_array)
        predictions.append(pred[0])
        
        print(f"  Month {i+1} ({future_month.strftime('%Y-%m')}): {pred[0]:.2f} units")
    
    # Create results DataFrame
    results = pd.DataFrame({
        'year_month': future_months.strftime('%Y-%m'),
        'predicted_quantity': predictions,
        'product_id': product_id,
        'location_id': location_id
    })
    
    print("\nForecast Results:")
    print("=" * 70)
    print(f"Product ID: {product_id}, Location ID: {location_id}")
    print(f"Last historical month: {last_date.strftime('%Y-%m')}")
    print("-" * 70)
    print(results.to_string(index=False))
    print("=" * 70)
    
    # Show historical context (last 6 months)
    print("\nHistorical Context (Last 6 Months):")
    print("=" * 70)
    historical = df_monthly.sort_values('sale_date').tail(6)[
        ['sale_date', 'quantity_sold']
    ].copy()
    historical['sale_date'] = pd.to_datetime(historical['sale_date']).dt.strftime('%Y-%m')
    historical.columns = ['year_month', 'actual_quantity']
    print(historical.to_string(index=False))
    
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Predict monthly demand using trained regression model',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Predict 6 months ahead (default) - data fetched from database
  python predict_monthly_regression.py --model monthly_regression_random_forest_model.pkl --product-id 12 --location-id 1
  
  # Predict 12 months ahead - data fetched from database
  python predict_monthly_regression.py --model monthly_regression_random_forest_model.pkl --product-id 12 --location-id 1 --months 12
        """
    )
    
    parser.add_argument('--model', '-m', type=str, required=True,
                       help='Path to trained model file (e.g., models/monthly_regression_random_forest_model.pkl)')
    parser.add_argument('--data', '-d', type=str, default=None,
                       help='Deprecated: Data is now fetched from database (DailySales table)')
    parser.add_argument('--months', type=int, default=6,
                       help='Number of months to predict ahead')
    parser.add_argument('--product-id', '-p', type=int, required=True,
                       help='Product ID (required)')
    parser.add_argument('--location-id', '-l', type=int, required=True,
                       help='Location ID (required)')
    
    args = parser.parse_args()
    
    # Check if file exists
    model_path = args.model
    if not Path(model_path).exists() and not Path(model_path + '.pkl').exists():
        if not model_path.endswith('.pkl'):
            model_path = model_path + '.pkl'
        if not Path(model_path).exists():
            print(f"Error: Model file not found at {args.model}")
            print("Please train a model first using: python train_monthly_regression.py")
            exit(1)
    
    if not model_path.endswith('.pkl'):
        model_path = model_path + '.pkl'
    
    try:
        predict_monthly(
            model_path=model_path,
            data_path=args.data,
            months_ahead=args.months,
            product_id=args.product_id,
            location_id=args.location_id
        )
    except Exception as e:
        print(f"\nError: {e}")
        import traceback
        traceback.print_exc()
        exit(1)

