from sqlalchemy import Column, Integer, String, Float, DateTime, Date, ForeignKey, Enum, Index, Date, Index, Boolean, UniqueConstraint
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
import enum

try:
    from src.utils.db import Base
    from src.smart_inventory.apps.products.models import Company
except ImportError:
    from utils.db import Base
    from smart_inventory.apps.products.models import Company


# Enums for status fields
class PurchaseOrderStatus(str, enum.Enum):
    DRAFT = "draft"
    PENDING_APPROVAL = "pending_approval"
    SENT = "sent"
    RECEIVED = "received"
    CLOSED = "closed"


class InventoryBatchStatus(str, enum.Enum):
    ACTIVE = "active"
    SOLD_OUT = "sold_out"
    EXPIRED = "expired"
    DISPOSED = "disposed"
    DONATED = "donated"


class MovementType(str, enum.Enum):
    SALE = "sale"
    RECEIPT = "receipt"
    ADJUSTMENT = "adjustment"
    TRANSFER_IN = "transfer_in"
    TRANSFER_OUT = "transfer_out"


class SalesOrder(Base):
    """Sales orders table for tracking sales transactions"""
    __tablename__ = "sales_orders"

    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    company_id = Column(Integer, ForeignKey("companies.company_id"), nullable=False, index=True)
    location_id = Column(Integer, ForeignKey("locations.location_id"), nullable=False, index=True)
    sold_at = Column(DateTime(timezone=True), nullable=False)
    channel = Column(String(50), nullable=False)  # store, online, etc.
    
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    sales_order_lines = relationship("SalesOrderLine", back_populates="sales_order", cascade="all, delete-orphan")
    location = relationship("Location", foreign_keys=[location_id])

    def __repr__(self):
        return f"<SalesOrder(id={self.id}, location_id={self.location_id}, channel='{self.channel}')>"


class SalesOrderLine(Base):
    """Sales order lines table for individual line items in sales orders"""
    __tablename__ = "sales_order_lines"

    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    company_id = Column(Integer, ForeignKey("companies.company_id"), nullable=False, index=True)
    sales_order_id = Column(Integer, ForeignKey("sales_orders.id"), nullable=False)
    product_id = Column(Integer, ForeignKey("products.product_id"), nullable=False, index=True)
    quantity = Column(Integer, nullable=False)
    unit_price = Column(Float, nullable=False)
    promotion_id = Column(Integer, nullable=True)
    
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    sales_order = relationship("SalesOrder", back_populates="sales_order_lines")
    product = relationship("Product", foreign_keys=[product_id])

    def __repr__(self):
        return f"<SalesOrderLine(id={self.id}, sales_order_id={self.sales_order_id}, product_id={self.product_id})>"


class InventoryBatch(Base):
    """Inventory batches table for LOT-level tracking (perishable & waste management)"""
    __tablename__ = "inventory_batches"

    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    company_id = Column(Integer, ForeignKey("companies.company_id"), nullable=False, index=True)
    product_id = Column(Integer, ForeignKey("products.product_id"), nullable=False, index=True)
    location_id = Column(Integer, ForeignKey("locations.location_id"), nullable=False, index=True)
    batch_ref = Column(String(100), nullable=False, unique=True, index=True)
    quantity_on_hand = Column(Integer, nullable=False, default=0)
    expiry_date = Column(DateTime(timezone=True), nullable=True)
    received_date = Column(DateTime(timezone=True), nullable=False)
    status = Column(Enum(InventoryBatchStatus), nullable=False, default=InventoryBatchStatus.ACTIVE)
    
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    product = relationship("Product", foreign_keys=[product_id])
    location = relationship("Location", foreign_keys=[location_id])
    inventory_movements = relationship("InventoryMovement", back_populates="batch", cascade="all, delete-orphan")

    def __repr__(self):
        return f"<InventoryBatch(id={self.id}, batch_ref='{self.batch_ref}', quantity={self.quantity_on_hand})>"


class InventoryMovement(Base):
    """Inventory movements table for real-time tracking of inventory changes"""
    __tablename__ = "inventory_movements"

    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    company_id = Column(Integer, ForeignKey("companies.company_id"), nullable=False, index=True)
    product_id = Column(Integer, ForeignKey("products.product_id"), nullable=False, index=True)
    location_id = Column(Integer, ForeignKey("locations.location_id"), nullable=False, index=True)
    batch_id = Column(Integer, ForeignKey("inventory_batches.id"), nullable=True)
    movement_type = Column(Enum(MovementType), nullable=False)
    quantity_delta = Column(Integer, nullable=False)  # Positive for increase, negative for decrease
    reference = Column(String(100), nullable=True)  # PO number, SO number, etc.
    
    created_at = Column(DateTime(timezone=True), server_default=func.now())

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    product = relationship("Product", foreign_keys=[product_id])
    location = relationship("Location", foreign_keys=[location_id])
    batch = relationship("InventoryBatch", back_populates="inventory_movements")

    def __repr__(self):
        return f"<InventoryMovement(id={self.id}, type='{self.movement_type}', delta={self.quantity_delta})>"


