"""
Task Orchestrator for Smart Inventory
=====================================

This module provides utilities to automatically trigger Celery tasks based on 
affected database tables. It handles:
- Task dependency ordering
- HOLD status for tasks waiting on dependencies
- Polling for HOLD tasks to start when dependencies complete
- Tracking via CeleryTaskTracker

This is a table-based orchestrator - it determines which tasks to run based on
which database tables were modified. It can be used from:
- CSV upload processing (data_import/controller.py)
- Webhook event processing (inventory/webhookcontroller.py)
- Any other data modification flow

Usage:
    from src.smart_inventory.utils.task_orchestrator import trigger_tasks_for_affected_tables
    
    # After any data modification
    trigger_tasks_for_affected_tables(
        db=db,
        affected_tables=["sales_orders", "sales_order_lines", "inventory_movements"],
        company_id=1,
        unique_dates=["2026-01-01", "2026-01-02"]
    )
"""

import logging
import sys
import uuid
from datetime import datetime, date
from typing import List, Dict, Set, Optional, Any
from enum import Enum
from sqlalchemy.orm import Session

logger = logging.getLogger(__name__)


# =============================================================================
# TASK DEFINITIONS AND DEPENDENCIES
# =============================================================================

class TaskName(str, Enum):
    """Celery task names for inventory processing"""
    DAILY_SALES = "compute_daily_sales"
    INVENTORY_SNAPSHOT = "compute_inventory_snapshot"
    SERVICE_LEVEL = "compute_service_level_daily"
    SLOW_MOVERS = "compute_slow_movers_90d"
    INVENTORY_PLANNING = "compute_inventory_planning_snapshot"


# Task dependencies: key depends on values (must complete before key can start)
TASK_DEPENDENCIES: Dict[TaskName, List[TaskName]] = {
    TaskName.DAILY_SALES: [],  # No dependencies
    TaskName.INVENTORY_SNAPSHOT: [],  # No dependencies (can run parallel with daily_sales)
    TaskName.SERVICE_LEVEL: [TaskName.DAILY_SALES],  # Depends on daily_sales
    TaskName.SLOW_MOVERS: [TaskName.DAILY_SALES],  # Depends on daily_sales (and inventory_batch but that's from CSV)
    TaskName.INVENTORY_PLANNING: [TaskName.INVENTORY_SNAPSHOT, TaskName.DAILY_SALES],  # Depends on both
}


# Execution order priority (lower = runs first)
TASK_PRIORITY: Dict[TaskName, int] = {
    TaskName.DAILY_SALES: 1,
    TaskName.INVENTORY_SNAPSHOT: 1,  # Same priority as daily_sales (can run parallel)
    TaskName.SERVICE_LEVEL: 2,
    TaskName.SLOW_MOVERS: 2,
    TaskName.INVENTORY_PLANNING: 3,
}


# Tables that trigger each task
# When any of these tables are affected, the corresponding task should run
# NOTE: Tasks with dependencies will only start when their dependencies complete
TABLE_TO_TASKS: Dict[str, List[TaskName]] = {
    # Sales-related tables trigger daily_sales
    "sales_orders": [TaskName.DAILY_SALES],
    "sales_order_lines": [TaskName.DAILY_SALES],
    "sales_return_orders": [TaskName.DAILY_SALES],
    "sales_return_order_lines": [TaskName.DAILY_SALES],
    
    # Inventory movement tables trigger inventory_snapshot
    "inventory_movements": [TaskName.INVENTORY_SNAPSHOT],
    "inventory_movement": [TaskName.INVENTORY_SNAPSHOT],
    
    # Inventory batch changes also trigger inventory_snapshot (not slow_movers directly)
    # slow_movers runs as downstream of daily_sales (requires sales data)
    "inventory_batches": [TaskName.INVENTORY_SNAPSHOT],
    "inventory_batch": [TaskName.INVENTORY_SNAPSHOT],
    
    # Purchase receive updates trigger inventory_snapshot (for in_transit_qty recalculation)
    "purchase_order_receive": [TaskName.INVENTORY_SNAPSHOT],
    "purchase_order_receive_lines": [TaskName.INVENTORY_SNAPSHOT],
    
    # Purchase return (vendor return) updates trigger inventory_snapshot
    "purchase_order_return": [TaskName.INVENTORY_SNAPSHOT],
    "purchase_order_return_lines": [TaskName.INVENTORY_SNAPSHOT],
}


