from pydantic import BaseModel, field_validator, model_validator
from datetime import datetime, date, timedelta
from typing import Optional, List, Generic, TypeVar
from .models import PurchaseOrderStatus, InventoryBatchStatus, MovementType

# Generic type for pagination
T = TypeVar('T')


# Generic Response Model that can handle any data type
class ResponseModel(BaseModel):
    success: bool
    data: Optional[dict] = None
    message: Optional[str] = None

# Universal Frontend API Response Schemas
class PaginationInfo(BaseModel):
    page: int
    per_page: int
    total_items: int
    total_pages: int


# Error Response Schema
class ErrorResponse(BaseModel):
    success: bool = False
    message: str
    error_code: Optional[str] = None
    details: Optional[dict] = None
    data: Optional[dict] = None


# Nested object schemas for frontend responses
class CompanyInfo(BaseModel):
    id: int
    name: Optional[str] = None


class LocationInfo(BaseModel):
    id: int
    name: Optional[str] = None


class ProductInfo(BaseModel):
    id: int
    name: Optional[str] = None


# Base schemas
class SalesOrderBase(BaseModel):
    company_id: int
    location_id: int
    sold_at: datetime
    channel: str


class SalesOrderCreate(SalesOrderBase):
    pass


class SalesOrderOut(SalesOrderBase):
    id: int
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True


class SalesOrderLineBase(BaseModel):
    sales_order_id: int
    product_id: int
    quantity: int
    unit_price: float
    promotion_id: Optional[int] = None


class SalesOrderLineCreate(SalesOrderLineBase):
    pass


class SalesOrderLineOut(SalesOrderLineBase):
    id: int
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True


class InventoryBatchBase(BaseModel):
    company_id: int
    product_id: int
    location_id: int
    batch_ref: str
    quantity_on_hand: int = 0
    expiry_date: Optional[datetime] = None
    received_date: datetime
    status: InventoryBatchStatus = InventoryBatchStatus.ACTIVE


class InventoryBatchCreate(InventoryBatchBase):
    pass


class InventoryBatchOut(InventoryBatchBase):
    id: int
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True


class InventoryMovementBase(BaseModel):
    company_id: int
    product_id: int
    location_id: int
    batch_id: Optional[int] = None
    movement_type: MovementType
    quantity_delta: int
    reference: Optional[str] = None


class InventoryMovementCreate(InventoryMovementBase):
    pass


class InventoryMovementOut(InventoryMovementBase):
    id: int
    created_at: datetime

    class Config:
        from_attributes = True


class PurchaseOrderBase(BaseModel):
    company_id: int
    supplier_id: int
    location_id: int
    status: PurchaseOrderStatus = PurchaseOrderStatus.DRAFT
    expected_delivery_date: Optional[datetime] = None


class PurchaseOrderCreate(PurchaseOrderBase):
    pass


class PurchaseOrderOut(PurchaseOrderBase):
    id: int
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True


class PurchaseOrderLineBase(BaseModel):
    purchase_order_id: int
    product_id: int
    ordered_qty: int
    received_qty: int = 0
    unit_cost: float


class PurchaseOrderLineCreate(PurchaseOrderLineBase):
    pass


class PurchaseOrderLineOut(PurchaseOrderLineBase):
    id: int
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True


class DailySalesBase(BaseModel):
    product_id: int
    location_id: int
    sale_date: datetime
    quantity_sold: int = 0
    total_amount: float = 0.0


class DailySalesCreate(DailySalesBase):
    pass


class DailySalesOut(DailySalesBase):
    id: int
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True


class DailySalesExtendedOut(BaseModel):
    id: int
    company_id: int
    product_id: int
    location_id: int
    sale_date: datetime
    quantity_sold: int
    total_amount: float
    created_at: datetime
    updated_at: datetime
    product_name: Optional[str] = None
    brand_name: Optional[str] = None
    location_name: Optional[str] = None

    class Config:
        from_attributes = True


