import csv
import random
import datetime
import os

try:
    from faker import Faker
except ImportError:
    print("Faker is not installed. Please run: pip install faker")
    exit(1)

# Initialize Faker
fake = Faker()

# Base Configuration
START_DATE = datetime.datetime.now() - datetime.timedelta(days=365)
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
OUTPUT_DIR = os.path.join(SCRIPT_DIR, "data")

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

def get_data_size_choice():
    """Get user choice for data size"""
    print("=" * 50)
    print("Smart Inventory Dummy Data Generator")
    print("=" * 50)
    print("Choose data size to generate:")
    print("1. Small dataset (1 month)")
    print("2. Medium dataset (6 months)")
    print("3. Large dataset (1 year)")
    print("=" * 50)
    
    while True:
        choice = input("Enter your choice [1/2/3]: ").strip()
        if choice in ['1', '2', '3']:
            return int(choice)
        else:
            print("Please enter 1, 2, or 3.")

def get_config_for_size(size_choice):
    """Get configuration parameters based on size choice"""
    configs = {
        1: {  # Small - 1 month
            'days': 30,
            'companies': 3,
            'products': 25,
            'locations_per_company': 2,
            'vendors_per_company': 3,
            'description': 'Small dataset (1 month)'
        },
        2: {  # Medium - 6 months
            'days': 180,
            'companies': 6,
            'products': 60,
            'locations_per_company': 3,
            'vendors_per_company': 5,
            'description': 'Medium dataset (6 months)'
        },
        3: {  # Large - 1 year
            'days': 365,
            'companies': 10,
            'products': 100,
            'locations_per_company': 5,
            'vendors_per_company': 7,
            'description': 'Large dataset (1 year)'
        }
    }
    return configs[size_choice]

# Real category data from your system
CATEGORIES = [
    (10147, "Better Homes & Gardens"),
    (10146, "Luggage"),
    (10163, "Paint Brush"),
    (10148, "category@24"),
    (10077, "Automobile"),
    (10111, "COTTON"),
    (10056, "Women's Wear"),
    (10165, "GifCard"),
    (4, "Coffee"),
    (7, "Soda"),
    (10107, "Wallet Category"),
    (30, "Natural Cheese"),
    (10140, "Women's "),
    (5, "Cold Beverages"),
    (10139, "Perfumes"),
    (10136, "Skincare"),
    (12, "Chocolates"),
    (10076, "Decorative"),
    (21, "Milk Prod"),
    (10118, "Chicken"),
    (10055, "Men's Wear"),
]

# Real product name patterns from your system
PRODUCT_NAMES = [
    "Better Homes & Gardens Stemless Cocktail Glass",
    "3 Trolley",
    "Acrylic Paint Brush",
    "AI Translation Earbuds Real Time",
    "Birdfy Smart Bird Feeder with Camera",
    "Black Recycled Plastic Hopper Wild Bird Feeder",
    "Bundle Swagger",
    "BURBERRY Her Eau de Parfum Intense",
    "Cadbury Chocolate",
    "Chicken Sausage",
    "Better Homes & Gardens River Crest Bookcase",
    "Box of Treats Mini Perfume Sampler",
    "Baggy Shirt",
    "Boho bag-multicolor",
    "Black Gel Pen",
    "Candle",
    "Candy"
]

BRANDS = ["Gucci", "Bellavita", "BURBERRY", "LACOSTE", "Cadbury", "VIP", "Kellogg Company", "Lipton", "Boat-X", "New Brand", None]

def generate_companies(config):
    """Generate companies data"""
    companies = []
    company_names = [
        "TechCorp Solutions", "Global Retail Inc", "Metro Foods Ltd", 
        "Fashion Forward Co", "Electronics Plus", "Fresh Market Chain",
        "Urban Style Group", "Quick Mart Corp", "Premium Brands Ltd", "City Shopping Network"
    ]
    
    for i in range(1, config['companies'] + 1):
        companies.append({
            "id": i,
            "company_id": 500 + i,
            "company_name": company_names[i-1] if i <= len(company_names) else f"Company {i}",
            "created_at": START_DATE.isoformat() + "Z",
            "updated_at": START_DATE.isoformat() + "Z"
        })
    
    return companies