# Downstream tasks: when a task completes, these tasks may need to run
DOWNSTREAM_TASKS: Dict[TaskName, List[TaskName]] = {
    TaskName.DAILY_SALES: [TaskName.SERVICE_LEVEL, TaskName.SLOW_MOVERS, TaskName.INVENTORY_PLANNING],
    TaskName.INVENTORY_SNAPSHOT: [TaskName.INVENTORY_PLANNING],
    TaskName.SERVICE_LEVEL: [],
    TaskName.SLOW_MOVERS: [],
    TaskName.INVENTORY_PLANNING: [],
}


# Full Celery task name mapping (for send_task on Linux)
# NOTE: These must match the 'name' parameter in @celery_app.task() decorators
TASK_NAME_TO_CELERY_NAME: Dict[TaskName, str] = {
    TaskName.DAILY_SALES: "src.smart_inventory.tasks.daily_sales_task.compute_daily_sales",
    TaskName.INVENTORY_SNAPSHOT: "src.smart_inventory.tasks.snapshot_task.compute_inventory_snapshot",
    TaskName.SERVICE_LEVEL: "src.smart_inventory.tasks.service_level_task.compute_service_level_daily",
    TaskName.SLOW_MOVERS: "src.smart_inventory.tasks.slow_movers_task.compute_slow_movers_90d",
    TaskName.INVENTORY_PLANNING: "src.smart_inventory.tasks.inventory_planning_task.compute_inventory_planning_snapshot",
}


# =============================================================================
# TASK IMPORT HELPERS
# =============================================================================

def _get_celery_task(task_name: TaskName):
    """Get the actual Celery task function by name (for Windows)"""
    try:
        if task_name == TaskName.DAILY_SALES:
            from src.smart_inventory.tasks.daily_sales_task import compute_daily_sales
            return compute_daily_sales
        elif task_name == TaskName.INVENTORY_SNAPSHOT:
            from src.smart_inventory.tasks.inventory_snapshot_task import compute_inventory_snapshot
            return compute_inventory_snapshot
        elif task_name == TaskName.SERVICE_LEVEL:
            from src.smart_inventory.tasks.service_level_task import compute_service_level_daily
            return compute_service_level_daily
        elif task_name == TaskName.SLOW_MOVERS:
            from src.smart_inventory.tasks.slow_movers_task import compute_slow_movers_90d
            return compute_slow_movers_90d
        elif task_name == TaskName.INVENTORY_PLANNING:
            from src.smart_inventory.tasks.inventory_planning_task import compute_inventory_planning_snapshot
            return compute_inventory_planning_snapshot
    except ImportError as e:
        logger.warning(f"Could not import task {task_name}: {e}")
        return None
    return None


def _dispatch_celery_task(task_name: TaskName, kwargs: Dict[str, Any], task_id: str) -> bool:
    """
    Dispatch a Celery task with platform-specific handling.
    
    Windows (solo pool): Uses direct task import and apply_async
    Linux (prefork pool): Uses send_task with task name string
    
    Args:
        task_name: The TaskName enum
        kwargs: Task keyword arguments
        task_id: Task ID to use
        
    Returns:
        True if task was dispatched successfully
    """
    try:
        if sys.platform == 'win32':
            # Windows: use direct import and apply_async
            celery_task = _get_celery_task(task_name)
            if celery_task:
                celery_task.apply_async(kwargs=kwargs, task_id=task_id)
                return True
        else:
            # Linux: use send_task with task name for prefork pool compatibility
            from src.utils.celery_worker import celery_app
            celery_name = TASK_NAME_TO_CELERY_NAME.get(task_name)
            if celery_name:
                celery_app.send_task(celery_name, kwargs=kwargs, task_id=task_id)
                return True
        return False
    except Exception as e:
        logger.error(f"Failed to dispatch task {task_name}: {e}")
        return False


# =============================================================================
# CORE ORCHESTRATION FUNCTIONS
# =============================================================================