# Pagination Response Schema
class PaginatedDailySalesResponse(BaseModel):
    data: List[DailySalesExtendedOut]
    page: int
    perpage: int
    total: int


# Analytics Schemas
class MonthlyServiceLevelItem(BaseModel):
    """Single month data point for service level vs demand"""
    month: str  # Format: YYYY-MM
    demand_qty: int
    fulfilled_qty: int
    lost_sales_qty: int
    service_level: float  # 0.0 - 1.0


class ServiceLevelVsDemandData(BaseModel):
    """Data structure for service level vs demand analytics"""
    company_id: int
    location_id: Optional[int] = None
    product_id: Optional[int] = None
    month_from: str  # Format: YYYY-MM
    month_to: str  # Format: YYYY-MM
    monthly_data: List[MonthlyServiceLevelItem]


class ServiceLevelVsDemandResponse(BaseModel):
    """Response for service level vs demand analytics endpoint"""
    success: bool
    data: ServiceLevelVsDemandData
    message: str


# Service Level Daily Schemas
class ServiceLevelBase(BaseModel):
    date: datetime
    company_id: int
    location_id: int
    product_id: int
    demand_qty: float = 0.0
    fulfilled_qty: float = 0.0
    lost_sales_qty: float = 0.0
    service_level: float = 1.0


class ServiceLevelCreate(ServiceLevelBase):
    pass


class ServiceLevelOut(ServiceLevelBase):
    id: int
    created_at: datetime

    class Config:
        from_attributes = True


class ServiceLevelExtendedOut(BaseModel):
    id: int
    date: datetime
    company_id: int
    location_id: int
    product_id: int
    demand_qty: float
    fulfilled_qty: float
    lost_sales_qty: float
    service_level: float
    created_at: datetime
    product_name: Optional[str] = None
    brand_name: Optional[str] = None
    location_name: Optional[str] = None

    class Config:
        from_attributes = True


# Pagination Response Schema for Service Level
class PaginatedServiceLevelResponse(BaseModel):
    data: List[ServiceLevelExtendedOut]
    page: int
    perpage: int
    total: int


# Inventory Snapshot Daily Schemas
class InventorySnapshotBase(BaseModel):
    snapshot_date: datetime
    company_id: int
    location_id: int
    product_id: int
    on_hand_qty: float = 0.0
    inbound_qty: float = 0.0
    outbound_qty: float = 0.0


class InventorySnapshotCreate(InventorySnapshotBase):
    pass


class InventorySnapshotOut(InventorySnapshotBase):
    id: int
    created_at: datetime

    class Config:
        from_attributes = True


class InventorySnapshotExtendedOut(BaseModel):
    id: int
    snapshot_date: datetime
    company_id: int
    location_id: int
    product_id: int
    on_hand_qty: float
    inbound_qty: float
    outbound_qty: float
    created_at: datetime
    product_name: Optional[str] = None
    brand_name: Optional[str] = None
    location_name: Optional[str] = None

    class Config:
        from_attributes = True


# Pagination Response Schema for Inventory Snapshot
class PaginatedInventorySnapshotResponse(BaseModel):
    data: List[InventorySnapshotExtendedOut]
    page: int
    perpage: int
    total: int