class PurchaseOrder(Base):
    """Purchase orders table for tracking purchase orders from suppliers"""
    __tablename__ = "purchase_orders"

    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    company_id = Column(Integer, ForeignKey("companies.company_id"), nullable=False, index=True)
    supplier_id = Column(Integer, nullable=False, index=True)
    location_id = Column(Integer, ForeignKey("locations.location_id"), nullable=False, index=True)
    status = Column(Enum(PurchaseOrderStatus), nullable=False, default=PurchaseOrderStatus.DRAFT)
    expected_delivery_date = Column(DateTime(timezone=True), nullable=True)
    
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    purchase_order_lines = relationship("PurchaseOrderLine", back_populates="purchase_order", cascade="all, delete-orphan")
    location = relationship("Location", foreign_keys=[location_id])

    def __repr__(self):
        return f"<PurchaseOrder(id={self.id}, supplier_id={self.supplier_id}, status='{self.status}')>"


class PurchaseOrderLine(Base):
    """Purchase order lines table for individual line items in purchase orders"""
    __tablename__ = "purchase_order_lines"

    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    company_id = Column(Integer, ForeignKey("companies.company_id"), nullable=False, index=True)
    purchase_order_id = Column(Integer, ForeignKey("purchase_orders.id"), nullable=False)
    product_id = Column(Integer, ForeignKey("products.product_id"), nullable=False, index=True)
    ordered_qty = Column(Integer, nullable=False)
    received_qty = Column(Integer, nullable=False, default=0)
    unit_cost = Column(Float, nullable=False)
    
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    purchase_order = relationship("PurchaseOrder", back_populates="purchase_order_lines")
    product = relationship("Product", foreign_keys=[product_id])

    def __repr__(self):
        return f"<PurchaseOrderLine(id={self.id}, purchase_order_id={self.purchase_order_id}, product_id={self.product_id})>"



class DailySales(Base):
    """Daily sales table for tracking aggregated sales per product per location per day"""
    __tablename__ = "daily_sales"

    id = Column(Integer, primary_key=True, index=True, autoincrement=True)
    company_id = Column(Integer, ForeignKey("companies.company_id"), nullable=False, index=True)
    product_id = Column(Integer, ForeignKey("products.product_id"), nullable=False, index=True)
    location_id = Column(Integer, ForeignKey("locations.location_id"), nullable=False, index=True)
    sale_date = Column(DateTime(timezone=True), nullable=False, index=True)
    quantity_sold = Column(Integer, nullable=False, default=0)
    total_amount = Column(Float, nullable=False, default=0.0)
    
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    product = relationship("Product", foreign_keys=[product_id])
    location = relationship("Location", foreign_keys=[location_id])

    def __repr__(self):
        return f"<DailySales(id={self.id}, product_id={self.product_id}, location_id={self.location_id}, date={self.sale_date})>"

# wil not be used now
class StockoutEvent(Base):
    """
    Logs of stockouts or lost sales events.
    Each record represents some demand that could NOT be fulfilled
    because there was no stock.
    lost_sales_qty = demand you couldn't serve.
    """

    __tablename__ = "stockout_events"

    id = Column(Integer, primary_key=True, index=True)

    company_id = Column(Integer, ForeignKey("companies.company_id"), index=True, nullable=False)
    location_id = Column(Integer, ForeignKey("locations.location_id"), index=True, nullable=False)
    product_id = Column(Integer, ForeignKey("products.product_id"), index=True, nullable=False)

    date = Column(Date, index=True, nullable=False)
    # lost_sales_qty could be:
    #  - estimated from POS / order data when orders were rejected/backordered
    #  - or approximated from traffic vs conversions
    lost_sales_qty = Column(Float, nullable=False, default=0.0)

    reason = Column(String(255), nullable=True)  # optional text

    created_at = Column(DateTime(timezone=True), server_default=func.now())

    __table_args__ = (
        Index(
            "ix_stockout_company_loc_prod_date",
            "company_id", "location_id", "product_id", "date"
        ),
    )
    
    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    location = relationship("Location", foreign_keys=[location_id])
    product = relationship("Product", foreign_keys=[product_id])