def get_required_tasks(affected_tables: List[str]) -> Set[TaskName]:
    """
    Determine which tasks need to run based on affected tables.
    Also includes downstream dependent tasks.
    
    Args:
        affected_tables: List of table names that were modified
        
    Returns:
        Set of TaskName enums for tasks that should run
    """
    required_tasks: Set[TaskName] = set()
    
    # Normalize table names (lowercase, handle plurals)
    normalized_tables = [t.lower().replace("-", "_") for t in affected_tables]
    
    # Find directly triggered tasks
    for table in normalized_tables:
        if table in TABLE_TO_TASKS:
            for task in TABLE_TO_TASKS[table]:
                required_tasks.add(task)
    
    # Add downstream tasks
    tasks_to_check = list(required_tasks)
    while tasks_to_check:
        task = tasks_to_check.pop()
        for downstream_task in DOWNSTREAM_TASKS.get(task, []):
            if downstream_task not in required_tasks:
                required_tasks.add(downstream_task)
                tasks_to_check.append(downstream_task)
    
    return required_tasks


def get_tasks_in_order(tasks: Set[TaskName]) -> List[TaskName]:
    """
    Sort tasks by priority/dependency order.
    
    Args:
        tasks: Set of tasks to sort
        
    Returns:
        List of tasks sorted by execution priority
    """
    return sorted(tasks, key=lambda t: TASK_PRIORITY.get(t, 99))


def check_dependencies_met(
    db: Session,
    task_name: TaskName,
    company_id: int,
    target_date: str,
    batch_id: str
) -> bool:
    """
    Check if all dependencies for a task have completed successfully.
    
    Only checks dependencies that were actually triggered in the current batch.
    If a dependency task doesn't exist in this batch, it's considered "not required"
    and skipped - this handles cases like purchase_receive where DAILY_SALES
    is never triggered but INVENTORY_PLANNING still needs to run.
    
    Special case: INVENTORY_PLANNING for today's date needs to wait for ALL
    DAILY_SALES tasks in the batch (not just today's), because the planning
    calculation uses historical sales data from the past 90 days.
    
    Args:
        db: Database session
        task_name: The task to check dependencies for
        company_id: Company ID
        target_date: Target date string (YYYY-MM-DD)
        batch_id: Batch ID to group related tasks
        
    Returns:
        True if all dependencies are met (SUCCESS status) or not required
    """
    from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
    from datetime import date as date_type
    
    dependencies = TASK_DEPENDENCIES.get(task_name, [])
    
    if not dependencies:
        return True
    
    # Special handling for INVENTORY_PLANNING on today's date:
    # It needs ALL daily_sales tasks in the batch to complete, not just today's
    # because planning calculates avg_daily_demand from last 90 days of sales
    today_str = date_type.today().isoformat()
    is_planning_for_today = (
        task_name == TaskName.INVENTORY_PLANNING and 
        target_date == today_str
    )
    
    for dep_task in dependencies:
        if is_planning_for_today and dep_task == TaskName.DAILY_SALES:
            # For INVENTORY_PLANNING on today: wait for ALL daily_sales in batch
            pattern = f"{TaskName.DAILY_SALES.value}_%_{batch_id}"
            all_daily_sales = db.query(CeleryTaskTracker).filter(
                CeleryTaskTracker.company_id == company_id,
                CeleryTaskTracker.task_name.like(pattern)
            ).all()
            
            # If any daily_sales tasks exist and are not complete, wait
            for ds_tracker in all_daily_sales:
                if ds_tracker.status != CeleryTaskStatus.SUCCESS:
                    logger.debug(f"Waiting for {ds_tracker.task_name} to complete for {task_name.value}")
                    return False
            
            logger.debug(f"All daily_sales tasks complete for {task_name.value}")
            continue
        
        # Standard dependency check for other cases
        dep_task_name = f"{dep_task.value}_{target_date}_{batch_id}"
        
        # Check if this dependency task exists in this batch
        dep_tracker = db.query(CeleryTaskTracker).filter(
            CeleryTaskTracker.company_id == company_id,
            CeleryTaskTracker.task_name == dep_task_name
        ).first()
        
        # If dependency wasn't triggered in this batch, skip it
        # (e.g., purchase_receive doesn't trigger daily_sales, so don't wait for it)
        if not dep_tracker:
            logger.debug(f"Dependency {dep_task.value} not in batch {batch_id}, skipping")
            continue
        
        # If dependency exists but not SUCCESS yet, wait for it
        if dep_tracker.status != CeleryTaskStatus.SUCCESS:
            logger.debug(f"Dependency {dep_task.value} not yet complete for {task_name.value}")
            return False
    
    return True


