import csv
import os
import sys
from datetime import datetime
from pathlib import Path

# Add project root to path
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from src.utils.db import get_db
from src.smart_inventory.apps.products.models import (
    Company, Category, Product, Location, Vendor, ProductVendor, ProductLocation, ProductPrice
)
from src.smart_inventory.apps.inventory.models import (
    PurchaseOrder, PurchaseOrderLine, InventoryBatch, 
    SalesOrder, SalesOrderLine, InventoryMovement,
    PurchaseOrderStatus, InventoryBatchStatus, MovementType,
    ReorderPolicy
)

# Configuration
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(SCRIPT_DIR, "data")


def parse_datetime(dt_str):
    """Parse datetime string, handling None values"""
    if dt_str is None or dt_str == '' or dt_str.lower() == 'none':
        return None
    try:
        # Handle ISO format with Z
        if dt_str.endswith('Z'):
            dt_str = dt_str[:-1] + '+00:00'
        return datetime.fromisoformat(dt_str)
    except:
        return None


def parse_bool(val):
    """Parse boolean values from CSV"""
    if isinstance(val, bool):
        return val
    if isinstance(val, str):
        return val.lower() in ('true', '1', 'yes', 't')
    return bool(val)


def parse_int(val):
    """Parse integer values, handling None"""
    if val is None or val == '' or val.lower() == 'none':
        return None
    return int(val)


def parse_float(val):
    """Parse float values, handling None"""
    if val is None or val == '' or val.lower() == 'none':
        return None
    return float(val)