def generate_products(companies, config):
    """Generate products with company_id foreign key"""
    products = []
    product_pricing_data = []  # Store pricing info for later
    
    for i in range(1, 101):
        cat = random.choice(CATEGORIES)
        cat_id = cat[0]
        is_perishable = cat_id in [4, 7, 12, 21, 10118, 30]  # Coffee, Soda, Chocolates, Milk Prod, Chicken, Natural Cheese
        
        # Assign to random company
        company = random.choice(companies)
        
        # Generate pricing data
        base_price = round(random.uniform(10.0, 650.0), 2)
        cost = round(base_price * random.uniform(0.4, 0.8), 2)
        markup = round(base_price - cost, 2)
        margin = round(((base_price - cost) / base_price) * 100, 2) if base_price > 0 else 0
        
        # Store pricing data for ProductPrice table generation
        product_pricing_data.append({
            "product_id": 300 + i,
            "company_id": company["company_id"],
            "cost_price_per_unit": cost,
            "markup_value": markup,
            "margin_value": margin,
            "retail_price_excl_tax": base_price
        })
        
        # Generate realistic product name
        base_name = random.choice(PRODUCT_NAMES)
        variation = random.choice(["Pack", "Set", "Large", "Small", "Medium", "Classic", "Premium", ""])
        product_name = f"{base_name} {variation}".strip() if variation else base_name
        
        product = {
            "id": i,
            "company_id": company["company_id"],
            "product_id": 300 + i,
            "product_name": product_name,
            "short_name": product_name[:15] if random.random() > 0.5 else None,
            "description": fake.sentence(nb_words=15) if random.random() > 0.6 else None,
            "brand_name": random.choice(BRANDS),
            "fk_product_category_id": cat_id,
            "eligible_for_return": random.choice([True, False]),
            "display_on_pos": random.choice([True, True, True, False]),  # mostly True
            "display_on_online_store": random.choice([True, True, False]),
            "is_perishable": is_perishable,
            "image_path": f"https://d1pfe4z4600rkk.cloudfront.net/DEV/{fake.uuid4()}" if random.random() > 0.3 else None,
            "created_at": START_DATE.isoformat() + "Z",
            "updated_at": START_DATE.isoformat() + "Z"
        }
        products.append(product)
    
    return products, product_pricing_data

