"""
Script to trigger /predict endpoint with multiple combinations
Generates demand forecasts for various product/location combinations
"""
import requests
import json
import time
from datetime import datetime, timedelta
from itertools import product as itertools_product

# Configuration
BASE_URL = "http://localhost:8000/smart-inventory/inventory/predict"
COMPANY_ID = 1
PRODUCT_IDS = list(range(1, 35))  # 1 to 34 (34 products)
LOCATION_IDS = list(range(1, 4))  # 1 to 3
REQUEST_DELAY = 0.5  # Delay in seconds between requests

# Date range for prediction (180 days / 6 months from today)
start_date = datetime.now().date()
end_date = start_date + timedelta(days=180)

def trigger_prediction(company_id, location_id, product_id, start_date, end_date):
    """
    Trigger prediction endpoint for a specific combination
    """
    payload = {
        "company_id": company_id,
        "location_id": location_id,
        "product_id": product_id,
        "start_date": start_date.isoformat(),
        "end_date": end_date.isoformat()
    }
    
    try:
        response = requests.post(
            BASE_URL,
            headers={
                "accept": "application/json",
                "Content-Type": "application/json"
            },
            json=payload,
            timeout=30
        )
        
        return {
            "success": response.status_code == 200,
            "status_code": response.status_code,
            "response": response.json() if response.status_code == 200 else response.text
        }
        
    except requests.exceptions.RequestException as e:
        return {
            "success": False,
            "error": str(e)
        }


def main():
    print("\nDemand Forecast Prediction Trigger\n")
    
    # Generate all combinations
    combinations = list(itertools_product(LOCATION_IDS, PRODUCT_IDS))
    total_combinations = len(combinations)
    
    print(f"Company ID: {COMPANY_ID}")
    print(f"Location IDs: {LOCATION_IDS}")
    print(f"Product IDs: {PRODUCT_IDS}")
    print(f"Date range: {start_date} to {end_date}")
    print(f"Total combinations: {total_combinations}\n")
    
    # Confirmation
    response = input(f"Proceed with {total_combinations} prediction requests? [y/N]: ").strip().lower()
    if response not in ['y', 'yes']:
        print("[CANCELLED]")
        return
    
    print()
    
    successful = []
    failed = []
    
    # Loop through combinations
    for idx, (location_id, product_id) in enumerate(combinations, 1):
        print(f"\n[{idx}/{total_combinations}] Location {location_id}, Product {product_id}")
        print(f"Payload: company_id={COMPANY_ID}, location_id={location_id}, product_id={product_id}")
        print(f"Date range: {start_date} to {end_date}")
        
        result = trigger_prediction(COMPANY_ID, location_id, product_id, start_date, end_date)
        
        if result["success"]:
            print(f"Status: [SUCCESS]")
            response_data = result["response"]
            
            # Show forecast count
            forecast_count = len(response_data.get("forecasts", []))
            print(f"Forecasts generated: {forecast_count}")
            
            # Show first few forecasts as sample
            if forecast_count > 0:
                print(f"Sample forecasts:")
                for forecast in response_data.get("forecasts", [])[:3]:
                    print(f"  - {forecast['date']}: {forecast['forecast_qty']} units")
                if forecast_count > 3:
                    print(f"  ... and {forecast_count - 3} more")
            
            successful.append({
                "location_id": location_id,
                "product_id": product_id,
                "forecast_count": forecast_count
            })
        else:
            error_msg = result.get("error") or result.get("response", f"HTTP {result.get('status_code')}")
            print(f"Status: [FAILED]")
            print(f"Error: {error_msg}")
            failed.append({
                "location_id": location_id,
                "product_id": product_id,
                "error": error_msg
            })
        
        # Add delay between requests to avoid overwhelming the backend
        if idx < total_combinations:
            time.sleep(REQUEST_DELAY)
    
    # Summary
    print(f"\n{'='*60}")
    print(f"Predictions triggered: {len(successful)} | Failed: {len(failed)}")
    print(f"{'='*60}\n")
    
    if successful:
        print(f"Successfully triggered ({len(successful)}):")
        for item in successful:
            print(f"   Location {item['location_id']}, Product {item['product_id']} -> {item['forecast_count']} forecasts")
    
    if failed:
        print(f"\nFailed ({len(failed)}):")
        for item in failed:
            print(f"   Location {item['location_id']}, Product {item['product_id']} -> {item['error']}")
    
    print(f"\n{'='*60}")
    print("Prediction requests completed.")
    print(f"{'='*60}\n")


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n[CANCELLED]")
    except Exception as e:
        print(f"\n[ERROR] {str(e)}")
        import traceback
        traceback.print_exc()