# Slow mover schema with nested objects
class SlowMoverData(BaseModel):
    id: int
    snapshot_date: datetime
    company: CompanyInfo
    location: LocationInfo
    product: ProductInfo
    on_hand_qty: float
    total_sold_90d: float
    ads_90d: float
    doh_90d: float
    days_since_last_sale: int
    is_slow_mover: bool
    slow_mover_severity: Optional[str] = None
    slow_mover_reason: Optional[str] = None
    created_at: datetime

    class Config:
        from_attributes = True
    
    @model_validator(mode='before')
    @classmethod
    def transform_flat_data(cls, data):
        """Transform flat data structure to nested structure"""
        if isinstance(data, dict) and 'company_name' in data:
            # Transform flat data structure to nested structure
            return {
                "id": data["id"],
                "snapshot_date": data["snapshot_date"],
                "company": {
                    "id": data["company_id"],
                    "name": data.get("company_name")
                },
                "location": {
                    "id": data["location_id"],
                    "name": data.get("location_name")
                },
                "product": {
                    "id": data["product_id"],
                    "name": data.get("product_name"),
                    "sku": data.get("product_sku")
                },
                "on_hand_qty": data["on_hand_qty"],
                "total_sold_90d": data["total_sold_90d"],
                "ads_90d": data["ads_90d"],
                "doh_90d": data["doh_90d"],
                "days_since_last_sale": data["days_since_last_sale"],
                "is_slow_mover": data["is_slow_mover"],
                "slow_mover_severity": data["slow_mover_severity"],
                "slow_mover_reason": data["slow_mover_reason"],
                "created_at": data["created_at"]
            }
        return data


class SlowMoverPaginatedData(BaseModel):
    items: List[SlowMoverData]
    page: int
    perpage: int
    total: int


class SlowMoverResponse(BaseModel):
    success: bool
    data: SlowMoverPaginatedData
    message: str
    
    
class ForecastRequest(BaseModel):
    company_id: int
    location_id: int
    product_id: int
    start_date: date
    end_date: date

    @property
    def dates(self) -> List[date]:
        d = self.start_date
        out = []
        while d <= self.end_date:
            out.append(d)
            d += timedelta(days=1)
        return out


class ForecastResponseItem(BaseModel):
    date: date
    forecast_qty: float


class ForecastResponse(BaseModel):
    company_id: int
    location_id: int
    product_id: int
    forecasts: List[ForecastResponseItem]


class TrainDemandForecastRequest(BaseModel):
    company_id: int


class TrainDemandForecastResponse(BaseModel):
    task_id: str
    status: str
    message: str
    company_id: int


# Analytics Overview Schemas
class AnalyticsOverviewData(BaseModel):
    slow_mover_count: int
    over_stock_count: int
    under_stock_count: int


class AnalyticsOverviewResponse(BaseModel):
    success: bool
    data: AnalyticsOverviewData
    message: str


# KPI Summary Schemas
class KPISummaryData(BaseModel):
    """KPI summary metrics for frontend dashboard cards"""
    snapshot_date: Optional[date] = None  # Date the KPI data is for
    avg_service_level: float  # Average service level (0.0 - 1.0)
    stockouts_count: int  # Number of stockout events / out of stock products
    overstock_count: int  # Number of overstocked products
    understock_count: int  # Number of understocked products
    slow_mover_count: int  # Number of slow-moving products
    # Stock turn metrics
    total_products_tracked: int  # Total number of products being tracked
    avg_days_on_hand: Optional[float] = None  # Average days of inventory on hand


class KPISummaryResponse(BaseModel):
    success: bool
    data: KPISummaryData
    message: str


# Stock vs Demand Analytics Schemas
class MonthlyStockVsDemandItem(BaseModel):
    month: str  # Format: "YYYY-MM"
    avg_daily_demand: int
    available_stock: int


class StockVsDemandData(BaseModel):
    company_id: int
    location_id: Optional[int] = None
    product_id: Optional[int] = None
    month_from: str  # Format: YYYY-MM
    month_to: str  # Format: YYYY-MM
    monthly_data: List[MonthlyStockVsDemandItem]


class StockVsDemandResponse(BaseModel):
    success: bool
    data: StockVsDemandData
    message: str


