"""
Celery task for running monthly demand prediction job for all product/location combinations.
Uses the generate_monthly_forecasts function from controller to generate and save forecasts.
"""
from __future__ import annotations
from typing import Dict
from datetime import datetime, date

from src.utils.celery_worker import celery_app
from src.smart_inventory.utils.celery_dispatch import dispatch_task


@celery_app.task(bind=True, name='src.smart_inventory.tasks.monthly_demand_prediction_job_task.run_monthly_demand_prediction_job')
def run_monthly_demand_prediction_job(
    self,
    company_id: int,
    months_ahead: int = 6
) -> Dict:
    """
    Run monthly demand prediction for all product/location combinations for a company.
    
    This task:
    - Finds all unique product_id and location_id combinations for the company
    - Generates monthly forecasts for each combination
    - Saves predictions to the MonthlyForecast table
    
    Args:
        company_id: Company ID to run predictions for
        months_ahead: Number of months to predict ahead (default: 6)
    
    Returns:
        Dict with success status, counts, and any errors
    """
    from src.utils.db import get_db
    from src.smart_inventory.apps.inventory.models import (
        CeleryTaskTracker, 
        CeleryTaskStatus,
        DailySales
    )
    from src.smart_inventory.apps.inventory.controller import generate_monthly_forecasts
    from sqlalchemy import distinct
    from dateutil.relativedelta import relativedelta
    
    # Get database session
    db = next(get_db())
    task_tracker = None
    task_id = self.request.id
    
    print(f"[CELERY TASK] Starting monthly demand prediction job with ID: {task_id} for company_id: {company_id}")
    
    try:
        # Get and update task status to STARTED
        task_tracker = db.query(CeleryTaskTracker).filter(
            CeleryTaskTracker.task_id == task_id
        ).first()
        
        print(f"[CELERY TASK] Found task_tracker: {task_tracker}")
        
        if task_tracker:
            task_tracker.status = CeleryTaskStatus.STARTED
            task_tracker.started_at = datetime.now()
            db.commit()
            print(f"[CELERY TASK] Updated status to STARTED")
        else:
            print(f"[CELERY TASK] WARNING: No task_tracker found for task_id: {task_id}")
        
        # Get all unique product_id and location_id combinations for this company
        combinations = db.query(
            distinct(DailySales.product_id),
            DailySales.location_id
        ).filter(
            DailySales.company_id == company_id
        ).group_by(
            DailySales.product_id,
            DailySales.location_id
        ).all()
        
        if not combinations:
            # Update task status to FAILURE
            if task_tracker:
                task_tracker.status = CeleryTaskStatus.FAILURE
                task_tracker.error_message = f"No product/location combinations found for company {company_id}"
                task_tracker.completed_at = datetime.now()
                db.commit()
            
            return {
                "success": False,
                "message": f"No product/location combinations found for company {company_id}",
                "company_id": company_id,
                "total_combinations": 0,
                "successful_predictions": 0,
                "failed_predictions": 0
            }
        
        # Calculate date range
        today = date.today()
        start_month = date(today.year, today.month, 1)
        end_month = start_month + relativedelta(months=months_ahead - 1)
        
        # Get unique counts
        unique_product_ids = set(combo[0] for combo in combinations)
        unique_location_ids = set(combo[1] for combo in combinations)
        
        total_combinations = len(combinations)
        successful_predictions = 0
        failed_predictions = 0
        errors = []
        
        print(f"[CELERY TASK] Processing {total_combinations} combinations for company {company_id}")
        print(f"[CELERY TASK] Unique products: {len(unique_product_ids)}, Unique locations: {len(unique_location_ids)}")
        print(f"[CELERY TASK] Predicting from {start_month} to {end_month} ({months_ahead} months)")
        
        for product_id, location_id in combinations:
            try:
                # Generate and save forecasts for this combination
                generate_monthly_forecasts(
                    db=db,
                    company_id=company_id,
                    location_id=location_id,
                    product_id=product_id,
                    start_month=start_month,
                    end_month=end_month
                )
                successful_predictions += 1
                
                if successful_predictions % 100 == 0:
                    print(f"[CELERY TASK] Processed {successful_predictions}/{total_combinations} combinations")
                    
            except Exception as e:
                failed_predictions += 1
                error_msg = f"product_id={product_id}, location_id={location_id}: {str(e)}"
                errors.append(error_msg)
                print(f"[CELERY TASK] Error: {error_msg}")
                continue
        
        # Update task status to SUCCESS
        if task_tracker:
            task_tracker.status = CeleryTaskStatus.SUCCESS
            task_tracker.completed_at = datetime.now()
            db.commit()
        
        # Trigger inventory planning snapshot task for today's date after successful prediction
        try:
            today_str = date.today().strftime("%Y-%m-%d")
            print(f"[CELERY TASK] Triggering inventory planning snapshot for {today_str}")
            
            # Trigger the task and get the async result
            planning_task = dispatch_task(
                "src.smart_inventory.tasks.inventory_planning_task.compute_inventory_planning_snapshot",
                args=(today_str,)
            )
            planning_task_id = planning_task.id
            
            # Create a CeleryTaskTracker entry for the inventory planning task
            planning_tracker = CeleryTaskTracker(
                task_id=planning_task_id,
                task_name="compute_inventory_planning_snapshot",
                status=CeleryTaskStatus.PENDING,
                company_id=company_id,
                created_at=datetime.now()
            )
            db.add(planning_tracker)
            db.commit()
            
            print(f"[CELERY TASK] Inventory planning snapshot task triggered with ID: {planning_task_id}")
        except Exception as e:
            print(f"[CELERY TASK] Error triggering inventory planning snapshot: {e}")
        
        result = {
            "success": True,
            "message": f"Monthly demand prediction job completed for company {company_id}",
            "company_id": company_id,
            "months_ahead": months_ahead,
            "start_month": start_month.strftime("%Y-%m"),
            "end_month": end_month.strftime("%Y-%m"),
            "unique_product_ids": len(unique_product_ids),
            "unique_location_ids": len(unique_location_ids),
            "total_combinations": total_combinations,
            "successful_predictions": successful_predictions,
            "failed_predictions": failed_predictions,
            "errors": errors[:10] if errors else []  # Limit error messages
        }
        
        print(f"[CELERY TASK] Job completed: {successful_predictions} successful, {failed_predictions} failed")
        return result
        
    except Exception as e:
        # Update task status to FAILURE
        error_message = str(e)
        try:
            if not task_tracker:
                task_tracker = db.query(CeleryTaskTracker).filter(
                    CeleryTaskTracker.task_id == task_id
                ).first()
            
            if task_tracker:
                task_tracker.status = CeleryTaskStatus.FAILURE
                task_tracker.error_message = error_message[:2000]  # Truncate if too long
                task_tracker.completed_at = datetime.now()
                db.commit()
        except Exception as db_error:
            print(f"[CELERY TASK] Error updating task tracker: {db_error}")
        
        print(f"[CELERY TASK] Job failed with error: {error_message}")
        
        return {
            "success": False,
            "message": f"Monthly demand prediction job failed: {error_message}",
            "company_id": company_id,
            "error": error_message
        }
    finally:
        db.close()
