from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
from . import schemas, controller
import logging

try:
    from src.utils.db import get_db
except ImportError:
    from utils.db import get_db

# Import Celery task with error handling
try:
    from src.smart_inventory.tasks.products_fetch_task import product_fetch
    from src.utils.celery_worker import celery_app
    celery_available = True
except ImportError as e:
    print(f"Warning: Could not import Celery components: {e}")
    product_fetch = None
    celery_app = None
    celery_available = False



logger = logging.getLogger(__name__)

router = APIRouter()


@router.get("/products", response_model=List[schemas.ProductOut])
def get_products(
    skip: int = Query(0, ge=0, description="Number of products to skip"),
    limit: int = Query(100, ge=1, le=1000, description="Number of products to return"),
    db: Session = Depends(get_db)
):
    """Get products with pagination"""
    try:
        products = controller.get_products(db, skip=skip, limit=limit)
        return products
    except Exception as e:
        logger.error(f"Error fetching products: {str(e)}")
        raise HTTPException(status_code=500, detail="Error fetching products")


@router.get("/locations", response_model=List[schemas.LocationOut])
def get_locations(db: Session = Depends(get_db)):
    """Get all locations"""
    try:
        locations = controller.get_locations(db)
        return locations
    except Exception as e:
        logger.error(f"Error fetching locations: {str(e)}")
        raise HTTPException(status_code=500, detail="Error fetching locations")


@router.get("/vendors", response_model=List[schemas.VendorOut])
def get_vendors(db: Session = Depends(get_db)):
    """Get all vendors"""
    try:
        vendors = controller.get_vendors(db)
        return vendors
    except Exception as e:
        logger.error(f"Error fetching vendors: {str(e)}")
        raise HTTPException(status_code=500, detail="Error fetching vendors")


@router.get("/categories", response_model=List[schemas.CategoryOut])
def get_categories(db: Session = Depends(get_db)):
    """Get all categories"""
    try:
        categories = controller.get_categories(db)
        return categories
    except Exception as e:
        logger.error(f"Error fetching categories: {str(e)}")
        raise HTTPException(status_code=500, detail="Error fetching categories")


@router.get("/products/{product_id}/vendors", response_model=List[schemas.VendorOut])
def get_product_vendors(product_id: int, db: Session = Depends(get_db)):
    """Get vendors for a specific product with detailed vendor information"""
    try:
        product_vendors = controller.get_product_vendors(db, product_id)
        return product_vendors
    except Exception as e:
        logger.error(f"Error fetching product vendors: {str(e)}")
        raise HTTPException(status_code=500, detail="Error fetching product vendors")


@router.get("/products/{product_id}/locations", response_model=List[schemas.ProductLocationOut])
def get_product_locations(product_id: int, db: Session = Depends(get_db)):
    """Get locations for a specific product"""
    try:
        product_locations = controller.get_product_locations(db, product_id)
        return product_locations
    except Exception as e:
        logger.error(f"Error fetching product locations: {str(e)}")
        raise HTTPException(status_code=500, detail="Error fetching product locations")


@router.get("/products/prices", response_model=List[schemas.ProductPriceOut])
def get_all_prices(
    skip: int = Query(0, ge=0, description="Number of prices to skip"),
    limit: int = Query(100, ge=1, le=1000, description="Number of prices to return"),
    db: Session = Depends(get_db)
):
    """Get all product prices (across all products and locations)"""
    try:
        prices = controller.get_all_prices(db, skip=skip, limit=limit)
        return prices
    except Exception as e:
        logger.error(f"Error fetching all prices: {str(e)}")
        raise HTTPException(status_code=500, detail="Error fetching all prices")


@router.get("/statistics")
def get_statistics(db: Session = Depends(get_db)):
    """Get product statistics"""
    try:
        stats = controller.get_product_statistics(db)
        return stats
    except Exception as e:
        logger.error(f"Error fetching statistics: {str(e)}")
        raise HTTPException(status_code=500, detail="Error fetching statistics")


# Celery-based endpoints
@router.post("/celery/fetch", response_model=schemas.TaskResult)
def trigger_product_fetch(
    company_id: int = Query(2, description="Company ID"),
    page_number: int = Query(1, ge=1, description="Starting page number"),
    page_size: int = Query(10, ge=1, le=100, description="Items per page"),
    total_pages: int = Query(1, ge=1, le=50, description="Total pages to fetch")
):
    """Trigger product fetch from external API using Celery"""
    if not celery_available:
        raise HTTPException(
            status_code=503, 
            detail="Celery not available. Product fetch service is currently unavailable."
        )
    
    try:
        # Use the unified product_fetch task that handles both single and bulk operations
        task = product_fetch.delay(
            company_id=company_id,
            page_number=page_number,
            page_size=page_size,
            total_pages=total_pages
        )
        
        fetch_type = "single page" if total_pages == 1 else "bulk"
        message = f"Product fetch task started ({fetch_type}): company {company_id}"
        if total_pages == 1:
            message += f", page {page_number}"
        else:
            message += f", {total_pages} pages"
            
        return schemas.TaskResult(
            task_id=task.id,
            status="started",
            message=message
        )
        
    except Exception as e:
        logger.error(f"Error starting product fetch task: {str(e)}")
        raise HTTPException(status_code=500, detail="Error starting product fetch task")
