import pandas as pd
from sqlalchemy.orm import Session
from datetime import date
from src.smart_inventory.core.demand_forecaster import CompanyDemandForecaster
from src.utils.db import SessionLocal  # your SQLAlchemy session factory


def load_training_data_for_company(db: Session, company_id: int) -> pd.DataFrame:
    """
    Example: pull from daily_sales table.
    """
    query = """
        SELECT company_id, location_id, product_id, sale_date, quantity_sold
        FROM daily_sales
        WHERE company_id = %(company_id)s
          AND sale_date >= %(start_date)s
    """
    df = pd.read_sql(
        query,
        db.bind,
        params={"company_id": company_id, "start_date": date(2023, 1, 1)},
    )
    # Rename columns to match expected model columns
    df = df.rename(columns={"sale_date": "date", "quantity_sold": "qty_sold"})
    return df


def train_company_demand_model(company_id: int) -> str:
    db = SessionLocal()
    try:
        df = load_training_data_for_company(db, company_id)
        forecaster = CompanyDemandForecaster(model_dir="models")
        model_path = forecaster.train_company_model(df, company_id)
        return model_path
    finally:
        db.close()