def generate_dummy_data():
    # Get user choice and configuration
    size_choice = get_data_size_choice()
    config = get_config_for_size(size_choice)
    
    print(f"\nGenerating {config['description']}...")
    print(f"Companies: {config['companies']}, Products: {config['products']}, Days: {config['days']}")
    print("\nGenerating static data with realistic patterns...")
    
    # Generate companies first (needed for foreign keys)
    companies = generate_companies(config)
    products, product_pricing_data = generate_products(companies, config)
    
    # Generate Locations (stores) with company_id - dynamic per company
    locations_data = []
    location_id_counter = 1
    cities = [fake.city() for _ in range(20)]  # Generate more cities for variety
    
    for company in companies:
        num_locations = config['locations_per_company']
        for i in range(num_locations):
            city = random.choice(cities)
            locations_data.append({
                "id": location_id_counter,
                "company_id": company["company_id"],
                "location_id": 100 + location_id_counter,
                "location_name": f"{city} Store {i+1}",
                "created_at": START_DATE.isoformat() + "Z",
                "updated_at": START_DATE.isoformat() + "Z"
            })
            location_id_counter += 1

    # Generate Vendors with company_id - dynamic per company
    vendors_data = []
    vendor_id_counter = 1
    
    for company in companies:
        num_vendors = config['vendors_per_company']
        for i in range(num_vendors):
            company_name = fake.company()
            vendors_data.append({
                "id": vendor_id_counter,
                "company_id": company["company_id"],
                "vendor_id": 700 + vendor_id_counter,
                "vendor_name": company_name,
                "vendor_code": f"VEN-{vendor_id_counter:03d}",
                "created_at": START_DATE.isoformat() + "Z",
                "updated_at": START_DATE.isoformat() + "Z"
            })
            vendor_id_counter += 1
        
    # M2M Relationships with company_id
    product_vendors = []
    pv_id = 1
    for p in products:
        for v in random.sample(vendors_data, k=random.randint(1, 3)):
            # Only pair products and vendors from same company
            if p["company_id"] == v["company_id"]:
                product_vendors.append({
                    "id": pv_id,
                    "company_id": p["company_id"],
                    "product_id": p["product_id"],
                    "vendor_id": v["vendor_id"],
                    "created_at": START_DATE.isoformat() + "Z"
                })
                pv_id += 1
            
    product_locations = []
    pl_id = 1
    for p in products:
        for l in random.sample(locations_data, k=random.randint(3, 5)):
            # Only pair products and locations from same company
            if p["company_id"] == l["company_id"]:
                product_locations.append({
                    "id": pl_id,
                    "company_id": p["company_id"],
                    "product_id": p["product_id"],
                    "location_id": l["location_id"],
                    "created_at": START_DATE.isoformat() + "Z"
                })
                pl_id += 1

    # Generate ProductPrice data (location-wise pricing) with company_id
    product_prices = []
    pp_id = 1
    for pl in product_locations:
        # Find the pricing data for this product
        pricing = next((p for p in product_pricing_data if p["product_id"] == pl["product_id"]), None)
        if pricing:
            product_prices.append({
                "id": pp_id,
                "company_id": pl["company_id"],
                "product_price_id": 1000 + pp_id,
                "product_id": pl["product_id"],
                "location_id": pl["location_id"],
                "cost_price_per_unit": pricing["cost_price_per_unit"],
                "markup_value": pricing["markup_value"],
                "margin_value": pricing["margin_value"],
                "retail_price_excl_tax": pricing["retail_price_excl_tax"],
                "compare_at_price": round(pricing["retail_price_excl_tax"] * 1.15, 2) if random.random() > 0.7 else 0,
                "markup_type_name": random.choice(["Percentage", "Fixed", None]),
                "margin_type_name": random.choice(["Percentage", "Fixed", None]),
                "created_at": START_DATE.isoformat() + "Z",
                "updated_at": START_DATE.isoformat() + "Z"
            })
            pp_id += 1

    # Generate categories with company_id
    categories_data = []
    for i, c in enumerate(CATEGORIES, 1):
        company = random.choice(companies)
        categories_data.append({
            "id": i,
            "company_id": company["company_id"],
            "category_id": c[0],
            "category_name": c[1],
            "created_at": START_DATE.isoformat() + "Z",
            "updated_at": START_DATE.isoformat() + "Z"
        })

    # Transactional Data
    print("Simulating transactions (this may take a moment)...")
    
    purchase_orders = []
    purchase_order_lines = []
    inventory_batches = []
    inventory_movements = []
    sales_orders = []
    sales_order_lines = []
    reorder_policies = []
    
    # Counters
    po_id_counter = 1
    pol_id_counter = 1
    batch_id_counter = 1
    mov_id_counter = 1
    so_id_counter = 1
    sol_id_counter = 1
    rp_id_counter = 1
    
    # Inventory State: { (store_id, product_id): [batch_obj, ...] }
    inventory_state = {}

    current_date = START_DATE
    
    # Simulation Loop - use dynamic days
    for day in range(config['days']):
        if day % 30 == 0:  # Progress indicator
            print(f"Processing day {day+1}/{config['days']}...")
        current_date += datetime.timedelta(days=1)
        is_weekend = current_date.weekday() >= 5
        daily_volume_factor = 1.5 if is_weekend else 1.0
        
        # 1. Check Inventory & Reorder
        if day % 7 == 0:  # Weekly replenishment
            for store in locations_data:
                # Only get products from same company as store
                company_products = [p for p in products if p["company_id"] == store["company_id"]]
                if not company_products: continue
                
                products_to_order = random.sample(company_products, k=min(len(company_products), random.randint(5, 15)))
                if not products_to_order: continue
                
                # Only get vendors from same company
                company_vendors = [v for v in vendors_data if v["company_id"] == store["company_id"]]
                if not company_vendors: continue
                
                supplier = random.choice(company_vendors)
                
                po = {
                    "id": po_id_counter,
                    "company_id": store["company_id"],
                    "supplier_id": supplier["vendor_id"],
                    "location_id": store["location_id"],  # Changed from store_id to location_id
                    "status": "received",
                    "expected_delivery_date": (current_date + datetime.timedelta(days=2)).isoformat() + "Z",
                    "created_at": current_date.isoformat() + "Z",
                    "updated_at": current_date.isoformat() + "Z"
                }
                purchase_orders.append(po)
                
                for p in products_to_order:
                    qty = random.randint(10, 100)
                    
                    # Look up pricing data for this product
                    pricing = next((pr for pr in product_pricing_data if pr["product_id"] == p["product_id"]), None)
                    cost = pricing["cost_price_per_unit"] if pricing else 10.0  # fallback if not found
                    
                    pol = {
                        "id": pol_id_counter,
                        "company_id": store["company_id"],
                        "purchase_order_id": po["id"],
                        "product_id": p["product_id"],
                        "ordered_qty": qty,
                        "received_qty": qty,
                        "unit_cost": cost,
                        "created_at": current_date.isoformat() + "Z",
                        "updated_at": current_date.isoformat() + "Z"
                    }
                    purchase_order_lines.append(pol)
                    pol_id_counter += 1
                    
                    # Receive Inventory -> Create Batch
                    expiry = None
                    if p["is_perishable"]:
                        expiry = (current_date + datetime.timedelta(days=random.randint(7, 30))).isoformat() + "Z"
                    
                    batch = {
                        "id": batch_id_counter,
                        "company_id": store["company_id"],
                        "product_id": p["product_id"],
                        "location_id": store["location_id"],  # Changed from store_id to location_id
                        "batch_ref": f"BATCH-{current_date.strftime('%Y%m%d')}-{batch_id_counter}",
                        "quantity_on_hand": qty,
                        "expiry_date": expiry,
                        "received_date": current_date.isoformat() + "Z",
                        "status": "active",
                        "created_at": current_date.isoformat() + "Z",
                        "updated_at": current_date.isoformat() + "Z"
                    }
                    inventory_batches.append(batch)
                    
                    key = (store["location_id"], p["product_id"])
                    if key not in inventory_state:
                        inventory_state[key] = []
                    inventory_state[key].append(batch)
                    
                    mov = {
                        "id": mov_id_counter,
                        "company_id": store["company_id"],
                        "product_id": p["product_id"],
                        "location_id": store["location_id"],  # Changed from store_id to location_id
                        "batch_id": batch["id"],
                        "movement_type": "receipt",
                        "quantity_delta": qty,
                        "reference": f"PO-{po['id']}",
                        "created_at": current_date.isoformat() + "Z"
                    }
                    inventory_movements.append(mov)
                    
                    batch_id_counter += 1
                    mov_id_counter += 1
                
                po_id_counter += 1

        # 2. Generate Sales
        for store in locations_data:
            # B2B sales - fewer orders but much larger quantities
            base_orders = random.randint(3, 12)  # B2B typically has fewer but larger orders
            num_orders = int(base_orders * daily_volume_factor)
            
            for _ in range(num_orders):
                so = {
                    "id": so_id_counter,
                    "company_id": store["company_id"],
                    "location_id": store["location_id"],  # Changed from store_id to location_id
                    "sold_at": (current_date + datetime.timedelta(hours=random.randint(9, 20))).isoformat() + "Z",
                    "channel": random.choice(["store", "online", "store", "store"]),
                    "created_at": current_date.isoformat() + "Z",
                    "updated_at": current_date.isoformat() + "Z"
                }
                
                # B2B orders typically have fewer product types but bulk quantities
                num_items = random.choices(
                    [1, 2, 3, 4, 5],
                    weights=[40, 30, 15, 10, 5]  # B2B focuses on fewer product types
                )[0]
                
                # Only get products from same company as store
                company_products = [p for p in products if p["company_id"] == store["company_id"]]
                if not company_products: continue
                
                selected_products = random.sample(company_products, k=min(len(company_products), num_items))
                
                order_valid = False
                
                for p in selected_products:
                    # B2B quantities - much larger volumes
                    is_perishable = p.get("is_perishable", False)
                    if is_perishable:
                        # Perishable items: still large quantities but not as extreme
                        qty_needed = random.choices(
                            [20, 30, 50, 75, 100, 150, 200],
                            weights=[25, 25, 20, 15, 10, 3, 2]
                        )[0]
                    else:
                        # Non-perishable: very large B2B quantities
                        qty_needed = random.choices(
                            [50, 100, 150, 200, 300, 500, 750, 1000, 1500],
                            weights=[20, 25, 20, 15, 10, 5, 3, 1, 1]
                        )[0]
                    
                    # Look up pricing data for this product
                    pricing = next((pr for pr in product_pricing_data if pr["product_id"] == p["product_id"]), None)
                    price = pricing["retail_price_excl_tax"] if pricing else 20.0  # fallback if not found
                    
                    key = (store["location_id"], p["product_id"])
                    batches = inventory_state.get(key, [])
                    valid_batches = [b for b in batches if b["status"] == "active" and b["quantity_on_hand"] > 0]
                    valid_batches.sort(key=lambda x: x["expiry_date"] or "9999-12-31")
                    
                    qty_fulfilled = 0
                    
                    for b in valid_batches:
                        if qty_fulfilled >= qty_needed:
                            break
                        
                        available = b["quantity_on_hand"]
                        take = min(available, qty_needed - qty_fulfilled)
                        
                        b["quantity_on_hand"] -= take
                        qty_fulfilled += take
                        
                        mov = {
                            "id": mov_id_counter,
                            "company_id": store["company_id"],
                            "product_id": p["product_id"],
                            "location_id": store["location_id"],  # Changed from store_id to location_id
                            "batch_id": b["id"],
                            "movement_type": "sale",
                            "quantity_delta": -take,
                            "reference": f"SO-{so['id']}",
                            "created_at": so["sold_at"]
                        }
                        inventory_movements.append(mov)
                        mov_id_counter += 1
                        
                        if b["quantity_on_hand"] == 0:
                            b["status"] = "sold_out"
                            
                    if qty_fulfilled > 0:
                        sol = {
                            "id": sol_id_counter,
                            "company_id": store["company_id"],
                            "sales_order_id": so["id"],
                            "product_id": p["product_id"],
                            "quantity": qty_fulfilled,
                            "unit_price": price,
                            "promotion_id": None,
                            "created_at": so["sold_at"],
                            "updated_at": so["sold_at"]
                        }
                        sales_order_lines.append(sol)
                        sol_id_counter += 1
                        order_valid = True
                
                if order_valid:
                    sales_orders.append(so)
                    so_id_counter += 1

        # 3. Check Expiry
        for key, batches in inventory_state.items():
            for b in batches:
                if b["status"] == "active" and b["expiry_date"] and b["expiry_date"] < current_date.isoformat() + "Z":
                    b["status"] = "expired"
                    mov = {
                        "id": mov_id_counter,
                        "company_id": b["company_id"],
                        "product_id": b["product_id"],
                        "location_id": b["location_id"],  # Changed from store_id to location_id
                        "batch_id": b["id"],
                        "movement_type": "adjustment",
                        "quantity_delta": -b["quantity_on_hand"],
                        "reference": f"EXPIRY-{b['batch_ref']}",
                        "created_at": current_date.isoformat() + "Z"
                    }
                    inventory_movements.append(mov)
                    mov_id_counter += 1
                    b["quantity_on_hand"] = 0

    # Generate Reorder Policies
    print("Generating reorder policies...")
    
    # Find all unique (location_id, product_id) pairs that have inventory activity
    location_product_pairs = set()
    for batch in inventory_batches:
        location_product_pairs.add((batch["location_id"], batch["product_id"], batch["company_id"]))
    
    # Generate reorder policy for each pair
    for location_id, product_id, company_id in location_product_pairs:
        # Find the product to check if it's perishable for different policies
        product = next((p for p in products if p["product_id"] == product_id), None)
        if not product:
            continue
            
        is_perishable = product.get("is_perishable", False)
        
        # Different reorder parameters based on product type
        if is_perishable:
            # Perishable items: shorter lead times, more frequent reviews, smaller order quantities
            lead_time_days = random.randint(1, 3)
            review_period_days = random.randint(1, 7)
            min_order_qty = random.randint(10, 40)
            service_level_target = round(random.uniform(0.90, 0.98), 2)
        else:
            # Non-perishable: longer lead times, less frequent reviews, larger order quantities
            lead_time_days = random.randint(3, 14)
            review_period_days = random.randint(7, 30)
            min_order_qty = random.randint(50, 150)
            service_level_target = round(random.uniform(0.92, 0.99), 2)
        
        # Get a supplier from the same company
        company_vendors = [v["vendor_id"] for v in vendors_data if v["company_id"] == company_id]
        supplier_id = random.choice(company_vendors) if company_vendors else None
        
        policy = {
            "id": rp_id_counter,
            "company_id": company_id,
            "location_id": location_id,
            "product_id": product_id,
            "lead_time_days": lead_time_days,
            "review_period_days": review_period_days,
            "service_level_target": service_level_target,
            "min_order_qty": min_order_qty,
            "supplier_id": supplier_id,
            "created_at": START_DATE.isoformat() + "Z",
            "updated_at": START_DATE.isoformat() + "Z"
        }
        reorder_policies.append(policy)
        rp_id_counter += 1

    # Write to CSV
    print(f"Writing CSV files to {OUTPUT_DIR}...")
    
    def write_csv(filename, data):
        if not data:
            return
        filepath = os.path.join(OUTPUT_DIR, filename)
        keys = data[0].keys()
        with open(filepath, 'w', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(f, fieldnames=keys)
            writer.writeheader()
            writer.writerows(data)
        print(f"Generated {filename}: {len(data)} rows")

    write_csv("companies.csv", companies)
    write_csv("products.csv", products)
    write_csv("categories.csv", categories_data)
    write_csv("locations.csv", locations_data)
    write_csv("vendors.csv", vendors_data)
    write_csv("product_vendors.csv", product_vendors)
    write_csv("product_locations.csv", product_locations)
    write_csv("product_prices.csv", product_prices)
    
    write_csv("purchase_orders.csv", purchase_orders)
    write_csv("purchase_order_lines.csv", purchase_order_lines)
    write_csv("inventory_batches.csv", inventory_batches)
    write_csv("inventory_movements.csv", inventory_movements)
    write_csv("sales_orders.csv", sales_orders)
    write_csv("sales_order_lines.csv", sales_order_lines)
    write_csv("reorder_policies.csv", reorder_policies)
    
    print(f"\n✅ {config['description']} generation complete!")
    print(f"Summary:")
    print(f"- Companies: {len(companies)}")
    print(f"- Products: {len(products)}")
    print(f"- Locations: {len(locations_data)}")
    print(f"- Vendors: {len(vendors_data)}")
    print(f"- Purchase Orders: {len(purchase_orders)}")
    print(f"- Sales Orders: {len(sales_orders)}")
    print(f"- Reorder Policies: {len(reorder_policies)}")
    print(f"- Days simulated: {config['days']}")

if __name__ == "__main__":
    generate_dummy_data()
