"""
Script to trigger demand forecast model training for one or multiple companies
This triggers the Celery task for each specified company ID
"""
import sys
from pathlib import Path

# Add the project root to the Python path
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

try:
    from src.smart_inventory.tasks.demand_forecast_task import train_demand_forecast_model
    print("[SUCCESS] Successfully imported train_demand_forecast_model task")
except ImportError as e:
    print(f"[ERROR] Error importing Celery task: {e}")
    print("Make sure Celery is running and the task module is available")
    sys.exit(1)


def run_forecast_training(company_ids: list):
    """
    Run demand forecast model training for specified companies
    
    Args:
        company_ids: List of company IDs to train models for
    """
    print(f"\nCompanies to train: {', '.join(map(str, company_ids))}")
    print(f"Total: {len(company_ids)} company(ies)\n")
    
    # Confirmation
    response = input(f"Proceed with training? [y/N]: ").strip().lower()
    if response not in ['y', 'yes']:
        print("[CANCELLED]")
        return
    
    print()
    
    task_ids = []
    failed_companies = []
    
    # Loop through each company and trigger the task
    for idx, company_id in enumerate(company_ids, 1):
        try:
            print(f"[{idx}/{len(company_ids)}] Triggering training for company {company_id}...", end=" ")
            
            # Trigger the Celery task
            task = train_demand_forecast_model.delay(company_id)
            task_ids.append({
                'company_id': company_id,
                'task_id': task.id
            })
            
            print(f"[SUCCESS] Task ID: {task.id}")
            
        except Exception as e:
            print(f"[FAILED]")
            print(f"    Error: {str(e)}")
            failed_companies.append({
                'company_id': company_id,
                'error': str(e)
            })
    
    # Summary
    print(f"\n{'='*60}")
    print(f"Tasks triggered: {len(task_ids)} | Failed: {len(failed_companies)}")
    print(f"{'='*60}\n")
    
    if task_ids:
        print("Successfully triggered:")
        for item in task_ids:
            print(f"   Company {item['company_id']} -> Task {item['task_id']}")
    
    if failed_companies:
        print(f"\nFailed:")
        for item in failed_companies:
            print(f"   Company {item['company_id']} -> {item['error']}")
    
    print(f"\n{'='*60}")
    print("Training tasks are running in Celery. Check logs for progress.")
    print("Note: Training may take several minutes depending on data size.")
    print(f"{'='*60}\n")


def main():
    """Main entry point"""
    print("\nDemand Forecast Model Training - Celery Trigger\n")
    print("=" * 60)
    
    # Get company IDs from user
    print("\nEnter company ID(s) to train models for:")
    print("  - Single company: enter one ID (e.g., 1)")
    print("  - Multiple companies: enter comma-separated IDs (e.g., 1,2,3)")
    print("=" * 60)
    
    while True:
        try:
            company_input = input("\nCompany ID(s): ").strip()
            
            if not company_input:
                print("[ERROR] Please enter at least one company ID")
                continue
            
            # Parse input - handle both single and comma-separated values
            if ',' in company_input:
                company_ids = [int(cid.strip()) for cid in company_input.split(',') if cid.strip()]
            else:
                company_ids = [int(company_input)]
            
            if not company_ids:
                print("[ERROR] No valid company IDs provided")
                continue
            
            # Validate all IDs are positive
            if any(cid <= 0 for cid in company_ids):
                print("[ERROR] All company IDs must be positive numbers")
                continue
            
            break
            
        except ValueError:
            print("[ERROR] Invalid input. Please enter numeric company ID(s)")
            continue
    
    # Run the training
    run_forecast_training(company_ids)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n[CANCELLED] Operation cancelled by user")
        sys.exit(0)
    except Exception as e:
        print(f"\n[ERROR] Unexpected error: {str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
