"""
Generate incremental daily dummy data for today based on existing database state
This script queries yesterday's inventory state and generates realistic transactions for today.

IMPORTANT: This script does NOT create or duplicate master data (companies, products, 
locations, vendors, etc.). It only generates transactional data (purchase orders, sales 
orders, inventory movements) for today, using existing master data IDs from the database.

All foreign key constraints are maintained by referencing existing IDs.
"""
import csv
import random
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
from faker import Faker

# Add the project root to the Python path
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))

try:
    from sqlalchemy import create_engine, text
    from sqlalchemy.orm import sessionmaker
    from src.utils.settings import settings
    print("[SUCCESS] Successfully imported database dependencies")
except ImportError as e:
    print(f"[ERROR] Error importing dependencies: {e}")
    sys.exit(1)

# Configuration
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
TODAY = datetime.now().date()
OUTPUT_DIR = os.path.join(SCRIPT_DIR, "data", f"data-{TODAY.strftime('%Y-%m-%d')}")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)


def get_db_session():
    """Create database session"""
    engine = create_engine(settings.DATABASE_URL)
    SessionLocal = sessionmaker(bind=engine)
    return SessionLocal()


def fetch_master_data(session):
    """
    Fetch master data IDs from database (NO duplication - read-only)
    This ensures all foreign key constraints are respected by using existing IDs
    """
    print("Fetching master data IDs from database...")
    
    # Get companies (read-only)
    companies = session.execute(text("""
        SELECT DISTINCT company_id 
        FROM companies
        ORDER BY company_id
    """)).fetchall()
    
    # Get products with their metadata (read-only)
    products = session.execute(text("""
        SELECT 
            p.product_id,
            p.company_id,
            p.is_perishable,
            COALESCE(pp.cost_price_per_unit, 10.0) as cost_price,
            COALESCE(pp.retail_price_excl_tax, 20.0) as retail_price
        FROM products p
        LEFT JOIN product_prices pp ON p.product_id = pp.product_id
        GROUP BY p.product_id, p.company_id, p.is_perishable, pp.cost_price_per_unit, pp.retail_price_excl_tax
    """)).fetchall()
    
    # Get locations (read-only)
    locations = session.execute(text("""
        SELECT location_id, company_id
        FROM locations
        ORDER BY location_id
    """)).fetchall()
    
    # Get vendors (read-only)
    vendors = session.execute(text("""
        SELECT vendor_id, company_id
        FROM vendors
        ORDER BY vendor_id
    """)).fetchall()
    
    # Get product-location mappings (read-only) - ensures only valid combinations
    product_locations = session.execute(text("""
        SELECT product_id, location_id, company_id
        FROM product_locations
        ORDER BY product_id, location_id
    """)).fetchall()
    
    print(f"  Companies: {len(companies)} (read-only, no duplication)")
    print(f"  Products: {len(products)} (read-only, no duplication)")
    print(f"  Locations: {len(locations)} (read-only, no duplication)")
    print(f"  Vendors: {len(vendors)} (read-only, no duplication)")
    print(f"  Product-Location mappings: {len(product_locations)} (FK constraints respected)")
    
    return {
        'companies': [{'company_id': c[0]} for c in companies],
        'products': [{'product_id': p[0], 'company_id': p[1], 'is_perishable': p[2], 
                     'cost_price': float(p[3]), 'retail_price': float(p[4])} for p in products],
        'locations': [{'location_id': l[0], 'company_id': l[1]} for l in locations],
        'vendors': [{'vendor_id': v[0], 'company_id': v[1]} for v in vendors],
        'product_locations': [{'product_id': pl[0], 'location_id': pl[1], 'company_id': pl[2]} 
                             for pl in product_locations]
    }


