import sys
from pathlib import Path

# Add project root to path
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

from src.utils.db import get_db
from src.smart_inventory.apps.products.models import (
    Company, Category, Product, Location, Vendor, ProductVendor, ProductLocation, ProductPrice
)
from src.smart_inventory.apps.inventory.models import (
    PurchaseOrder, PurchaseOrderLine, InventoryBatch, 
    SalesOrder, SalesOrderLine, InventoryMovement, DailySales,
    ServiceLevelDaily, StockoutEvent, InventorySnapshotDaily, DemandForecast,
    ReorderPolicy, SlowMoverSnapshot, InventoryPlanningSnapshot
)


def show_menu():
    """Show main menu options"""
    print("=" * 50)
    print("Smart Inventory Data Deletion Tool")
    print("=" * 50)
    print("1. Delete all smart_inventory related tables")
    print("2. Select tables one by one")
    print("=" * 50)
    
    while True:
        choice = input("Choose option [1/2]: ").strip()
        if choice in ['1', '2']:
            return choice
        else:
            print("Please enter '1' or '2'.")


def confirm_deletion(table_name):
    """Ask for user confirmation before deleting table data"""
    while True:
        response = input(f"Delete all data from {table_name} table? [y/n]: ").lower().strip()
        if response in ['y', 'yes']:
            return True
        elif response in ['n', 'no']:
            return False
        else:
            print("Please enter 'y' for yes or 'n' for no.")


def delete_all_tables(db):
    """Delete all tables without confirmation"""
    tables_deleted = []
    
    # Delete in reverse order of foreign key dependencies
    tables = [
        (DemandForecast, "demand_forecasts"),
        (InventoryPlanningSnapshot, "inventory_planning_snapshot"),
        (SlowMoverSnapshot, "slow_mover_snapshot"),
        (ReorderPolicy, "reorder_policies"),
        (InventoryMovement, "inventory_movements"),
        (SalesOrderLine, "sales_order_lines"),
        (SalesOrder, "sales_orders"),
        (DailySales, "daily_sales"),
        (ServiceLevelDaily, "service_level_daily"),
        (StockoutEvent, "stockout_events"),
        (InventorySnapshotDaily, "inventory_snapshot_daily"),
        (InventoryBatch, "inventory_batches"),
        (PurchaseOrderLine, "purchase_order_lines"),
        (PurchaseOrder, "purchase_orders"),
        (ProductPrice, "product_prices"),
        (ProductLocation, "product_locations"),
        (ProductVendor, "product_vendors"),
        (Product, "products"),
        (Vendor, "vendors"),
        (Location, "locations"),
        (Category, "categories"),
        (Company, "companies")
    ]
    
    for model, table_name in tables:
        count = db.query(model).delete()
        print(f"Deleted {count} records from {table_name}")
        tables_deleted.append(table_name)
    
    return tables_deleted


