from datetime import date, timedelta
import pandas as pd
from src.smart_inventory.core.demand_forecaster import CompanyDemandForecaster
from sqlalchemy.orm import Session
from src.utils.db import SessionLocal

def load_history_for_series(db: Session, company_id: int, location_id: int, product_id: int) -> pd.DataFrame:
    query = """
        SELECT company_id, location_id, product_id, sale_date, quantity_sold
        FROM daily_sales
        WHERE company_id = %(company_id)s
          AND location_id = %(location_id)s
          AND product_id = %(product_id)s
        ORDER BY sale_date
    """
    df =  pd.read_sql(query, db.bind, params={
        "company_id": company_id,
        "location_id": location_id,
        "product_id": product_id,
    })

    return df.rename(columns={"sale_date": "date", "quantity_sold": "qty_sold"})

def forecast_series(company_id: int, location_id: int, product_id: int, horizon_days: int = 30):
    db = SessionLocal()
    try:
        history_df = load_history_for_series(db, company_id, location_id, product_id)

        start = date.today() + timedelta(days=1)
        future_dates = [start + timedelta(days=i) for i in range(horizon_days)]

        forecaster = CompanyDemandForecaster(model_dir="models")
        df_preds = forecaster.predict_for_series(
            company_id=company_id,
            history_df=history_df,
            location_id=location_id,
            product_id=product_id,
            future_dates=future_dates,
        )
        return df_preds
    finally:
        db.close()