def fetch_inventory_state(session):
    """Fetch current inventory state from database"""
    print("Fetching current inventory state...")
    
    # Get active batches with current on-hand quantity
    batches = session.execute(text("""
        SELECT 
            ib.id as batch_id,
            ib.product_id,
            ib.location_id,
            ib.company_id,
            ib.quantity_on_hand,
            ib.expiry_date,
            ib.status,
            ib.batch_ref
        FROM inventory_batches ib
        WHERE ib.status = 'ACTIVE' AND ib.quantity_on_hand > 0
        ORDER BY ib.location_id, ib.product_id, ib.expiry_date
    """)).fetchall()
    
    # Get latest inventory snapshot for reorder intelligence
    snapshots = session.execute(text("""
        SELECT 
            s.product_id,
            s.location_id,
            s.company_id,
            s.on_hand_qty,
            COALESCE(sm.doh_90d, 0) as doh_90
        FROM inventory_snapshot_daily s
        LEFT JOIN slow_mover_snapshot sm 
            ON s.product_id = sm.product_id 
            AND s.location_id = sm.location_id
            AND s.company_id = sm.company_id
            AND sm.snapshot_date = (
                SELECT MAX(snapshot_date) 
                FROM slow_mover_snapshot
            )
        WHERE s.snapshot_date = (
            SELECT MAX(snapshot_date) 
            FROM inventory_snapshot_daily
        )
    """)).fetchall()
    
    print(f"  Active batches: {len(batches)}")
    print(f"  Inventory snapshots: {len(snapshots)}")
    
    inventory_state = {}
    for b in batches:
        key = (b[2], b[1])  # (location_id, product_id)
        if key not in inventory_state:
            inventory_state[key] = []
        inventory_state[key].append({
            'batch_id': b[0],
            'product_id': b[1],
            'location_id': b[2],
            'company_id': b[3],
            'quantity_on_hand': b[4],
            'expiry_date': b[5].isoformat() + 'Z' if b[5] else None,
            'status': b[6],
            'batch_ref': b[7]
        })
    
    snapshot_data = {}
    for s in snapshots:
        key = (s[1], s[0])  # (location_id, product_id)
        snapshot_data[key] = {
            'on_hand_qty': s[3],
            'doh_90': float(s[4]) if s[4] else 0
        }
    
    return inventory_state, snapshot_data


def get_next_ids(session):
    """Get next available IDs for all tables"""
    print("Getting next available IDs...")
    
    ids = {}
    
    tables = {
        'purchase_orders': 'purchase_orders',
        'purchase_order_lines': 'purchase_order_lines',
        'inventory_batches': 'inventory_batches',
        'inventory_movements': 'inventory_movements',
        'sales_orders': 'sales_orders',
        'sales_order_lines': 'sales_order_lines'
    }
    
    for key, table in tables.items():
        # Remove smart_inventory prefix from table names
        table_name = table.replace('smart_inventory.', '')
        result = session.execute(text(f"SELECT COALESCE(MAX(id), 0) + 1 FROM {table_name}")).fetchone()
        ids[key] = result[0]
    
    return ids


