#!/usr/bin/env python3
"""
Smart Inventory - Daily Snapshot Task
====================================

Celery task for computing daily inventory snapshots.

This task computes inventory levels by:
1. Taking previous day's snapshot as starting point
2. Applying all inventory movements for the current day
3. Computing final on-hand, inbound, outbound quantities
4. Storing results in inventory_snapshot_daily table

Formula:
--------
today_on_hand = previous_day_on_hand + sum(today_movements)
today_inbound = sum(positive_movements_today)  
today_outbound = sum(abs(negative_movements_today))
"""

import os
import sys
from datetime import datetime, date, timedelta
from typing import Optional, Dict, List, Tuple
from dataclasses import dataclass
from sqlalchemy.orm import Session
from sqlalchemy import text, and_
from sqlalchemy.exc import IntegrityError

# Add src to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))

from src.utils.db import SessionLocal, engine
#from src.smart_inventory.apps.inventory.models import InventorySnapshotDaily, InventoryMovement

@dataclass
class MovementSummary:
    """Summary of movements for a combination"""
    inbound_qty: float = 0.0
    outbound_qty: float = 0.0
    net_movement: float = 0.0

@dataclass 
class SnapshotComputationResult:
    """Result of snapshot computation"""
    records_processed: int = 0
    records_created: int = 0
    records_updated: int = 0
    computation_time: float = 0.0
    success: bool = False
    error_message: Optional[str] = None

def create_celery_task():
    """Create the celery task dynamically to avoid circular imports"""
    try:
        from src.utils.celery_worker import celery_app
        
        @celery_app.task(name="src.smart_inventory.tasks.snapshot_task.compute_inventory_snapshot", bind=True)
        def compute_inventory_snapshot_task(self, target_date: Optional[str] = None) -> Dict:
            return compute_inventory_snapshot_impl(self, target_date)
        
        return compute_inventory_snapshot_task
    except ImportError as e:
        # This is expected when running standalone - not an error
        return None

def compute_inventory_snapshot_impl(self, target_date: Optional[str] = None) -> Dict:
    """
    Celery task to compute daily inventory snapshots for a given date.
    
    Args:
        target_date: Date string in YYYY-MM-DD format. If None, uses yesterday.
        
    Returns:
        Dict with computation results and statistics
        
    Raises:
        Exception: If computation fails
    """
    
    start_time = datetime.now()
    task_id = getattr(self, 'request', {}).get('id', 'manual') if self else 'manual'
    
    print(f"[Task {task_id}] Starting inventory snapshot computation...")
    
    # Parse target date
    if target_date:
        try:
            target_date_obj = datetime.strptime(target_date, '%Y-%m-%d').date()
        except ValueError:
            error_msg = f"Invalid date format: {target_date}. Expected YYYY-MM-DD"
            print(f"[Task {task_id}] ERROR: {error_msg}")
            return {"success": False, "error": error_msg}
    else:
        target_date_obj = date.today() - timedelta(days=1)
    
    previous_date = target_date_obj - timedelta(days=1)
    
    print(f"[Task {task_id}] Target Date: {target_date_obj}")
    print(f"[Task {task_id}] Previous Date: {previous_date}")
    
    db = SessionLocal()
    
    try:
        result = _compute_snapshots_for_date(db, target_date_obj, previous_date)
        
        computation_time = (datetime.now() - start_time).total_seconds()
        result.computation_time = computation_time
        
        if result.success:
            print(f"[Task {task_id}] Computation completed in {computation_time:.2f}s")
            print(f"   Records processed: {result.records_processed}")
            print(f"   Records created: {result.records_created}")
            print(f"   Records updated: {result.records_updated}")
            
            return {
                "success": True,
                "records_processed": result.records_processed,
                "records_created": result.records_created,
                "records_updated": result.records_updated,
                "computation_time": computation_time,
                "target_date": target_date_obj.isoformat()
            }
        else:
            print(f"[Task {task_id}] Computation failed: {result.error_message}")
            return {
                "success": False,
                "error": result.error_message,
                "computation_time": computation_time
            }
            
    except Exception as e:
        error_msg = f"Unexpected error during snapshot computation: {str(e)}"
        print(f"[Task {task_id}] ERROR: {error_msg}")
        return {"success": False, "error": error_msg}
        
    finally:
        db.close()

# =============================================================================
# CORE COMPUTATION LOGIC
# =============================================================================

