"""
Script to backfill historical forecast data for past dates.
Generates forecasts as if they were made in the past, for demonstration/testing purposes.

This simulates what forecasts would have looked like if we had made predictions
on historical dates, allowing for comparison of actual vs forecasted data.
"""
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../src')))

import pandas as pd
from datetime import date, datetime, timedelta
from itertools import product as itertools_product
from sqlalchemy.orm import Session

from src.utils.db import SessionLocal
from src.smart_inventory.core.demand_forecaster import CompanyDemandForecaster
from src.smart_inventory.apps.inventory import models
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration
COMPANY_ID = 1
PRODUCT_IDS = list(range(1, 35))  # 1 to 34 (34 products)
LOCATION_IDS = list(range(1, 4))  # 1 to 3
FORECAST_HORIZON_DAYS = 90  # How many days ahead to forecast from each historical date


def load_history_up_to_date(db: Session, company_id: int, location_id: int, product_id: int, up_to_date: date) -> pd.DataFrame:
    """Load historical sales data up to a specific date (for training the model)"""
    query = """
        SELECT company_id, location_id, product_id, sale_date, quantity_sold
        FROM daily_sales
        WHERE company_id = %(company_id)s
          AND location_id = %(location_id)s
          AND product_id = %(product_id)s
          AND sale_date < %(up_to_date)s
        ORDER BY sale_date
    """
    df = pd.read_sql(query, db.bind, params={
        "company_id": company_id,
        "location_id": location_id,
        "product_id": product_id,
        "up_to_date": up_to_date,
    })
    
    return df.rename(columns={"sale_date": "date", "quantity_sold": "qty_sold"})


def generate_historical_forecast(
    db: Session,
    company_id: int,
    location_id: int,
    product_id: int,
    forecast_date: date,
    horizon_days: int
) -> pd.DataFrame:
    """
    Generate forecast as if it was made on a specific historical date.
    
    Args:
        forecast_date: The date when the forecast is "made" (uses data before this date)
        horizon_days: Number of days to forecast ahead from forecast_date
    """
    # Load historical data up to (but not including) the forecast_date
    history_df = load_history_up_to_date(db, company_id, location_id, product_id, forecast_date)
    
    if history_df.empty:
        logger.warning(f"No historical data available before {forecast_date} for "
                      f"company={company_id}, location={location_id}, product={product_id}")
        return pd.DataFrame()
    
    # Generate future dates starting from the forecast_date
    start = forecast_date
    future_dates = [start + timedelta(days=i) for i in range(horizon_days)]
    
    # Load forecaster and make predictions
    forecaster = CompanyDemandForecaster(model_dir="models")
    
    try:
        df_preds = forecaster.predict_for_series(
            company_id=company_id,
            history_df=history_df,
            location_id=location_id,
            product_id=product_id,
            future_dates=future_dates,
        )
        return df_preds
    except Exception as e:
        logger.error(f"Error generating forecast: {str(e)}")
        return pd.DataFrame()


def save_historical_forecast(
    db: Session,
    company_id: int,
    location_id: int,
    product_id: int,
    forecast_date: date,
    df_predictions: pd.DataFrame,
    model_version: str = "v1_backfill"
) -> int:
    """
    Save historical forecast to database with custom forecast_date.
    Deletes existing forecasts for the same company/location/product/date range before inserting new ones.
    Uses database transaction to ensure atomicity (all-or-nothing operation).
    """
    if df_predictions is None or df_predictions.empty:
        return 0
    
    try:
        # Extract date range from predictions
        target_dates = df_predictions["date"].apply(
            lambda x: x.date() if hasattr(x, 'date') else x
        ).tolist()
        min_date = min(target_dates)
        max_date = max(target_dates)
        
        # Delete existing forecasts for this company/location/product/date range
        # This happens within the same transaction as the insert
        deleted_count = db.query(models.DemandForecast).filter(
            models.DemandForecast.company_id == company_id,
            models.DemandForecast.location_id == location_id,
            models.DemandForecast.product_id == product_id,
            models.DemandForecast.target_date >= min_date,
            models.DemandForecast.target_date <= max_date
        ).delete(synchronize_session=False)
        
        if deleted_count > 0:
            logger.info(
                f"Deleted {deleted_count} existing forecasts for company={company_id}, "
                f"location={location_id}, product={product_id}, date_range={min_date} to {max_date}"
            )
        
        # Create forecast records with the specified historical forecast_date
        forecast_objs = []
        for _, row in df_predictions.iterrows():
            fc = models.DemandForecast(
                company_id=company_id,
                location_id=location_id,
                product_id=product_id,
                forecast_date=forecast_date,  # Historical date when forecast was "made"
                target_date=row["date"].date() if hasattr(row["date"], 'date') else row["date"],
                forecast_qty=float(row["forecast_qty"]),
                model_version=model_version,
            )
            forecast_objs.append(fc)
        
        # Bulk insert new forecasts
        db.bulk_save_objects(forecast_objs)
        
        # Commit transaction (both delete and insert succeed together)
        db.commit()
        
        logger.info(
            f"Saved {len(forecast_objs)} historical forecasts "
            f"(forecast_date={forecast_date}, updated: {deleted_count}, inserted: {len(forecast_objs)})"
        )
        return len(forecast_objs)
        
    except Exception as e:
        # Rollback transaction on any error to restore previous state
        db.rollback()
        logger.error(f"Failed to save historical forecasts: {str(e)}")
        return 0


