from __future__ import annotations
from datetime import date, timedelta

from celery import Celery
from sqlalchemy.orm import Session
import pandas as pd

import os
import sys
from datetime import datetime, date, timedelta
from typing import Optional, Dict, List, Tuple
from dataclasses import dataclass
from sqlalchemy.orm import Session
from sqlalchemy import text, and_
from sqlalchemy.exc import IntegrityError
from src.utils.celery_worker import celery_app
from src.utils.db import SessionLocal, engine
from src.smart_inventory.apps.inventory.models import DailySales, DemandForecast
from src.smart_inventory.apps.products.models import Company
from src.smart_inventory.core.demand_forecaster import CompanyDemandForecaster
from src.smart_inventory.core.train_pipeline import train_company_demand_model
# Add src to path for imports
from src.utils.db import get_db_session
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))


@celery_app.task(bind=True)
def train_demand_forecast_model(self, company_id: int) -> Dict:
    """
    Train demand forecast model for a given company
    """
    try:
        model_path = train_company_demand_model(company_id)
        
        return {
            "success": True, 
            "message": f"Demand forecast model trained successfully for company {company_id}",
            "model_path": model_path
        }
    except ValueError as e:
        return {"success": False, "error": str(e)}
    except Exception as e:
        return {"success": False, "error": f"Error training model: {str(e)}"}


# OLD CODE
# @celery_app.task(bind=True)
# def compute_demand_forecast(self, company_id: int, target_date: date) -> Dict:
#     """
#     Compute demand forecast for a given company and date
#     """
#     db = get_db_session()
#     company = db.query(Company).filter(Company.id == company_id).first()
#     if not company:
#         return {"success": False, "error": "Company not found"}
#     demand_forecaster = CompanyDemandForecaster(company.company_id)
#     demand_forecaster.compute_demand_forecast(target_date)
#     return {"success": True, "data": "Demand forecast computed successfully"}