def _compute_snapshots_for_date(
    db: Session, 
    target_date: date, 
    previous_date: date
) -> SnapshotComputationResult:
    """
    Core logic to compute inventory snapshots for a specific date.
    
    Step 1: Load previous day's snapshots as baseline
    Step 1b: If no previous snapshots exist, compute cumulative from ALL historical movements
    Step 2: Apply all inventory movements for the target date
    Step 3: Delete previous day's snapshots (rolling window)
    Step 4: Create new snapshots for target date
    """
    
    result = SnapshotComputationResult()
    
    try:
        print(f"Loading previous snapshots from {previous_date}...")
        
        # Step 1: Load all combinations from previous date as baseline
        # importing models here to avoid circular imports
        from src.smart_inventory.apps.inventory.models import InventorySnapshotDaily, InventoryMovement
        from sqlalchemy import func

        previous_snapshots = db.query(InventorySnapshotDaily).filter(
            InventorySnapshotDaily.snapshot_date == previous_date
        ).all()
        
        baseline_data = {}
        for snapshot in previous_snapshots:
            key = (snapshot.company_id, snapshot.location_id, snapshot.product_id)
            baseline_data[key] = {
                'on_hand_qty': float(snapshot.on_hand_qty),
                'inbound_qty': 0.0,  # Reset for current day
                'outbound_qty': 0.0  # Reset for current day
            }
        
        # **added as negetive on hand quantity was being computed
        # Step 1b: If no previous snapshots exist, compute cumulative baseline from ALL historical movements
        if len(baseline_data) == 0:
            print(f"No previous snapshots found for {previous_date}. Computing cumulative baseline from historical movements...")
            
            # Get all movements BEFORE the target date (up to and including previous_date)
            historical_movements = db.query(InventoryMovement).filter(
                func.date(InventoryMovement.created_at) < target_date
            ).all()
            
            print(f"Found {len(historical_movements)} historical movements before {target_date}")
            
            # Compute cumulative on_hand_qty for each combination
            for movement in historical_movements:
                key = (movement.company_id, movement.location_id, movement.product_id)
                
                if key not in baseline_data:
                    baseline_data[key] = {
                        'on_hand_qty': 0.0,
                        'inbound_qty': 0.0,
                        'outbound_qty': 0.0
                    }
                
                # Sum all quantity_delta to get cumulative on_hand
                baseline_data[key]['on_hand_qty'] += float(movement.quantity_delta)
            
            print(f"Computed baseline for {len(baseline_data)} product/location combinations from history")
        

        print(f"Loaded {len(baseline_data)} baseline combinations")
        
        # Step 2: Load and apply movements for target date
        print(f"Loading movements for {target_date}...")
        
        # Note: Using created_at date instead of movement_date since the model uses created_at
        movements = db.query(InventoryMovement).filter(
            func.date(InventoryMovement.created_at) == target_date
        ).all()
        
        print(f"Found {len(movements)} movements for {target_date}")
        
        # Process movements and apply to baseline
        movement_summaries = {}
        for movement in movements:
            key = (movement.company_id, movement.location_id, movement.product_id)
            
            if key not in movement_summaries:
                movement_summaries[key] = MovementSummary()
            
            summary = movement_summaries[key]
            qty = float(movement.quantity_delta)  # Using quantity_delta from the model
            
            if qty > 0:
                summary.inbound_qty += qty
            else:
                summary.outbound_qty += abs(qty)
            
            summary.net_movement += qty
        
        # Step 3: Delete previous day's snapshots (rolling window approach)
        previous_count = len(previous_snapshots)
        if previous_count > 0:
            print(f"Deleting {previous_count} previous snapshots for {previous_date}...")
            db.query(InventorySnapshotDaily).filter(
                InventorySnapshotDaily.snapshot_date == previous_date
            ).delete()
            db.commit()
        
        # Step 4: Delete any existing snapshots for target date (cleanup)
        existing_count = db.query(InventorySnapshotDaily).filter(
            InventorySnapshotDaily.snapshot_date == target_date
        ).count()
        
        if existing_count > 0:
            print(f"Deleting {existing_count} existing snapshots for {target_date}...")
            db.query(InventorySnapshotDaily).filter(
                InventorySnapshotDaily.snapshot_date == target_date
            ).delete()
            db.commit()
        
        # Step 5: Create new snapshots for target date
        records_to_insert = []
        current_time = datetime.now()
        
        # Process combinations that have movements
        for key, movement_summary in movement_summaries.items():
            company_id, location_id, product_id = key
            
            # Get baseline (previous day's on-hand, or 0 if new combination)
            baseline = baseline_data.get(key, {
                'on_hand_qty': 0.0,
                'inbound_qty': 0.0,
                'outbound_qty': 0.0
            })
            
            # Calculate new on-hand quantity
            new_on_hand = baseline['on_hand_qty'] + movement_summary.net_movement
            
            record = InventorySnapshotDaily(
                snapshot_date=target_date,
                company_id=company_id,
                location_id=location_id,
                product_id=product_id,
                on_hand_qty=new_on_hand,
                inbound_qty=movement_summary.inbound_qty,
                outbound_qty=movement_summary.outbound_qty,
                created_at=current_time
            )
            records_to_insert.append(record)
            
            print(f"WITH Movement ({company_id},{location_id},{product_id}): {baseline['on_hand_qty']} + {movement_summary.net_movement} = {new_on_hand}")
        
        # Process combinations that have NO movements (carry forward)
        for key, baseline in baseline_data.items():
            if key not in movement_summaries:  # No movement for this combination
                company_id, location_id, product_id = key
                
                record = InventorySnapshotDaily(
                    snapshot_date=target_date,
                    company_id=company_id,
                    location_id=location_id,
                    product_id=product_id,
                    on_hand_qty=baseline['on_hand_qty'],
                    inbound_qty=0.0,
                    outbound_qty=0.0,
                    created_at=current_time
                )
                records_to_insert.append(record)
                
                print(f"NO Movement ({company_id},{location_id},{product_id}): {baseline['on_hand_qty']} (carry forward)")
        
        # Bulk insert all records
        if records_to_insert:
            db.add_all(records_to_insert)
            db.commit()
            print(f"Created {len(records_to_insert)} snapshot records for {target_date}")
        else:
            print(f"No snapshot records to create for {target_date}")
        
        result.records_processed = len(baseline_data) + len(movement_summaries)
        result.records_created = len(records_to_insert)
        result.records_updated = 0
        result.success = True
        
        return result
        
    except Exception as e:
        db.rollback()
        result.success = False
        result.error_message = str(e)
        print(f"ERROR: Snapshot computation failed: {e}")
        return result