# Inventory Details Schemas
class InventoryDetailData(BaseModel):
    company: CompanyInfo
    location: LocationInfo
    product: ProductInfo
    brand_name: Optional[str] = None
    category_name: Optional[str] = None
    is_perishable: Optional[bool] = None
    
    # Stock information
    current_stock: int
    active_batches_count: int
    earliest_expiry: Optional[datetime] = None
    status: str  # in_stock, low_stock, out_of_stock, slow_mover, dead_stock
    
    # Snapshot data
    snapshot_stock: Optional[int] = None
    snapshot_date: Optional[date] = None
    
    # Slow mover data
    is_slow_mover: bool = False
    slow_mover_severity: Optional[str] = None
    avg_daily_sales_90d: int = 0
    days_on_hand_90d: int = 0
    days_since_last_sale: Optional[int] = None
    
    # Planning data
    reorder_point: int = 0
    safety_stock: int = 0
    should_reorder: bool = False
    recommended_order_qty: int = 0
    
    # Additional planning data from InventoryPlanningSnapshot
    planning_snapshot_date: Optional[date] = None
    avg_daily_demand: int = 0
    sigma_daily_demand: int = 0
    lead_time_days: int = 0
    review_period_days: int = 0
    service_level_target: float = 0.95
    forecast_avg_daily_demand_90d: int = 0
    forecast_safety_stock_90d: int = 0
    forecasted_reorder_point_90d: int = 0
    planning_on_hand_qty: int = 0
    inbound_qty: int = 0
    available_stock: int = 0
    min_target: int = 0
    max_target: int = 0
    stock_status: Optional[str] = None
    
    # Service level
    latest_service_level: Optional[float] = None
    service_level_date: Optional[date] = None

    class Config:
        from_attributes = True
    
    @model_validator(mode='before')
    @classmethod
    def transform_flat_data(cls, data):
        """Transform flat data structure to nested structure"""
        if isinstance(data, dict) and 'company_name' in data:
            # Transform flat data structure to nested structure
            return {
                "company": {
                    "id": data["company_id"],
                    "name": data.get("company_name")
                },
                "location": {
                    "id": data["location_id"],
                    "name": data.get("location_name")
                },
                "product": {
                    "id": data["product_id"],
                    "name": data.get("product_name"),
                    "sku": str(data["product_id"])  # Using product_id as SKU
                },
                "brand_name": data.get("brand_name"),
                "category_name": data.get("category_name"),
                "is_perishable": data.get("is_perishable"),
                "current_stock": int(data["current_stock"]),
                "active_batches_count": data["active_batches_count"],
                "earliest_expiry": data.get("earliest_expiry"),
                "status": data["status"],
                "snapshot_stock": int(data["snapshot_stock"]) if data.get("snapshot_stock") is not None else None,
                "snapshot_date": data.get("snapshot_date"),
                "is_slow_mover": data.get("is_slow_mover", False),
                "slow_mover_severity": data.get("slow_mover_severity"),
                "avg_daily_sales_90d": int(data.get("avg_daily_sales_90d", 0)),
                "days_on_hand_90d": int(data.get("days_on_hand_90d", 0)),
                "days_since_last_sale": data.get("days_since_last_sale"),
                "reorder_point": int(data.get("reorder_point", 0)),
                "safety_stock": int(data.get("safety_stock", 0)),
                "should_reorder": data.get("should_reorder", False),
                "recommended_order_qty": int(data.get("recommended_order_qty", 0)),
                # Additional planning data from InventoryPlanningSnapshot
                "planning_snapshot_date": data.get("planning_snapshot_date"),
                "avg_daily_demand": int(data.get("avg_daily_demand", 0)),
                "sigma_daily_demand": int(data.get("sigma_daily_demand", 0)),
                "lead_time_days": data.get("lead_time_days", 0),
                "review_period_days": data.get("review_period_days", 0),
                "service_level_target": data.get("service_level_target", 0.95),
                "forecast_avg_daily_demand_90d": int(data.get("forecast_avg_daily_demand_90d", 0)),
                "forecast_safety_stock_90d": int(data.get("forecast_safety_stock_90d", 0)),
                "forecasted_reorder_point_90d": int(data.get("forecasted_reorder_point_90d", 0)),
                "planning_on_hand_qty": int(data.get("planning_on_hand_qty", 0)),
                "inbound_qty": int(data.get("inbound_qty", 0)),
                "available_stock": int(data.get("available_stock", 0)),
                "min_target": int(data.get("min_target", 0)),
                "max_target": int(data.get("max_target", 0)),
                "stock_status": data.get("stock_status"),
                "latest_service_level": data.get("latest_service_level"),
                "service_level_date": data.get("service_level_date")
            }
        return data