def trigger_single_task(
    db: Session,
    task_name: TaskName,
    company_id: int,
    target_date: str,
    batch_id: str,
    check_dependencies: bool = True
) -> Optional[str]:
    """
    Trigger a single Celery task with tracking.
    
    Args:
        db: Database session
        task_name: Task to trigger
        company_id: Company ID
        target_date: Target date (YYYY-MM-DD)
        batch_id: Batch ID for grouping
        check_dependencies: Whether to check and set HOLD status
        
    Returns:
        Task ID if triggered, None if failed
    """
    from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
    
    # Create unique task name with date and batch
    tracker_task_name = f"{task_name.value}_{target_date}_{batch_id}"
    
    # Check if task already exists for this combination
    existing = db.query(CeleryTaskTracker).filter(
        CeleryTaskTracker.task_name == tracker_task_name,
        CeleryTaskTracker.company_id == company_id
    ).first()
    
    if existing:
        logger.debug(f"Task {tracker_task_name} exists: {existing.status}")
        return existing.task_id
    
    # Generate task ID
    task_id = str(uuid.uuid4())
    
    # Check dependencies
    dependencies_met = True
    if check_dependencies:
        dependencies_met = check_dependencies_met(db, task_name, company_id, target_date, batch_id)
    
    # Determine initial status
    initial_status = CeleryTaskStatus.PENDING if dependencies_met else CeleryTaskStatus.HOLD
    
    # Create tracker record FIRST
    tracker = CeleryTaskTracker(
        task_id=task_id,
        task_name=tracker_task_name,
        company_id=company_id,
        status=initial_status,
        started_at=datetime.now() if dependencies_met else None
    )
    db.add(tracker)
    db.commit()
    
    # If dependencies are met, trigger the task
    if dependencies_met:
        # Build kwargs based on task type
        if task_name == TaskName.DAILY_SALES:
            task_kwargs = {"target_date": target_date}
        elif task_name == TaskName.INVENTORY_SNAPSHOT:
            task_kwargs = {"target_date": target_date}
        elif task_name == TaskName.SERVICE_LEVEL:
            task_kwargs = {"target_date_str": target_date}
        elif task_name == TaskName.SLOW_MOVERS:
            task_kwargs = {"snapshot_date_str": target_date}
        elif task_name == TaskName.INVENTORY_PLANNING:
            task_kwargs = {"snapshot_date_str": target_date}
        else:
            logger.warning(f"Unknown task type: {task_name}")
            return None
        
        try:
            # Dispatch task with platform-specific handling
            if _dispatch_celery_task(task_name, task_kwargs, task_id):
                logger.debug(f"Triggered {task_name.value} for {target_date}")
            else:
                logger.warning(f"Celery task {task_name.value} not available")
                return None
        except Exception as e:
            logger.error(f"Failed to trigger task {task_name.value}: {e}")
            tracker.status = CeleryTaskStatus.FAILURE
            tracker.error_message = str(e)
            db.commit()
            return None
    else:
        logger.debug(f"{task_name.value} set to HOLD")
    
    return task_id