def generate_daily_transactions(master_data, inventory_state, snapshot_data, next_ids):
    """
    Generate transactions for today using ONLY existing master data IDs.
    No new products, companies, locations, or vendors are created.
    All foreign keys reference existing database records.
    """
    print(f"\nGenerating transactions for {TODAY}...")
    print(f"Using existing master data - FK constraints will be maintained")
    
    current_datetime = datetime.combine(TODAY, datetime.min.time())
    
    purchase_orders = []
    purchase_order_lines = []
    inventory_batches = []
    inventory_movements = []
    sales_orders = []
    sales_order_lines = []
    
    # Counters
    po_id_counter = next_ids['purchase_orders']
    pol_id_counter = next_ids['purchase_order_lines']
    batch_id_counter = next_ids['inventory_batches']
    mov_id_counter = next_ids['inventory_movements']
    so_id_counter = next_ids['sales_orders']
    sol_id_counter = next_ids['sales_order_lines']
    
    is_weekend = current_datetime.weekday() >= 5
    daily_volume_factor = 1.5 if is_weekend else 1.0
    
    # 1. Generate Purchase Orders (Reorder based on inventory state)
    print("Generating purchase orders (using existing vendor/location/product IDs)...")
    
    total_products_checked = 0
    products_with_stock = 0
    products_needing_reorder = 0
    
    for location in master_data['locations']:
        location_id = location['location_id']
        company_id = location['company_id']
        
        # Get products for this location (FK constraint: only existing product_locations)
        location_products = [
            pl for pl in master_data['product_locations']
            if pl['location_id'] == location_id and pl['company_id'] == company_id
        ]
        
        if not location_products:
            continue
        
        # Get vendors for this company (FK constraint: only existing vendors)
        company_vendors = [v for v in master_data['vendors'] if v['company_id'] == company_id]
        if not company_vendors:
            continue
        
        products_to_order = []
        
        for pl in location_products:
            total_products_checked += 1
            product_id = pl['product_id']
            key = (location_id, product_id)
            
            # Get current stock
            current_stock = sum(
                b['quantity_on_hand'] 
                for b in inventory_state.get(key, [])
                if b['status'] == 'ACTIVE'
            )
            
            if current_stock > 0:
                products_with_stock += 1
            
            # Get product details
            product = next((p for p in master_data['products'] if p['product_id'] == product_id), None)
            if not product:
                continue
            
            # Get DOH if available
            snapshot = snapshot_data.get(key, {})
            doh_90 = snapshot.get('doh_90', 0)
            
            # Determine reorder threshold based on DOH
            is_perishable = product.get('is_perishable', False)
            
            if doh_90 > 180:  # Dead mover
                safety_stock = 50 if is_perishable else 200
                reorder_threshold = safety_stock * 0.8
                reorder_probability = 0.05  # Rarely reorder
            elif doh_90 > 90:  # Slow mover
                safety_stock = 100 if is_perishable else 350
                reorder_threshold = safety_stock * 0.9
                reorder_probability = 0.15
            else:  # Normal
                safety_stock = 150 if is_perishable else 500
                reorder_threshold = safety_stock
                reorder_probability = 0.40
            
            # Reorder logic:
            # 1. Always reorder if critically low (below 50% of threshold)
            # 2. Probabilistic reorder if below threshold
            # 3. Small chance to reorder even above threshold (proactive stocking)
            critically_low = current_stock < (reorder_threshold * 0.5)
            below_threshold = current_stock < reorder_threshold
            proactive_reorder = current_stock < (safety_stock * 1.5) and random.random() < 0.10
            
            if critically_low or (below_threshold and random.random() < reorder_probability) or proactive_reorder:
                products_to_order.append(product)
                products_needing_reorder += 1
        
        # Debug output for this location
        if products_to_order:
            print(f"  Location {location_id}: {len(products_to_order)} products need reorder")
        
        # Limit products per order
        if len(products_to_order) > 20:
            products_to_order = random.sample(products_to_order, 20)
        
        if not products_to_order:
            continue
        
        # Create purchase order (FK: existing vendor_id and location_id)
        supplier = random.choice(company_vendors)
        
        po = {
            "id": po_id_counter,
            "company_id": company_id,  # FK: existing company
            "supplier_id": supplier['vendor_id'],  # FK: existing vendor
            "location_id": location_id,  # FK: existing location
            "status": "RECEIVED",
            "expected_delivery_date": (current_datetime + timedelta(days=2)).isoformat() + "Z",
            "created_at": current_datetime.isoformat() + "Z",
            "updated_at": current_datetime.isoformat() + "Z"
        }
        purchase_orders.append(po)
        
        for product in products_to_order:
            is_perishable = product.get('is_perishable', False)
            
            # Determine order quantity
            if is_perishable:
                base_qty = random.randint(80, 300)
            else:
                base_qty = random.randint(400, 1200)
            
            qty = max(base_qty, 1)
            
            pol = {
                "id": pol_id_counter,
                "company_id": company_id,  # FK: existing company
                "purchase_order_id": po["id"],  # FK: just created PO
                "product_id": product['product_id'],  # FK: existing product
                "ordered_qty": qty,
                "received_qty": qty,
                "unit_cost": product['cost_price'],
                "created_at": current_datetime.isoformat() + "Z",
                "updated_at": current_datetime.isoformat() + "Z"
            }
            purchase_order_lines.append(pol)
            pol_id_counter += 1
            
            # Create inventory batch (FK: existing product_id, location_id, company_id)
            expiry = None
            if is_perishable:
                expiry = (current_datetime + timedelta(days=random.randint(7, 30))).isoformat() + "Z"
            
            batch = {
                "id": batch_id_counter,
                "company_id": company_id,  # FK: existing company
                "product_id": product['product_id'],  # FK: existing product
                "location_id": location_id,  # FK: existing location
                "batch_ref": f"BATCH-{current_datetime.strftime('%Y%m%d')}-{batch_id_counter}",
                "quantity_on_hand": qty,
                "expiry_date": expiry if expiry else "",  # Empty string for NULL in CSV
                "received_date": current_datetime.isoformat() + "Z",
                "status": "ACTIVE",
                "created_at": current_datetime.isoformat() + "Z",
                "updated_at": current_datetime.isoformat() + "Z"
            }
            inventory_batches.append(batch)
            
            # Update inventory state for sales simulation
            key = (location_id, product['product_id'])
            if key not in inventory_state:
                inventory_state[key] = []
            inventory_state[key].append(batch)
            
            # Create inventory movement (FK: existing company, product, location, new batch)
            mov = {
                "id": mov_id_counter,
                "company_id": company_id,  # FK: existing company
                "product_id": product['product_id'],  # FK: existing product
                "location_id": location_id,  # FK: existing location
                "batch_id": batch["id"],  # FK: just created batch
                "movement_type": "RECEIPT",
                "quantity_delta": qty,
                "reference": f"PO-{po['id']}",
                "created_at": current_datetime.isoformat() + "Z"
            }
            inventory_movements.append(mov)
            
            batch_id_counter += 1
            mov_id_counter += 1
        
        po_id_counter += 1
    
    print(f"  Debug: Checked {total_products_checked} product-location combinations")
    print(f"  Debug: {products_with_stock} have current stock")
    print(f"  Debug: {products_needing_reorder} triggered reorder logic")
    print(f"  Debug: Generated {len(purchase_orders)} purchase orders")
    
    # 2. Generate Sales Orders
    print("Generating sales orders (using existing location/product IDs)...")
    
    for location in master_data['locations']:
        location_id = location['location_id']
        company_id = location['company_id']
        
        # Number of sales orders for the day
        base_orders = random.randint(2, 6)
        num_orders = int(base_orders * daily_volume_factor)
        
        # Get products available at this location (FK constraint: only existing product_locations)
        location_products = [
            p for p in master_data['products']
            if p['company_id'] == company_id
            and any(pl['product_id'] == p['product_id'] and pl['location_id'] == location_id 
                   for pl in master_data['product_locations'])
        ]
        
        if not location_products:
            continue
        
        for order_num in range(num_orders):
            so = {
                "id": so_id_counter,
                "company_id": company_id,  # FK: existing company
                "location_id": location_id,  # FK: existing location
                "sold_at": (current_datetime + timedelta(hours=random.randint(9, 20))).isoformat() + "Z",
                "channel": random.choice(["store", "online", "store", "store"]),
                "created_at": current_datetime.isoformat() + "Z",
                "updated_at": current_datetime.isoformat() + "Z"
            }
            
            # Products per order
            num_items = random.choices([1, 2, 3, 4, 5], weights=[40, 30, 15, 10, 5])[0]
            selected_products = random.sample(location_products, k=min(len(location_products), num_items))
            
            order_valid = False
            
            for product in selected_products:
                # Determine quantity based on product type
                is_perishable = product.get('is_perishable', False)
                
                if is_perishable:
                    base_qty = random.choices([5, 10, 15, 20, 30], weights=[20, 25, 20, 15, 10])[0]
                else:
                    base_qty = random.choices([10, 25, 50, 75, 100], weights=[15, 25, 25, 15, 10])[0]
                
                qty_needed = base_qty
                
                # Try to fulfill from inventory
                key = (location_id, product['product_id'])
                batches = inventory_state.get(key, [])
                valid_batches = [b for b in batches if b['status'] == 'ACTIVE' and b['quantity_on_hand'] > 0]
                valid_batches.sort(key=lambda x: x['expiry_date'] or '9999-12-31')
                
                qty_fulfilled = 0
                
                for batch in valid_batches:
                    if qty_fulfilled >= qty_needed:
                        break
                    
                    available = batch['quantity_on_hand']
                    take = min(available, qty_needed - qty_fulfilled)
                    
                    batch['quantity_on_hand'] -= take
                    qty_fulfilled += take
                    
                    mov = {
                        "id": mov_id_counter,
                        "company_id": company_id,  # FK: existing company
                        "product_id": product['product_id'],  # FK: existing product
                        "location_id": location_id,  # FK: existing location
                        "batch_id": batch["batch_id"],  # FK: existing batch
                        "movement_type": "SALE",
                        "quantity_delta": -take,
                        "reference": f"SO-{so['id']}",
                        "created_at": so["sold_at"]
                    }
                    inventory_movements.append(mov)
                    mov_id_counter += 1
                    
                    if batch['quantity_on_hand'] == 0:
                        batch['status'] = 'SOLD_OUT'
                
                if qty_fulfilled > 0:
                    sol = {
                        "id": sol_id_counter,
                        "company_id": company_id,  # FK: existing company
                        "sales_order_id": so["id"],  # FK: just created SO
                        "product_id": product['product_id'],  # FK: existing product
                        "quantity": qty_fulfilled,
                        "unit_price": product['retail_price'],
                        "promotion_id": None,
                        "created_at": so["sold_at"],
                        "updated_at": so["sold_at"]
                    }
                    sales_order_lines.append(sol)
                    sol_id_counter += 1
                    order_valid = True
            
            if order_valid:
                sales_orders.append(so)
                so_id_counter += 1
    
    # 3. Check for expired batches
    print("Checking for expired inventory...")
    
    for key, batches in inventory_state.items():
        for batch in batches:
            if (batch['status'] == 'ACTIVE' and 
                batch['expiry_date'] and 
                batch['expiry_date'] < current_datetime.isoformat() + 'Z'):
                
                batch['status'] = 'EXPIRED'
                
                if batch['quantity_on_hand'] > 0:
                    mov = {
                        "id": mov_id_counter,
                        "company_id": batch['company_id'],  # FK: existing company
                        "product_id": batch['product_id'],  # FK: existing product
                        "location_id": batch['location_id'],  # FK: existing location
                        "batch_id": batch['batch_id'],  # FK: existing batch
                        "movement_type": "adjustment",
                        "quantity_delta": -batch['quantity_on_hand'],
                        "reference": f"EXPIRY-{batch['batch_ref']}",
                        "created_at": current_datetime.isoformat() + "Z"
                    }
                    inventory_movements.append(mov)
                    mov_id_counter += 1
                    batch['quantity_on_hand'] = 0
    
    return {
        'purchase_orders': purchase_orders,
        'purchase_order_lines': purchase_order_lines,
        'inventory_batches': inventory_batches,
        'inventory_movements': inventory_movements,
        'sales_orders': sales_orders,
        'sales_order_lines': sales_order_lines
    }