# =============================================================================
# TEST FUNCTION FOR MANUAL TESTING
# =============================================================================

def test_snapshot_computation(target_date: Optional[str] = None):
    """
    Test function for manual snapshot computation (bypasses Celery)
    
    Args:
        target_date: Date string in YYYY-MM-DD format. If None, uses yesterday.
    """
    
    print("Running manual snapshot computation test...")
    
    # Call the implementation directly without Celery
    result_dict = compute_inventory_snapshot_impl(None, target_date)
    
    if result_dict["success"]:
        print(f"SUCCESS: Computation completed in {result_dict.get('computation_time', 0):.2f}s")
        print(f"   Records processed: {result_dict.get('records_processed', 0)}")
        print(f"   Records created: {result_dict.get('records_created', 0)}")
        print(f"   Records updated: {result_dict.get('records_updated', 0)}")
    else:
        print(f"ERROR: Computation failed: {result_dict.get('error', 'Unknown error')}")

def manual_test_core_logic(target_date: Optional[str] = None):
    """
    Direct test of core logic without any Celery wrapper
    """
    print("Running direct core logic test...")
    
    # Parse target date
    if target_date:
        try:
            target_date_obj = datetime.strptime(target_date, '%Y-%m-%d').date()
        except ValueError:
            print(f"ERROR: Invalid date format: {target_date}. Expected YYYY-MM-DD")
            return
    else:
        target_date_obj = date.today() - timedelta(days=1)
    
    previous_date = target_date_obj - timedelta(days=1)
    
    print(f"Target Date: {target_date_obj}")
    print(f"Previous Date: {previous_date}")
    
    db = SessionLocal()
    
    try:
        start_time = datetime.now()
        result = _compute_snapshots_for_date(db, target_date_obj, previous_date)
        computation_time = (datetime.now() - start_time).total_seconds()
        
        if result.success:
            print(f"SUCCESS: Computation completed in {computation_time:.2f}s")
            print(f"   Records processed: {result.records_processed}")
            print(f"   Records created: {result.records_created}")
            print(f"   Records updated: {result.records_updated}")
        else:
            print(f"ERROR: Computation failed: {result.error_message}")
            
    except Exception as e:
        print(f"ERROR: Unexpected error: {e}")
        
    finally:
        db.close()


# Create and register the celery task
compute_inventory_snapshot = create_celery_task()