def trigger_tasks_for_affected_tables(
    db: Session,
    affected_tables: List[str],
    company_id: int,
    unique_dates: List[str],
    batch_id: Optional[str] = None
) -> Dict[str, Any]:
    """
    Main function to trigger all required tasks based on affected tables.
    
    This is the primary entry point for the task orchestrator. Call this after
    CSV processing to automatically trigger all necessary background tasks.
    
    EXECUTION ORDER (wave-based):
    1. DAILY_SALES for ALL unique dates in CSV (computes historical sales data)
    2. INVENTORY_SNAPSHOT for TODAY only (current inventory state)
    3. SERVICE_LEVEL for TODAY only (current service level)
    4. SLOW_MOVERS for TODAY only (current slow mover analysis)
    5. INVENTORY_PLANNING for TODAY only (current planning snapshot)
    
    Each wave waits for the previous wave to complete before starting.
    This ensures proper data availability for dependent calculations.
    
    Args:
        db: Database session
        affected_tables: List of table names that were modified by CSV upload
        company_id: Company ID the data belongs to
        unique_dates: List of unique dates (YYYY-MM-DD) in the uploaded data
        batch_id: Optional batch ID for grouping (auto-generated if not provided)
        
    Returns:
        Dict with triggered task information:
        {
            "batch_id": str,
            "tasks_triggered": int,
            "tasks_on_hold": int,
            "task_details": [...]
        }
    """
    from src.smart_inventory.apps.inventory.models import CeleryTaskStatus
    
    if not batch_id:
        batch_id = str(uuid.uuid4())[:8]
    
    # Get required tasks
    required_tasks = get_required_tasks(affected_tables)
    
    if not required_tasks:
        logger.debug(f"No tasks for tables: {affected_tables}")
        return {
            "batch_id": batch_id,
            "tasks_triggered": 0,
            "tasks_on_hold": 0,
            "task_details": []
        }
    
    # Define the execution order (wave-based)
    # Each wave completes before the next starts
    EXECUTION_ORDER = [
        TaskName.DAILY_SALES,
        TaskName.INVENTORY_SNAPSHOT,
        TaskName.SERVICE_LEVEL,
        TaskName.SLOW_MOVERS,
        TaskName.INVENTORY_PLANNING,
    ]
    
    # Filter to only required tasks in order
    ordered_tasks = [t for t in EXECUTION_ORDER if t in required_tasks]
    
    # DAILY_SALES runs for ALL unique dates (to compute historical sales data)
    # Other tasks run ONLY for TODAY (we only care about today's snapshot)
    all_dates = unique_dates  # For DAILY_SALES
    today_str = date.today().isoformat()
    today_only = [today_str]  # For other dependent tasks
    
    logger.info(f"Wave-based execution: {[t.value for t in ordered_tasks]}")
    logger.info(f"DAILY_SALES dates: {all_dates} ({len(all_dates)} dates)")
    logger.info(f"Snapshot/Planning dates: {today_only} (today only)")
    
    task_details = []
    tasks_triggered = 0
    tasks_on_hold = 0
    
    # Process tasks in waves: 
    # - DAILY_SALES: all dates
    # - Other tasks: today only
    for wave_index, task_name in enumerate(ordered_tasks):
        # Determine which dates to process for this task
        if task_name == TaskName.DAILY_SALES:
            dates_to_process = all_dates
        else:
            # INVENTORY_SNAPSHOT, SERVICE_LEVEL, SLOW_MOVERS, INVENTORY_PLANNING
            # Only run for today - we only care about current snapshot
            dates_to_process = today_only
        
        for target_date in dates_to_process:
            task_id = trigger_single_task_wave(
                db=db,
                task_name=task_name,
                company_id=company_id,
                target_date=target_date,
                batch_id=batch_id,
                wave_index=wave_index,
                previous_wave_task=ordered_tasks[wave_index - 1] if wave_index > 0 else None,
                all_dates=all_dates  # Still pass all_dates for dependency checking
            )
            
            if task_id:
                # Get status from tracker
                from src.smart_inventory.apps.inventory.models import CeleryTaskTracker
                tracker = db.query(CeleryTaskTracker).filter(
                    CeleryTaskTracker.task_id == task_id
                ).first()
                
                status = tracker.status.value if tracker else "unknown"
                
                task_details.append({
                    "task_name": task_name.value,
                    "task_id": task_id,
                    "target_date": target_date,
                    "status": status,
                    "wave": wave_index + 1
                })
                
                if status == CeleryTaskStatus.HOLD.value:
                    tasks_on_hold += 1
                else:
                    tasks_triggered += 1
    
    # If there are tasks on hold, ensure the poller is running
    if tasks_on_hold > 0:
        _ensure_hold_poller_running(db, company_id, batch_id)
    
    return {
        "batch_id": batch_id,
        "tasks_triggered": tasks_triggered,
        "tasks_on_hold": tasks_on_hold,
        "task_details": task_details
    }