class ServiceLevelDaily(Base):
    """
    Derived table: daily service level per company/location/product.
    service_level = fulfilled_demand / total_demand
                  = qty_sold / (qty_sold + lost_sales_qty)
    """

    __tablename__ = "service_level_daily"

    id = Column(Integer, primary_key=True, index=True)

    date = Column(Date, index=True, nullable=False)

    company_id = Column(Integer, ForeignKey("companies.company_id"), index=True, nullable=False)
    location_id = Column(Integer, ForeignKey("locations.location_id"), index=True, nullable=False)
    product_id = Column(Integer, ForeignKey("products.product_id"), index=True, nullable=False)

    demand_qty = Column(Float, nullable=False, default=0.0)         # total demand (fulfilled + lost)
    fulfilled_qty = Column(Float, nullable=False, default=0.0)      # qty_sold
    lost_sales_qty = Column(Float, nullable=False, default=0.0)     # from stockout_events
    service_level = Column(Float, nullable=False, default=1.0)      # 0.0 - 1.0

    created_at = Column(DateTime(timezone=True), server_default=func.now())

    __table_args__ = (
        Index(
            "ix_servicelvl_company_loc_prod_date",
            "date", "company_id", "location_id", "product_id",
            unique=True,
        ),
    )

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    location = relationship("Location", foreign_keys=[location_id])
    product = relationship("Product", foreign_keys=[product_id])



# =============================================================================
# INVENTORY SNAPSHOT MODEL - Daily Aggregated Inventory Data
# =============================================================================
# This model stores daily snapshots of inventory levels computed from:
# 1. Previous day's snapshot (starting stock)
# 2. Today's inventory movements (transactions)
# 
# Purpose: Fast inventory reporting without real-time calculation overhead
# Computed by: Celery task (scheduled nightly)
# =============================================================================

class InventorySnapshotDaily(Base):
    """
    Snapshot of inventory at end of a given day, per company/location/product.
    Computed by Celery task from previous snapshot + today's movements.
    """

    __tablename__ = "inventory_snapshot_daily"

    id = Column(Integer, primary_key=True, index=True)

    snapshot_date = Column(Date, index=True, nullable=False)

    company_id = Column(Integer, index=True, nullable=False)
    location_id = Column(Integer, index=True, nullable=False)
    product_id = Column(Integer, index=True, nullable=False)

    on_hand_qty = Column(Float, nullable=False, default=0.0)
    inbound_qty = Column(Float, nullable=False, default=0.0)
    outbound_qty = Column(Float, nullable=False, default=0.0)

    created_at = Column(DateTime(timezone=True), server_default=func.now())

    __table_args__ = (
        Index(
            "ix_snapshot_company_loc_prod_date",
            "snapshot_date",
            "company_id",
            "location_id",
            "product_id",
            unique=True,
        ),
    )

# SLOW MOVER SNAPSHOT MODEL (includes velocity metrics for fast/slow movers)
class SlowMoverSnapshot(Base):
    __tablename__ = "slow_mover_snapshot"

    id = Column(Integer, primary_key=True, index=True)

    snapshot_date = Column(Date, nullable=False, index=True)

    company_id = Column(Integer, ForeignKey("companies.company_id"), index=True, nullable=False)
    location_id = Column(Integer, ForeignKey("locations.location_id"), index=True, nullable=False)
    product_id = Column(Integer, ForeignKey("products.product_id"), index=True, nullable=False)

    on_hand_qty = Column(Float, nullable=False)
    
    # Sales velocity metrics (7/30/90 day windows)
    total_sold_7d = Column(Float, nullable=False, default=0.0)
    total_sold_30d = Column(Float, nullable=False, default=0.0)
    total_sold_90d = Column(Float, nullable=False)
    
    ads_7d = Column(Float, nullable=False, default=0.0)   # Average Daily Sales (7 days)
    ads_30d = Column(Float, nullable=False, default=0.0)  # Average Daily Sales (30 days)
    ads_90d = Column(Float, nullable=False)               # Average Daily Sales (90 days)
    
    doh_90d = Column(Float, nullable=False)          # Days of Inventory on Hand based on ADS_90d
    days_since_last_sale = Column(Integer, nullable=False)

    is_slow_mover = Column(Boolean, nullable=False, default=False)
    slow_mover_severity = Column(String(20), nullable=True)  # e.g. 'watchlist', 'slow', 'dead'
    slow_mover_reason = Column(String(255), nullable=True)

    created_at = Column(DateTime(timezone=True), server_default=func.now())

    __table_args__ = (
        Index(
            "uq_slow_mover_snapshot_day_sku_loc",
            "snapshot_date", "company_id", "location_id", "product_id",
            unique=True
        ),
    )

    # Relationships (optional but useful)
    company = relationship("Company", foreign_keys=[company_id])
    location = relationship("Location", foreign_keys=[location_id])
    product = relationship("Product", foreign_keys=[product_id])


