from sqlalchemy.orm import Session
from sqlalchemy import func
from typing import List, Optional, Dict, Any
from . import models, schemas, services
import logging

logger = logging.getLogger(__name__)


def get_products(db: Session, skip: int = 0, limit: int = 100) -> List[models.Product]:
    """Get products with pagination"""
    product_service = services.ProductService(db)
    return product_service.get_products(skip=skip, limit=limit)


def get_locations(db: Session) -> List[models.Location]:
    """Get all locations"""
    product_service = services.ProductService(db)
    return product_service.get_locations()


def get_vendors(db: Session) -> List[models.Vendor]:
    """Get all vendors"""
    product_service = services.ProductService(db)
    return product_service.get_vendors()


def get_categories(db: Session) -> List[models.Category]:
    """Get all categories"""
    product_service = services.ProductService(db)
    return product_service.get_categories()


def get_product_vendors(db: Session, product_id: int) -> List[models.Vendor]:
    """Get vendors for a specific product with detailed vendor information"""
    # Join ProductVendor with Vendor to get detailed vendor information
    vendors = db.query(models.Vendor).join(
        models.ProductVendor, models.Vendor.vendor_id == models.ProductVendor.vendor_id
    ).filter(
        models.ProductVendor.product_id == product_id
    ).all()
    
    return vendors


def get_product_locations(db: Session, product_id: int) -> List[models.ProductLocation]:
    """Get locations for a specific product"""
    return db.query(models.ProductLocation).filter(
        models.ProductLocation.product_id == product_id
    ).all()


def get_all_prices(db: Session, skip: int = 0, limit: int = 100) -> List[models.ProductPrice]:
    """Get all prices with pagination"""
    product_service = services.ProductService(db)
    return product_service.get_all_prices(skip=skip, limit=limit)


def get_product_statistics(db: Session) -> Dict[str, Any]:
    """Get product statistics"""
    try:
        total_products = db.query(models.Product).count()
        total_locations = db.query(models.Location).count()
        total_vendors = db.query(models.Vendor).count()
        total_categories = db.query(models.Category).count()
        total_product_vendor_mappings = db.query(models.ProductVendor).count()
        total_product_location_mappings = db.query(models.ProductLocation).count()
        
        # Products by category (using the new Category table)
        category_counts = db.query(
            models.Category.category_name,
            func.count(models.Product.id).label('count')
        ).join(
            models.Product, models.Category.category_id == models.Product.fk_product_category_id
        ).group_by(models.Category.category_name).all()
        
        # Products by brand
        brand_counts = db.query(
            models.Product.brand_name,
            func.count(models.Product.id).label('count')
        ).group_by(models.Product.brand_name).all()
        
        return {
            "total_products": total_products,
            "total_locations": total_locations,
            "total_vendors": total_vendors,
            "total_categories": total_categories,
            "total_product_vendor_mappings": total_product_vendor_mappings,
            "total_product_location_mappings": total_product_location_mappings,
            "categories": {category: count for category, count in category_counts if category is not None},
            "brands": {brand: count for brand, count in brand_counts if brand is not None}
        }
    except Exception as e:
        logger.error(f"Error in get_product_statistics: {str(e)}")
        raise