def trigger_single_task_wave(
    db: Session,
    task_name: TaskName,
    company_id: int,
    target_date: str,
    batch_id: str,
    wave_index: int,
    previous_wave_task: Optional[TaskName],
    all_dates: List[str]
) -> Optional[str]:
    """
    Trigger a single task with wave-based dependency checking.
    
    Wave 0 tasks run immediately.
    Wave N tasks wait for ALL tasks in wave N-1 to complete.
    
    Args:
        db: Database session
        task_name: Task to trigger
        company_id: Company ID
        target_date: Target date (YYYY-MM-DD)
        batch_id: Batch ID for grouping
        wave_index: Which wave this task belongs to (0-based)
        previous_wave_task: The task type from the previous wave (to wait for)
        all_dates: All dates in this batch (to check all previous wave tasks)
        
    Returns:
        Task ID if triggered, None if failed
    """
    from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
    
    # Create unique task name with date and batch
    tracker_task_name = f"{task_name.value}_{target_date}_{batch_id}"
    
    # Check if task already exists for this combination
    existing = db.query(CeleryTaskTracker).filter(
        CeleryTaskTracker.task_name == tracker_task_name,
        CeleryTaskTracker.company_id == company_id
    ).first()
    
    if existing:
        logger.debug(f"Task {tracker_task_name} exists: {existing.status}")
        return existing.task_id
    
    # Generate task ID
    task_id = str(uuid.uuid4())
    
    # Check if previous wave is complete (wave 0 has no previous)
    dependencies_met = True
    if wave_index > 0 and previous_wave_task:
        dependencies_met = check_previous_wave_complete(
            db, previous_wave_task, company_id, batch_id, all_dates
        )
    
    # Determine initial status
    initial_status = CeleryTaskStatus.PENDING if dependencies_met else CeleryTaskStatus.HOLD
    
    # Create tracker record FIRST
    tracker = CeleryTaskTracker(
        task_id=task_id,
        task_name=tracker_task_name,
        company_id=company_id,
        status=initial_status,
        started_at=datetime.now() if dependencies_met else None
    )
    db.add(tracker)
    db.commit()
    
    # If dependencies are met, trigger the task
    if dependencies_met:
        # Build kwargs based on task type
        if task_name == TaskName.DAILY_SALES:
            task_kwargs = {"target_date": target_date}
        elif task_name == TaskName.INVENTORY_SNAPSHOT:
            task_kwargs = {"target_date": target_date}
        elif task_name == TaskName.SERVICE_LEVEL:
            task_kwargs = {"target_date_str": target_date}
        elif task_name == TaskName.SLOW_MOVERS:
            task_kwargs = {"snapshot_date_str": target_date}
        elif task_name == TaskName.INVENTORY_PLANNING:
            task_kwargs = {"snapshot_date_str": target_date}
        else:
            logger.warning(f"Unknown task type: {task_name}")
            return None
        
        try:
            # Dispatch task with platform-specific handling
            if _dispatch_celery_task(task_name, task_kwargs, task_id):
                logger.debug(f"Wave {wave_index}: Triggered {task_name.value} for {target_date}")
            else:
                logger.warning(f"Celery task {task_name.value} not available")
                return None
        except Exception as e:
            logger.error(f"Failed to trigger task {task_name.value}: {e}")
            tracker.status = CeleryTaskStatus.FAILURE
            tracker.error_message = str(e)
            db.commit()
            return None
    else:
        logger.debug(f"Wave {wave_index}: {task_name.value} for {target_date} set to HOLD")
    
    return task_id


def check_previous_wave_complete(
    db: Session,
    previous_task: TaskName,
    company_id: int,
    batch_id: str,
    all_dates: List[str]
) -> bool:
    """
    Check if ALL tasks of the previous wave (for all dates) have completed.
    
    Args:
        db: Database session
        previous_task: The task type from the previous wave
        company_id: Company ID
        batch_id: Batch ID
        all_dates: All dates in this batch
        
    Returns:
        True if all previous wave tasks are complete (SUCCESS status)
    """
    from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
    
    for target_date in all_dates:
        task_name = f"{previous_task.value}_{target_date}_{batch_id}"
        
        tracker = db.query(CeleryTaskTracker).filter(
            CeleryTaskTracker.task_name == task_name,
            CeleryTaskTracker.company_id == company_id
        ).first()
        
        # If task doesn't exist or isn't complete, wave is not complete
        if not tracker:
            logger.debug(f"Previous wave task not found: {task_name}")
            return False
        
        if tracker.status != CeleryTaskStatus.SUCCESS:
            logger.debug(f"Previous wave task not complete: {task_name} ({tracker.status.value})")
            return False
    
    logger.debug(f"Previous wave ({previous_task.value}) complete for all {len(all_dates)} dates")
    return True


# =============================================================================
# HOLD TASK POLLING
# =============================================================================