def delete_selective_tables(db):
    """Delete tables with individual confirmation"""
    tables_deleted = []
    
    # Delete in reverse order of foreign key dependencies
    if confirm_deletion("demand_forecasts"):
        count = db.query(DemandForecast).delete()
        print(f"Deleted {count} demand forecasts")
        tables_deleted.append("demand_forecasts")
    else:
        print("Skipped demand forecasts")
    
    if confirm_deletion("inventory_planning_snapshot"):
        count = db.query(InventoryPlanningSnapshot).delete()
        print(f"Deleted {count} inventory planning snapshots")
        tables_deleted.append("inventory_planning_snapshot")
    else:
        print("Skipped inventory planning snapshot")
    
    if confirm_deletion("slow_mover_snapshot"):
        count = db.query(SlowMoverSnapshot).delete()
        print(f"Deleted {count} slow mover snapshots")
        tables_deleted.append("slow_mover_snapshot")
    else:
        print("Skipped slow mover snapshot")
    
    if confirm_deletion("reorder_policies"):
        count = db.query(ReorderPolicy).delete()
        print(f"Deleted {count} reorder policies")
        tables_deleted.append("reorder_policies")
    else:
        print("Skipped reorder policies")
    
    if confirm_deletion("inventory_movements"):
        count = db.query(InventoryMovement).delete()
        print(f"Deleted {count} inventory movements")
        tables_deleted.append("inventory_movements")
    else:
        print("Skipped inventory movements")
    
    if confirm_deletion("sales_order_lines"):
        count = db.query(SalesOrderLine).delete()
        print(f"Deleted {count} sales order lines")
        tables_deleted.append("sales_order_lines")
    else:
        print("Skipped sales order lines")
    
    if confirm_deletion("sales_orders"):
        count = db.query(SalesOrder).delete()
        print(f"Deleted {count} sales orders")
        tables_deleted.append("sales_orders")
    else:
        print("Skipped sales orders")
    
    if confirm_deletion("daily_sales"):
        count = db.query(DailySales).delete()
        print(f"Deleted {count} daily sales")
        tables_deleted.append("daily_sales")
    else:
        print("Skipped daily sales")
    
    if confirm_deletion("service_level_daily"):
        count = db.query(ServiceLevelDaily).delete()
        print(f"Deleted {count} service level daily records")
        tables_deleted.append("service_level_daily")
    else:
        print("Skipped service level daily")
    
    if confirm_deletion("stockout_events"):
        count = db.query(StockoutEvent).delete()
        print(f"Deleted {count} stockout events")
        tables_deleted.append("stockout_events")
    else:
        print("Skipped stockout events")
    
    if confirm_deletion("inventory_snapshot_daily"):
        count = db.query(InventorySnapshotDaily).delete()
        print(f"Deleted {count} inventory snapshot daily records")
        tables_deleted.append("inventory_snapshot_daily")
    else:
        print("Skipped inventory snapshot daily")
    
    if confirm_deletion("inventory_batches"):
        count = db.query(InventoryBatch).delete()
        print(f"Deleted {count} inventory batches")
        tables_deleted.append("inventory_batches")
    else:
        print("Skipped inventory batches")
    
    if confirm_deletion("purchase_order_lines"):
        count = db.query(PurchaseOrderLine).delete()
        print(f"Deleted {count} purchase order lines")
        tables_deleted.append("purchase_order_lines")
    else:
        print("Skipped purchase order lines")
    
    if confirm_deletion("purchase_orders"):
        count = db.query(PurchaseOrder).delete()
        print(f"Deleted {count} purchase orders")
        tables_deleted.append("purchase_orders")
    else:
        print("Skipped purchase orders")
    
    if confirm_deletion("product_prices"):
        count = db.query(ProductPrice).delete()
        print(f"Deleted {count} product prices")
        tables_deleted.append("product_prices")
    else:
        print("Skipped product prices")
    
    if confirm_deletion("product_locations"):
        count = db.query(ProductLocation).delete()
        print(f"Deleted {count} product-location relationships")
        tables_deleted.append("product_locations")
    else:
        print("Skipped product-location relationships")
    
    if confirm_deletion("product_vendors"):
        count = db.query(ProductVendor).delete()
        print(f"Deleted {count} product-vendor relationships")
        tables_deleted.append("product_vendors")
    else:
        print("Skipped product-vendor relationships")
    
    if confirm_deletion("products"):
        count = db.query(Product).delete()
        print(f"Deleted {count} products")
        tables_deleted.append("products")
    else:
        print("Skipped products")
    
    if confirm_deletion("vendors"):
        count = db.query(Vendor).delete()
        print(f"Deleted {count} vendors")
        tables_deleted.append("vendors")
    else:
        print("Skipped vendors")
    
    if confirm_deletion("locations"):
        count = db.query(Location).delete()
        print(f"Deleted {count} locations")
        tables_deleted.append("locations")
    else:
        print("Skipped locations")
    
    if confirm_deletion("categories"):
        count = db.query(Category).delete()
        print(f"Deleted {count} categories")
        tables_deleted.append("categories")
    else:
        print("Skipped categories")
    
    if confirm_deletion("companies"):
        count = db.query(Company).delete()
        print(f"Deleted {count} companies")
        tables_deleted.append("companies")
    else:
        print("Skipped companies")
    
    return tables_deleted


def delete_data_from_db():
    """Delete all data from smart inventory tables"""
    choice = show_menu()
    
    db = next(get_db())
    
    try:
        if choice == '1':
            print("\nDeleting all smart inventory tables...")
            tables_deleted = delete_all_tables(db)
        else:
            print("\nSelecting tables individually...")
            tables_deleted = delete_selective_tables(db)
        
        # Commit all deletions
        db.commit()
        
        print("\n" + "=" * 50)
        print("Data deletion completed successfully!")
        if tables_deleted:
            print(f"Tables affected: {', '.join(tables_deleted)}")
        else:
            print("No tables were modified.")
        print("=" * 50)
        
    except Exception as e:
        print(f"\nError deleting data: {e}")
        db.rollback()
        raise
    finally:
        db.close()


if __name__ == "__main__":
    delete_data_from_db()
