"""
Celery task for training monthly demand forecast model using Random Forest.
Uses the monthly_prediction module for training.
"""
from __future__ import annotations
from typing import Dict, Optional
from datetime import datetime

from src.utils.celery_worker import celery_app
from src.smart_inventory.core.monthly_prediction.train_monthly_regression import main as train_monthly_model


@celery_app.task(bind=True, name='src.smart_inventory.tasks.monthly_demand_forecast_task.train_monthly_demand_forecast_model')
def train_monthly_demand_forecast_model(
    self,
    company_id: int,
    n_trials: int = 50
) -> Dict:
    """
    Train monthly demand forecast model using Random Forest for a specific company.
    
    This task uses the monthly regression training pipeline with:
    - model_type: 'random_forest' (hardcoded)
    - train_all_models: False (single model mode)
    - Trains on all products and locations for the specified company
    
    Args:
        company_id: Company ID to train model for (required)
        n_trials: Number of Optuna optimization trials (default: 50)
    
    Returns:
        Dict with success status and message/error
    """
    from src.utils.db import get_db
    from src.smart_inventory.apps.inventory.models import CeleryTaskTracker, CeleryTaskStatus
    
    # Get database session
    db = next(get_db())
    task_tracker = None
    task_id = self.request.id
    
    print(f"[CELERY TASK] Starting task with ID: {task_id} for company_id: {company_id}")
    
    try:
        # Get and update task status to STARTED
        task_tracker = db.query(CeleryTaskTracker).filter(
            CeleryTaskTracker.task_id == task_id
        ).first()
        
        print(f"[CELERY TASK] Found task_tracker: {task_tracker}")
        
        if task_tracker:
            task_tracker.status = CeleryTaskStatus.STARTED
            task_tracker.started_at = datetime.now()
            db.commit()
            print(f"[CELERY TASK] Updated status to STARTED")
        else:
            print(f"[CELERY TASK] WARNING: No task_tracker found for task_id: {task_id}")
        
        # Train the model
        # Train using Random Forest only (--model random_forest --single-model equivalent)
        monthly_regressor, X_train, X_test, y_train, y_test = train_monthly_model(
            company_id=company_id,
            model_type='random_forest',
            n_trials=n_trials,
            test_size=0.2,
            train_all_models=False  # Single model mode
        )
        
        # Get test metrics for the response
        test_metrics = monthly_regressor.evaluate(X_test, y_test)
        
        # Get absolute path to the saved model
        from pathlib import Path
        project_root = Path(__file__).parent.parent.parent.parent
        model_path = project_root / 'models' / f"company_{company_id}_monthly_demand_forecast_model.pkl"
        
        # Update task status to SUCCESS
        if task_tracker:
            task_tracker.status = CeleryTaskStatus.SUCCESS
            task_tracker.completed_at = datetime.now()
            db.commit()
        
        return {
            "success": True,
            "message": f"Monthly demand forecast model trained successfully for company {company_id}",
            "model_type": "random_forest",
            "company_id": company_id,
            "metrics": {
                "test_rmse": float(test_metrics.get('RMSE', 0)),
                "test_r2": float(test_metrics.get('R2', 0)),
                "test_mae": float(test_metrics.get('MAE', 0))
            },
            "model_path": str(model_path)
        }
        
    except ValueError as e:
        # Update task status to FAILURE
        try:
            if not task_tracker:
                task_tracker = db.query(CeleryTaskTracker).filter(
                    CeleryTaskTracker.task_id == self.request.id
                ).first()
            
            if task_tracker:
                task_tracker.status = CeleryTaskStatus.FAILURE
                task_tracker.completed_at = datetime.now()
                task_tracker.error_message = str(e)
                db.commit()
        except Exception as db_error:
            print(f"Failed to update task tracker: {db_error}")
        
        return {
            "success": False,
            "error": str(e),
            "company_id": company_id
        }
    except Exception as e:
        # Update task status to FAILURE
        try:
            if not task_tracker:
                task_tracker = db.query(CeleryTaskTracker).filter(
                    CeleryTaskTracker.task_id == self.request.id
                ).first()
            
            if task_tracker:
                task_tracker.status = CeleryTaskStatus.FAILURE
                task_tracker.completed_at = datetime.now()
                task_tracker.error_message = f"Error training monthly model: {str(e)}"
                db.commit()
        except Exception as db_error:
            print(f"Failed to update task tracker: {db_error}")
        
        return {
            "success": False,
            "error": f"Error training monthly model: {str(e)}",
            "company_id": company_id
        }
    finally:
        db.close()