def backfill_for_combination(
    db: Session,
    company_id: int,
    location_id: int,
    product_id: int,
    historical_forecast_dates: list
) -> dict:
    """
    Generate and save historical forecasts for multiple past dates.
    """
    total_saved = 0
    successful_dates = []
    failed_dates = []
    
    for forecast_date in historical_forecast_dates:
        logger.info(f"\nGenerating forecast as of {forecast_date} for "
                   f"company={company_id}, location={location_id}, product={product_id}")
        
        # Generate forecast
        df_preds = generate_historical_forecast(
            db, company_id, location_id, product_id, 
            forecast_date, FORECAST_HORIZON_DAYS
        )
        
        if df_preds.empty:
            failed_dates.append(forecast_date)
            continue
        
        # Save to database
        saved_count = save_historical_forecast(
            db, company_id, location_id, product_id,
            forecast_date, df_preds
        )
        
        if saved_count > 0:
            total_saved += saved_count
            successful_dates.append(forecast_date)
        else:
            failed_dates.append(forecast_date)
    
    return {
        "total_saved": total_saved,
        "successful_dates": successful_dates,
        "failed_dates": failed_dates
    }


def main():
    print("\n" + "="*70)
    print("HISTORICAL FORECAST BACKFILL SCRIPT")
    print("="*70)
    
    # Get user input for date range
    months_back_input = input("\nHow many months back to generate forecasts? [default: 6]: ").strip()
    months_back = int(months_back_input) if months_back_input else 6
    
    # Generate list of historical forecast dates (one per month, mid-month)
    today = date.today()
    historical_forecast_dates = []
    
    for i in range(months_back, 0, -1):
        # Generate a forecast date from i months ago (using 15th of each month)
        forecast_date = (today.replace(day=15) - timedelta(days=30*i))
        historical_forecast_dates.append(forecast_date)
    
    print(f"\nCompany ID: {COMPANY_ID}")
    print(f"Location IDs: {LOCATION_IDS}")
    print(f"Product IDs: {PRODUCT_IDS}")
    print(f"Historical forecast dates: {[d.strftime('%Y-%m-%d') for d in historical_forecast_dates]}")
    print(f"Forecast horizon: {FORECAST_HORIZON_DAYS} days ahead from each date")
    
    # Generate all combinations
    combinations = list(itertools_product(LOCATION_IDS, PRODUCT_IDS))
    total_combinations = len(combinations)
    
    print(f"\nTotal combinations: {total_combinations}")
    print(f"Total forecast operations: {total_combinations * len(historical_forecast_dates)}")
    
    # Confirmation
    response = input(f"\nProceed with backfill? [y/N]: ").strip().lower()
    if response not in ['y', 'yes']:
        print("[CANCELLED]")
        return
    
    print()
    
    db = SessionLocal()
    
    try:
        all_results = []
        
        for idx, (location_id, product_id) in enumerate(combinations, 1):
            print(f"\n{'='*70}")
            print(f"[{idx}/{total_combinations}] Location {location_id}, Product {product_id}")
            print(f"{'='*70}")
            
            result = backfill_for_combination(
                db, COMPANY_ID, location_id, product_id, 
                historical_forecast_dates
            )
            
            result["location_id"] = location_id
            result["product_id"] = product_id
            all_results.append(result)
            
            print(f"Total forecasts saved: {result['total_saved']}")
            print(f"Successful dates: {len(result['successful_dates'])}/{len(historical_forecast_dates)}")
            if result['failed_dates']:
                print(f"Failed dates: {[d.strftime('%Y-%m-%d') for d in result['failed_dates']]}")
        
        # Final summary
        print(f"\n{'='*70}")
        print("BACKFILL SUMMARY")
        print(f"{'='*70}")
        
        total_forecasts_saved = sum(r['total_saved'] for r in all_results)
        successful_combinations = sum(1 for r in all_results if r['total_saved'] > 0)
        
        print(f"Total forecasts saved: {total_forecasts_saved}")
        print(f"Successful combinations: {successful_combinations}/{total_combinations}")
        
        print(f"\n{'='*70}")
        print("Backfill completed!")
        print(f"{'='*70}\n")
        
    except Exception as e:
        logger.error(f"Error during backfill: {str(e)}")
        import traceback
        traceback.print_exc()
    finally:
        db.close()


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n[CANCELLED]")
    except Exception as e:
        print(f"\n[ERROR] {str(e)}")
        import traceback
        traceback.print_exc()