def process_hold_tasks(db: Session, company_id: Optional[int] = None, batch_id: Optional[str] = None) -> int:
    """
    Check HOLD tasks and start them if previous wave is complete.
    
    Wave-based approach: tasks wait for ALL tasks of the previous wave to complete.
    
    Args:
        db: Database session
        company_id: Optional filter by company
        batch_id: Optional filter by batch
        
    Returns:
        Number of tasks started
    """
    from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
    
    # Define the execution order for wave checking
    EXECUTION_ORDER = [
        TaskName.DAILY_SALES,
        TaskName.INVENTORY_SNAPSHOT,
        TaskName.SERVICE_LEVEL,
        TaskName.SLOW_MOVERS,
        TaskName.INVENTORY_PLANNING,
    ]
    
    # Query for HOLD tasks
    query = db.query(CeleryTaskTracker).filter(
        CeleryTaskTracker.status == CeleryTaskStatus.HOLD
    )
    
    if company_id:
        query = query.filter(CeleryTaskTracker.company_id == company_id)
    
    hold_tasks = query.all()
    
    if not hold_tasks:
        return 0
    
    tasks_started = 0
    
    for tracker in hold_tasks:
        # Parse task name to get original task and date
        # Format: task_name_YYYY-MM-DD_batch_id
        parts = tracker.task_name.rsplit("_", 2)
        if len(parts) < 3:
            continue
        
        task_value = parts[0]
        target_date = parts[1]
        task_batch_id = parts[2]
        
        if batch_id and task_batch_id != batch_id:
            continue
        
        # Find matching TaskName
        task_name = None
        wave_index = -1
        for i, tn in enumerate(EXECUTION_ORDER):
            if tn.value == task_value:
                task_name = tn
                wave_index = i
                break
        
        if not task_name or wave_index < 0:
            continue
        
        # Wave 0 (DAILY_SALES) has no previous wave - should run immediately
        if wave_index == 0:
            dependencies_met = True
        else:
            # Get all dates in this batch by looking at existing tasks
            previous_task = EXECUTION_ORDER[wave_index - 1]
            pattern = f"{previous_task.value}_%_{task_batch_id}"
            
            previous_wave_tasks = db.query(CeleryTaskTracker).filter(
                CeleryTaskTracker.company_id == tracker.company_id,
                CeleryTaskTracker.task_name.like(pattern)
            ).all()
            
            # Check if all previous wave tasks are complete
            dependencies_met = True
            if not previous_wave_tasks:
                # No previous wave tasks exist - this might be an edge case
                dependencies_met = True
            else:
                for prev_tracker in previous_wave_tasks:
                    if prev_tracker.status != CeleryTaskStatus.SUCCESS:
                        dependencies_met = False
                        break
        
        if dependencies_met:
            # Build kwargs based on task type
            if task_name == TaskName.DAILY_SALES:
                task_kwargs = {"target_date": target_date}
            elif task_name == TaskName.INVENTORY_SNAPSHOT:
                task_kwargs = {"target_date": target_date}
            elif task_name == TaskName.SERVICE_LEVEL:
                task_kwargs = {"target_date_str": target_date}
            elif task_name == TaskName.SLOW_MOVERS:
                task_kwargs = {"snapshot_date_str": target_date}
            elif task_name == TaskName.INVENTORY_PLANNING:
                task_kwargs = {"snapshot_date_str": target_date}
            else:
                continue
            
            try:
                # Update tracker status
                tracker.status = CeleryTaskStatus.PENDING
                tracker.started_at = datetime.now()
                db.commit()
                
                # Dispatch task with platform-specific handling
                if _dispatch_celery_task(task_name, task_kwargs, tracker.task_id):
                    logger.debug(f"Started HOLD task: {tracker.task_name} (wave {wave_index})")
                    tasks_started += 1
                else:
                    logger.error(f"Failed to dispatch HOLD task {tracker.task_name}")
                    tracker.status = CeleryTaskStatus.FAILURE
                    tracker.error_message = "Failed to dispatch task"
                    db.commit()
                
            except Exception as e:
                logger.error(f"Failed to start HOLD task {tracker.task_name}: {e}")
                tracker.status = CeleryTaskStatus.FAILURE
                tracker.error_message = str(e)
                db.commit()
    
    return tasks_started


def has_hold_tasks(db: Session, company_id: Optional[int] = None) -> bool:
    """Check if there are any tasks in HOLD status"""
    from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
    
    query = db.query(CeleryTaskTracker).filter(
        CeleryTaskTracker.status == CeleryTaskStatus.HOLD
    )
    
    if company_id:
        query = query.filter(CeleryTaskTracker.company_id == company_id)
    
    return query.first() is not None