def load_csv(filename):
    """Load CSV file and return list of dicts"""
    filepath = os.path.join(DATA_DIR, filename)
    if not os.path.exists(filepath):
        print(f"⚠️  File not found: {filename}")
        return []
    
    with open(filepath, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        return list(reader)


def load_companies(db):
    """Load companies data"""
    print("Loading companies...")
    data = load_csv("companies.csv")
    for row in data:
        company = Company(
            id=int(row['id']),
            company_id=int(row['company_id']),
            company_name=row['company_name'],
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(company)
    db.commit()
    print(f"✓ Loaded {len(data)} companies")


def load_categories(db):
    """Load categories data"""
    print("Loading categories...")
    data = load_csv("categories.csv")
    for row in data:
        category = Category(
            id=int(row['id']),
            company_id=int(row['company_id']),
            category_id=int(row['category_id']),
            category_name=row['category_name'],
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(category)
    db.commit()
    print(f"✓ Loaded {len(data)} categories")


def load_locations(db):
    """Load locations data"""
    print("Loading locations...")
    data = load_csv("locations.csv")
    for row in data:
        location = Location(
            id=int(row['id']),
            company_id=int(row['company_id']),
            location_id=int(row['location_id']),
            location_name=row['location_name'],
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(location)
    db.commit()
    print(f"✓ Loaded {len(data)} locations")


def load_vendors(db):
    """Load vendors data"""
    print("Loading vendors...")
    data = load_csv("vendors.csv")
    for row in data:
        vendor = Vendor(
            id=int(row['id']),
            company_id=int(row['company_id']),
            vendor_id=int(row['vendor_id']),
            vendor_name=row['vendor_name'],
            vendor_code=row['vendor_code'] if row['vendor_code'] else None,
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(vendor)
    db.commit()
    print(f"✓ Loaded {len(data)} vendors")


def load_products(db):
    """Load products data"""
    print("Loading products...")
    data = load_csv("products.csv")
    for row in data:
        product = Product(
            id=int(row['id']),
            company_id=int(row['company_id']),
            product_id=int(row['product_id']),
            product_name=row['product_name'],
            short_name=row['short_name'] if row['short_name'] and row['short_name'].lower() != 'none' else None,
            description=row['description'] if row['description'] and row['description'].lower() != 'none' else None,
            brand_name=row['brand_name'] if row['brand_name'] and row['brand_name'].lower() != 'none' else None,
            fk_product_category_id=parse_int(row['fk_product_category_id']),
            eligible_for_return=parse_bool(row['eligible_for_return']),
            display_on_pos=parse_bool(row['display_on_pos']),
            display_on_online_store=parse_bool(row['display_on_online_store']),
            is_perishable=parse_bool(row['is_perishable']),
            image_path=row['image_path'] if row['image_path'] and row['image_path'].lower() != 'none' else None,
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(product)
    db.commit()
    print(f"✓ Loaded {len(data)} products")


def load_product_vendors(db):
    """Load product-vendor relationships"""
    print("Loading product-vendor relationships...")
    data = load_csv("product_vendors.csv")
    for row in data:
        pv = ProductVendor(
            id=int(row['id']),
            company_id=int(row['company_id']),
            product_id=int(row['product_id']),
            vendor_id=int(row['vendor_id']),
            created_at=parse_datetime(row['created_at'])
        )
        db.add(pv)
    db.commit()
    print(f"✓ Loaded {len(data)} product-vendor relationships")


def load_product_locations(db):
    """Load product-location relationships"""
    print("Loading product-location relationships...")
    data = load_csv("product_locations.csv")
    for row in data:
        pl = ProductLocation(
            id=int(row['id']),
            company_id=int(row['company_id']),
            product_id=int(row['product_id']),
            location_id=int(row['location_id']),
            created_at=parse_datetime(row['created_at'])
        )
        db.add(pl)
    db.commit()
    print(f"✓ Loaded {len(data)} product-location relationships")


def load_product_prices(db):
    """Load product prices data"""
    print("Loading product prices...")
    data = load_csv("product_prices.csv")
    for row in data:
        pp = ProductPrice(
            id=int(row['id']),
            company_id=int(row['company_id']),
            product_price_id=int(row['product_price_id']),
            product_id=int(row['product_id']),
            location_id=int(row['location_id']),
            cost_price_per_unit=parse_float(row['cost_price_per_unit']),
            markup_value=parse_float(row['markup_value']),
            margin_value=parse_float(row['margin_value']),
            retail_price_excl_tax=parse_float(row['retail_price_excl_tax']),
            compare_at_price=parse_float(row['compare_at_price']),
            markup_type_name=row['markup_type_name'] if row['markup_type_name'] and row['markup_type_name'].lower() != 'none' else None,
            margin_type_name=row['margin_type_name'] if row['margin_type_name'] and row['margin_type_name'].lower() != 'none' else None,
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(pp)
    db.commit()
    print(f"✓ Loaded {len(data)} product prices")


def load_purchase_orders(db):
    """Load purchase orders data"""
    print("Loading purchase orders...")
    data = load_csv("purchase_orders.csv")
    for row in data:
        po = PurchaseOrder(
            id=int(row['id']),
            company_id=int(row['company_id']),
            supplier_id=int(row['supplier_id']),
            location_id=int(row['location_id']),  # Changed from store_id
            status=PurchaseOrderStatus(row['status']),
            expected_delivery_date=parse_datetime(row['expected_delivery_date']),
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(po)
    db.commit()
    print(f"✓ Loaded {len(data)} purchase orders")


def load_purchase_order_lines(db):
    """Load purchase order lines data"""
    print("Loading purchase order lines...")
    data = load_csv("purchase_order_lines.csv")
    for row in data:
        pol = PurchaseOrderLine(
            id=int(row['id']),
            company_id=int(row['company_id']),
            purchase_order_id=int(row['purchase_order_id']),
            product_id=int(row['product_id']),
            ordered_qty=int(row['ordered_qty']),
            received_qty=int(row['received_qty']),
            unit_cost=parse_float(row['unit_cost']),
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(pol)
    db.commit()
    print(f"✓ Loaded {len(data)} purchase order lines")


def load_inventory_batches(db):
    """Load inventory batches data"""
    print("Loading inventory batches...")
    data = load_csv("inventory_batches.csv")
    for row in data:
        batch = InventoryBatch(
            id=int(row['id']),
            company_id=int(row['company_id']),
            product_id=int(row['product_id']),
            location_id=int(row['location_id']),  # Changed from store_id
            batch_ref=row['batch_ref'],
            quantity_on_hand=int(row['quantity_on_hand']),
            expiry_date=parse_datetime(row['expiry_date']),
            received_date=parse_datetime(row['received_date']),
            status=InventoryBatchStatus(row['status']),
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(batch)
    db.commit()
    print(f"✓ Loaded {len(data)} inventory batches")


def load_sales_orders(db):
    """Load sales orders data"""
    print("Loading sales orders...")
    data = load_csv("sales_orders.csv")
    for row in data:
        so = SalesOrder(
            id=int(row['id']),
            company_id=int(row['company_id']),
            location_id=int(row['location_id']),  # Changed from store_id
            sold_at=parse_datetime(row['sold_at']),
            channel=row['channel'],
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(so)
    db.commit()
    print(f"✓ Loaded {len(data)} sales orders")


def load_sales_order_lines(db):
    """Load sales order lines data"""
    print("Loading sales order lines...")
    data = load_csv("sales_order_lines.csv")
    for row in data:
        sol = SalesOrderLine(
            id=int(row['id']),
            company_id=int(row['company_id']),
            sales_order_id=int(row['sales_order_id']),
            product_id=int(row['product_id']),
            quantity=int(row['quantity']),
            unit_price=parse_float(row['unit_price']),
            promotion_id=parse_int(row['promotion_id']),
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(sol)
    db.commit()
    print(f"✓ Loaded {len(data)} sales order lines")


def load_inventory_movements(db):
    """Load inventory movements data"""
    print("Loading inventory movements...")
    data = load_csv("inventory_movements.csv")
    for row in data:
        mov = InventoryMovement(
            id=int(row['id']),
            company_id=int(row['company_id']),
            product_id=int(row['product_id']),
            location_id=int(row['location_id']),  # Changed from store_id
            batch_id=parse_int(row['batch_id']),
            movement_type=MovementType(row['movement_type']),
            quantity_delta=int(row['quantity_delta']),
            reference=row['reference'] if row['reference'] and row['reference'].lower() != 'none' else None,
            created_at=parse_datetime(row['created_at'])
        )
        db.add(mov)
    db.commit()
    print(f"✓ Loaded {len(data)} inventory movements")


def load_reorder_policies(db):
    """Load reorder policies data"""
    print("Loading reorder policies...")
    data = load_csv("reorder_policies.csv")
    for row in data:
        policy = ReorderPolicy(
            id=int(row['id']),
            company_id=int(row['company_id']),
            location_id=int(row['location_id']),
            product_id=int(row['product_id']),
            lead_time_days=int(row['lead_time_days']),
            review_period_days=int(row['review_period_days']),
            service_level_target=float(row['service_level_target']),
            min_order_qty=float(row['min_order_qty']),
            supplier_id=parse_int(row['supplier_id']),
            created_at=parse_datetime(row['created_at']),
            updated_at=parse_datetime(row['updated_at'])
        )
        db.add(policy)
    db.commit()
    print(f"✓ Loaded {len(data)} reorder policies")


def load_data_to_db():
    """Load all CSV data into database in correct order"""
    print("=" * 50)
    print("Starting data load from CSV files...")
    print("=" * 50)
    
    db = next(get_db())
    
    try:
        # Load in order respecting foreign key constraints
        load_companies(db)  # Companies must be loaded first
        load_categories(db)
        load_locations(db)
        load_vendors(db)
        load_products(db)
        load_product_vendors(db)
        load_product_locations(db)
        load_product_prices(db)
        load_purchase_orders(db)
        load_purchase_order_lines(db)
        load_inventory_batches(db)
        load_sales_orders(db)
        load_sales_order_lines(db)
        load_inventory_movements(db)
        load_reorder_policies(db)
        
        print("\n" + "=" * 50)
        print("Data loaded successfully into smart-inventory base tables:")
        print("- Companies, Categories, Locations, Vendors")
        print("- Products, ProductVendors, ProductLocations, ProductPrices")
        print("- PurchaseOrders, PurchaseOrderLines, InventoryBatches")
        print("- SalesOrders, SalesOrderLines, InventoryMovements")
        print("- ReorderPolicies")
        print("=" * 50)
        
    except Exception as e:
        print(f"\nError loading data: {e}")
        db.rollback()
        raise
    finally:
        db.close()


if __name__ == "__main__":
    load_data_to_db()