class DemandForecast(Base):
    __tablename__ = "demand_forecasts"

    id = Column(Integer, primary_key=True, index=True)
    company_id = Column(Integer, index=True)
    location_id = Column(Integer, index=True)
    product_id = Column(Integer, index=True)

    forecast_date = Column(Date, nullable=False)   # date when the forecast was made
    target_date = Column(Date, nullable=False)     # date being predicted

    forecast_qty = Column(Float, nullable=False)
    model_version = Column(String, nullable=True)

    created_at = Column(DateTime(timezone=True), server_default=func.now())



class InventoryPlanningSnapshot(Base):
    __tablename__ = "inventory_planning_snapshot"

    id = Column(Integer, primary_key=True, index=True)
    snapshot_date = Column(Date, index=True, nullable=False)
    company_id = Column(Integer, ForeignKey("companies.company_id"), index=True, nullable=False)
    location_id = Column(Integer, ForeignKey("locations.location_id"), index=True, nullable=False)
    product_id = Column(Integer, ForeignKey("products.product_id"), index=True, nullable=False)

    # Historical (current trend) demand stats
    avg_daily_demand = Column(Float, nullable=False, default=0.0)
    sigma_daily_demand = Column(Float, nullable=False, default=0.0)

    # Policy
    lead_time_days = Column(Integer, nullable=False, default=0)
    review_period_days = Column(Integer, nullable=False, default=0)
    service_level_target = Column(Float, nullable=False, default=0.95)

    # Current (history-based)
    current_safety_stock = Column(Float, nullable=False, default=0.0)
    current_reorder_point = Column(Float, nullable=False, default=0.0)

    # Forecast-based (next 90 days)
    forecast_avg_daily_demand_90d = Column(Float, nullable=False, default=0.0)
    forecast_safety_stock_90d = Column(Float, nullable=False, default=0.0)
    forecasted_reorder_point_90d = Column(Float, nullable=False, default=0.0)

    # Inventory position & status
    on_hand_qty = Column(Float, nullable=False, default=0.0)
    inbound_qty = Column(Float, nullable=False, default=0.0)
    available_stock = Column(Float, nullable=False, default=0.0)
    min_target = Column(Float, nullable=False, default=0.0)
    max_target = Column(Float, nullable=False, default=0.0)
    stock_status = Column(String(20), nullable=True)
    recommended_order_qty = Column(Float, nullable=False, default=0.0)
    should_reorder = Column(Boolean, nullable=False, default=False)

    # Urgency metrics
    days_of_cover = Column(Float, nullable=True)        # how many days inventory will last
    days_until_stockout = Column(Float, nullable=True)  # same as days_of_cover for now
    is_urgent = Column(Boolean, nullable=False, default=False)
    urgency_score = Column(Float, nullable=False, default=0.0)

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    location = relationship("Location", foreign_keys=[location_id])
    product = relationship("Product", foreign_keys=[product_id])

    # Unique per day/SKU/location
    __table_args__ = (
        UniqueConstraint(
            "snapshot_date", "company_id", "location_id", "product_id",
            name="uq_planning_snapshot_per_day_sku_loc"
        ),
    )

class ReorderPolicy(Base):
    __tablename__ = "reorder_policies"
 
    id = Column(Integer, primary_key=True, index=True)
 
    company_id = Column(Integer, ForeignKey("companies.company_id"), index=True, nullable=False)
    location_id = Column(Integer, ForeignKey("locations.location_id"), index=True, nullable=False)
    product_id = Column(Integer, ForeignKey("products.product_id"), index=True, nullable=False)
 
    # Core parameters
    lead_time_days = Column(Integer, nullable=False, default=7)
    review_period_days = Column(Integer, nullable=False, default=7)
    service_level_target = Column(Float, nullable=False, default=0.95)
 
    # Optional purchasing constraints
    min_order_qty = Column(Float, nullable=False, default=0.0)

 
    supplier_id = Column(Integer, nullable=True)  # optionally FK to suppliers.id
 
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())

    # Relationships
    company = relationship("Company", foreign_keys=[company_id])
    location = relationship("Location", foreign_keys=[location_id])
    product = relationship("Product", foreign_keys=[product_id])
 
    __table_args__ = (
        UniqueConstraint(
            "company_id", "location_id", "product_id",
            name="uq_reorder_policy_company_loc_prod"
        ),
    )