def _ensure_hold_poller_running(db: Session, company_id: int, batch_id: str):
    """
    Ensure the background poller for HOLD tasks is running.
    Uses a Celery task that polls every 20 seconds.
    """
    try:
        if sys.platform == 'win32':
            # Windows: use direct import and .delay()
            from src.smart_inventory.tasks.hold_task_poller import poll_hold_tasks
            poll_hold_tasks.delay(company_id=company_id, batch_id=batch_id)
        else:
            # Linux: use send_task with task name for prefork pool compatibility
            from src.utils.celery_worker import celery_app
            celery_app.send_task(
                'src.smart_inventory.tasks.hold_task_poller.poll_hold_tasks',
                kwargs={"company_id": company_id, "batch_id": batch_id}
            )
        logger.info(f"Started hold poller for batch {batch_id}")
    except Exception as e:
        logger.warning(f"Failed to start hold task poller: {e}")


# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def get_batch_status(db: Session, batch_id: str) -> Dict[str, Any]:
    """
    Get the status of all tasks in a batch.
    
    Args:
        db: Database session
        batch_id: Batch ID to check
        
    Returns:
        Dict with batch status information
    """
    from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
    
    tasks = db.query(CeleryTaskTracker).filter(
        CeleryTaskTracker.task_name.like(f"%_{batch_id}")
    ).all()
    
    status_counts = {
        "pending": 0,
        "started": 0,
        "success": 0,
        "failure": 0,
        "hold": 0,
        "total": len(tasks)
    }
    
    task_list = []
    for task in tasks:
        status_counts[task.status.value] = status_counts.get(task.status.value, 0) + 1
        task_list.append({
            "task_id": task.task_id,
            "task_name": task.task_name,
            "status": task.status.value,
            "started_at": task.started_at.isoformat() if task.started_at else None,
            "completed_at": task.completed_at.isoformat() if task.completed_at else None,
            "error_message": task.error_message
        })
    
    all_complete = status_counts["success"] + status_counts["failure"] == status_counts["total"]
    
    return {
        "batch_id": batch_id,
        "status_counts": status_counts,
        "all_complete": all_complete,
        "tasks": task_list
    }


# =============================================================================
# TASK STATUS UPDATE HELPERS
# =============================================================================

def update_task_tracker_status(
    task_id: str,
    status: str,
    error_message: Optional[str] = None
):
    """
    Update CeleryTaskTracker status for a task.
    Call this from within Celery tasks to update their tracking status.
    
    When status is 'success', automatically triggers processing of HOLD tasks
    that may have been waiting for this task to complete.
    
    Args:
        task_id: The Celery task ID
        status: New status (started, success, failure)
        error_message: Optional error message for failure status
    """
    try:
        from src.utils.db import get_db_session
        from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
        
        db = get_db_session()
        try:
            tracker = db.query(CeleryTaskTracker).filter(
                CeleryTaskTracker.task_id == task_id
            ).first()
            
            if tracker:
                # Map string status to enum
                status_map = {
                    "started": CeleryTaskStatus.STARTED,
                    "success": CeleryTaskStatus.SUCCESS,
                    "failure": CeleryTaskStatus.FAILURE,
                    "pending": CeleryTaskStatus.PENDING,
                    "retry": CeleryTaskStatus.RETRY,
                }
                
                new_status = status_map.get(status.lower())
                if new_status:
                    tracker.status = new_status
                    
                    if status.lower() == "started":
                        tracker.started_at = datetime.now()
                    elif status.lower() in ("success", "failure"):
                        tracker.completed_at = datetime.now()
                    
                    if error_message:
                        tracker.error_message = error_message[:2000]  # Truncate to fit column
                    
                    db.commit()
                    logger.debug(f"Updated task {task_id} status to {status}")
                    
                    # When a task succeeds, check if any HOLD tasks can now run
                    if status.lower() == "success" and tracker.company_id:
                        # Extract batch_id from task_name (format: task_value_YYYY-MM-DD_batch_id)
                        parts = tracker.task_name.rsplit("_", 2) if tracker.task_name else []
                        batch_id = parts[2] if len(parts) >= 3 else None
                        
                        try:
                            started = process_hold_tasks(db, tracker.company_id, batch_id)
                            if started > 0:
                                logger.info(f"Auto-started {started} HOLD tasks after {task_id} completed")
                        except Exception as hold_err:
                            logger.warning(f"Failed to process hold tasks: {hold_err}")
            else:
                logger.warning(f"No tracker found for task_id {task_id}")
        finally:
            db.close()
    except Exception as e:
        logger.error(f"Failed to update task tracker status: {e}")