class InventoryPaginatedData(BaseModel):
    items: List[InventoryDetailData]
    page: int
    perpage: int
    total: int


class InventoryResponse(BaseModel):
    success: bool
    data: InventoryPaginatedData
    message: str


# Fastest Moving Products Schemas
class VendorInfo(BaseModel):
    id: int
    name: Optional[str] = None
    code: Optional[str] = None


class CategoryInfo(BaseModel):
    id: int
    name: Optional[str] = None


class FastestMovingProduct(BaseModel):
    rank: int
    product: ProductInfo
    location: LocationInfo
    category: Optional[CategoryInfo] = None
    vendor: Optional[VendorInfo] = None
    on_hand_qty: float
    total_sold_7d: float
    total_sold_30d: float
    total_sold_90d: float
    ads_7d: float  # Average Daily Sales (7 days)
    ads_30d: float  # Average Daily Sales (30 days)
    ads_90d: float  # Average Daily Sales (90 days)
    snapshot_date: date


class FastestMovingResponse(BaseModel):
    success: bool
    data: List[FastestMovingProduct]
    message: str


class MostStagnantProduct(BaseModel):
    rank: int
    product: ProductInfo
    location: LocationInfo
    category: Optional[CategoryInfo] = None
    vendor: Optional[VendorInfo] = None
    on_hand_qty: float
    lead_time_demand: float
    ads_90d: float  # Average Daily Sales (90 days)
    excess_stock_level: float
    days_since_last_sale: int
    doh_90d: float  # Days of Inventory on Hand
    snapshot_date: date
    
    @field_validator('on_hand_qty', 'lead_time_demand', 'ads_90d', 'excess_stock_level', 'doh_90d', mode='before')
    @classmethod
    def round_floats(cls, v):
        if v is None:
            return v
        return round(float(v), 2)


class MostStagnantResponse(BaseModel):
    success: bool
    data: List[MostStagnantProduct]
    message: str


class MostUrgentProduct(BaseModel):
    rank: int
    product: ProductInfo
    location: LocationInfo
    category: Optional[CategoryInfo] = None
    vendor: Optional[VendorInfo] = None
    on_hand_qty: float
    inbound_qty: float  # On Order
    days_of_cover: float  # Days left
    recommended_order_qty: float  # Required Order Qty
    urgency_score: float
    snapshot_date: date
    
    @field_validator('on_hand_qty', 'inbound_qty', 'days_of_cover', 'recommended_order_qty', 'urgency_score', mode='before')
    @classmethod
    def round_floats(cls, v):
        if v is None:
            return 0.0
        return round(float(v), 2)


class MostUrgentResponse(BaseModel):
    success: bool
    data: List[MostUrgentProduct]
    message: str


# Demand Forecast Schemas
class DemandForecastItem(BaseModel):
    """Single demand forecast item"""
    id: int
    company_id: int
    location_id: int
    location_name: str
    product_id: int
    product_name: str
    forecast_date: date
    target_date: date
    forecast_qty: float
    created_at: datetime

    @field_validator('forecast_qty', mode='before')
    @classmethod
    def round_forecast_qty(cls, v):
        if v is None:
            return 0.0
        return round(float(v), 2)


class PaginatedDemandForecastResponse(BaseModel):
    """Paginated response for demand forecasts"""
    data: List[DemandForecastItem]
    page: int
    perpage: int
    total: int