def write_csv_files(data):
    """Write generated data to CSV files"""
    print(f"\nWriting CSV files to {OUTPUT_DIR}...")
    
    def write_csv(filename, rows):
        if not rows:
            print(f"  Skipped {filename}: 0 rows")
            return
        
        filepath = os.path.join(OUTPUT_DIR, filename)
        keys = list(rows[0].keys())
        
        with open(filepath, 'w', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=keys)
            writer.writeheader()
            writer.writerows(rows)
        
        print(f"  Generated {filename}: {len(rows)} rows")
    
    write_csv("purchase_orders.csv", data['purchase_orders'])
    write_csv("purchase_order_lines.csv", data['purchase_order_lines'])
    write_csv("inventory_batches.csv", data['inventory_batches'])
    write_csv("inventory_movements.csv", data['inventory_movements'])
    write_csv("sales_orders.csv", data['sales_orders'])
    write_csv("sales_order_lines.csv", data['sales_order_lines'])


def load_data_to_db():
    """Load generated CSV files into database"""
    print(f"\nLoading data from {OUTPUT_DIR} into database...")
    
    # Check if folder exists
    if not os.path.exists(OUTPUT_DIR):
        print(f"[ERROR] Folder not found: {OUTPUT_DIR}")
        print(f"Please generate data first (Option 1)")
        return False
    
    # Check if CSV files exist
    expected_files = [
        'purchase_orders.csv',
        'purchase_order_lines.csv',
        'inventory_batches.csv',
        'inventory_movements.csv',
        'sales_orders.csv',
        'sales_order_lines.csv'
    ]
    
    missing_files = []
    for file in expected_files:
        if not os.path.exists(os.path.join(OUTPUT_DIR, file)):
            missing_files.append(file)
    
    if missing_files:
        print(f"[ERROR] Missing files in {OUTPUT_DIR}:")
        for f in missing_files:
            print(f"  - {f}")
        print(f"\nPlease generate data first (Option 1)")
        return False
    
    print(f"Found all required CSV files")
    
    try:
        session = get_db_session()
        
        # Load each file
        print("\nLoading CSV files into database...")
        
        # 1. Purchase Orders
        print("  [1/6] Loading purchase_orders.csv...", end=" ")
        with open(os.path.join(OUTPUT_DIR, 'purchase_orders.csv'), 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            for row in rows:
                session.execute(text("""
                    INSERT INTO purchase_orders 
                    (id, company_id, supplier_id, location_id, status, expected_delivery_date, created_at, updated_at)
                    VALUES (:id, :company_id, :supplier_id, :location_id, :status, :expected_delivery_date, :created_at, :updated_at)
                """), row)
            print(f"{len(rows)} rows")
        
        # 2. Purchase Order Lines
        print("  [2/6] Loading purchase_order_lines.csv...", end=" ")
        with open(os.path.join(OUTPUT_DIR, 'purchase_order_lines.csv'), 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            for row in rows:
                session.execute(text("""
                    INSERT INTO purchase_order_lines
                    (id, company_id, purchase_order_id, product_id, ordered_qty, received_qty, unit_cost, created_at, updated_at)
                    VALUES (:id, :company_id, :purchase_order_id, :product_id, :ordered_qty, :received_qty, :unit_cost, :created_at, :updated_at)
                """), row)
            print(f"{len(rows)} rows")
        
        # 3. Inventory Batches
        print("  [3/6] Loading inventory_batches.csv...", end=" ")
        with open(os.path.join(OUTPUT_DIR, 'inventory_batches.csv'), 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            for row in rows:
                # Convert empty string to None for expiry_date
                if row.get('expiry_date') == '':
                    row['expiry_date'] = None
                # Handle NULL expiry_date
                row_data = row.copy()
                if row_data.get('expiry_date') == '':
                    row_data['expiry_date'] = None
                
                session.execute(text("""
                    INSERT INTO inventory_batches
                    (id, company_id, product_id, location_id, batch_ref, quantity_on_hand, expiry_date, received_date, status, created_at, updated_at)
                    VALUES (:id, :company_id, :product_id, :location_id, :batch_ref, :quantity_on_hand, :expiry_date, :received_date, :status, :created_at, :updated_at)
                """), row_data)
            print(f"{len(rows)} rows")
        
        # 4. Inventory Movements
        print("  [4/6] Loading inventory_movements.csv...", end=" ")
        with open(os.path.join(OUTPUT_DIR, 'inventory_movements.csv'), 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            for row in rows:
                session.execute(text("""
                    INSERT INTO inventory_movements
                    (id, company_id, product_id, location_id, batch_id, movement_type, quantity_delta, reference, created_at)
                    VALUES (:id, :company_id, :product_id, :location_id, :batch_id, :movement_type, :quantity_delta, :reference, :created_at)
                """), row)
            print(f"{len(rows)} rows")
        
        # 5. Sales Orders
        print("  [5/6] Loading sales_orders.csv...", end=" ")
        with open(os.path.join(OUTPUT_DIR, 'sales_orders.csv'), 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            for row in rows:
                session.execute(text("""
                    INSERT INTO sales_orders
                    (id, company_id, location_id, sold_at, channel, created_at, updated_at)
                    VALUES (:id, :company_id, :location_id, :sold_at, :channel, :created_at, :updated_at)
                """), row)
            print(f"{len(rows)} rows")
        
        # 6. Sales Order Lines
        print("  [6/6] Loading sales_order_lines.csv...", end=" ")
        with open(os.path.join(OUTPUT_DIR, 'sales_order_lines.csv'), 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            for row in rows:
                # Handle None/null for promotion_id
                if row.get('promotion_id') == '' or row.get('promotion_id') == 'None':
                    row['promotion_id'] = None
                session.execute(text("""
                    INSERT INTO sales_order_lines
                    (id, company_id, sales_order_id, product_id, quantity, unit_price, promotion_id, created_at, updated_at)
                    VALUES (:id, :company_id, :sales_order_id, :product_id, :quantity, :unit_price, :promotion_id, :created_at, :updated_at)
                """), row)
            print(f"{len(rows)} rows")
        
        # Commit all changes
        session.commit()
        session.close()
        
        print(f"\n{'='*60}")
        print(f"✅ Data loaded successfully into database!")
        print(f"{'='*60}\n")
        
        return True
        
    except Exception as e:
        print(f"\n[ERROR] Failed to load data: {str(e)}")
        import traceback
        traceback.print_exc()
        if 'session' in locals():
            session.rollback()
            session.close()
        return False


def generate_data():
    """Generate incremental data for today"""
    print(f"\n{'='*60}")
    print(f"Generating Incremental Data for {TODAY}")
    print(f"{'='*60}")
    print(f"\nIMPORTANT:")
    print(f"  - NO master data duplication (companies, products, locations, vendors)")
    print(f"  - Only transactional data generated (POs, sales, movements)")
    print(f"  - All foreign key constraints respected")
    print(f"{'='*60}\n")
    
    # Check if data already exists
    if os.path.exists(OUTPUT_DIR):
        print(f"[WARNING] Folder already exists: {OUTPUT_DIR}")
        response = input(f"Overwrite existing data? [y/N]: ").strip().lower()
        if response not in ['y', 'yes']:
            print("[CANCELLED]")
            return False
    
    try:
        # Get database session
        session = get_db_session()
        
        # Fetch master data
        master_data = fetch_master_data(session)
        
        # Fetch inventory state
        inventory_state, snapshot_data = fetch_inventory_state(session)
        
        # Get next IDs
        next_ids = get_next_ids(session)
        
        # Close session
        session.close()
        
        # Generate transactions
        daily_data = generate_daily_transactions(master_data, inventory_state, snapshot_data, next_ids)
        
        # Write to CSV
        write_csv_files(daily_data)
        
        print(f"\n{'='*60}")
        print(f"✅ Daily incremental dummy data generation complete!")
        print(f"{'='*60}")
        print(f"\nOutput directory: {OUTPUT_DIR}")
        print(f"\nSummary (Today's transactions only - no master data):")
        print(f"  Purchase Orders:      {len(daily_data['purchase_orders']):>6}")
        print(f"  Purchase Order Lines: {len(daily_data['purchase_order_lines']):>6}")
        print(f"  Inventory Batches:    {len(daily_data['inventory_batches']):>6}")
        print(f"  Inventory Movements:  {len(daily_data['inventory_movements']):>6}")
        print(f"  Sales Orders:         {len(daily_data['sales_orders']):>6}")
        print(f"  Sales Order Lines:    {len(daily_data['sales_order_lines']):>6}")
        print(f"\nAll foreign keys reference existing database records:")
        print(f"  ✓ company_id (FK to companies)")
        print(f"  ✓ product_id (FK to products)")
        print(f"  ✓ location_id (FK to locations)")
        print(f"  ✓ vendor_id (FK to vendors)")
        print(f"  ✓ batch_id (FK to inventory_batches)")
        print(f"{'='*60}\n")
        
        return True
        
    except Exception as e:
        print(f"\n[ERROR] {str(e)}")
        import traceback
        traceback.print_exc()
        return False


def main():
    """Main entry point with menu options"""
    print(f"\n{'='*60}")
    print(f"Daily Incremental Dummy Data Manager for {TODAY}")
    print(f"{'='*60}\n")
    
    # Check if data folder exists
    data_exists = os.path.exists(OUTPUT_DIR)
    
    if data_exists:
        print(f"✓ Data folder exists: {OUTPUT_DIR}")
        # Check if CSV files exist
        csv_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith('.csv')]
        if csv_files:
            print(f"  Found {len(csv_files)} CSV files")
    else:
        print(f"✗ Data folder not found: {OUTPUT_DIR}")
        print(f"  Please generate data first\n")
    
    print(f"\nOptions:")
    print(f"  1. Generate incremental data for today")
    print(f"  2. Load generated data into database")
    print(f"  3. Generate and load (both steps)")
    print(f"  4. Exit")
    print(f"{'='*60}\n")
    
    while True:
        choice = input("Enter your choice [1/2/3/4]: ").strip()
        
        if choice == '1':
            print()
            generate_data()
            break
        elif choice == '2':
            print()
            if not data_exists:
                print(f"[ERROR] No data folder found. Please generate data first (Option 1)")
            else:
                load_data_to_db()
            break
        elif choice == '3':
            print()
            if generate_data():
                print(f"\nProceeding to load data into database...\n")
                load_data_to_db()
            break
        elif choice == '4':
            print("\n[CANCELLED]")
            break
        else:
            print("Invalid choice. Please enter 1, 2, 3, or 4.")


if __name__ == "__main__":
    main()
