from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
from . import models, schemas
import requests
import logging
from sqlalchemy.exc import IntegrityError

logger = logging.getLogger(__name__)


class ProductService:
    """Service class for product-related operations"""
    
    def __init__(self, db: Session):
        self.db = db

    def get_or_create_product(self, product_data: schemas.ProductCreate) -> models.Product:
        """Get existing product or create new one"""
        existing_product = self.db.query(models.Product).filter(
            models.Product.product_id == product_data.product_id
        ).first()
        
        if existing_product:
            # Update existing product
            for field, value in product_data.model_dump(exclude_unset=True).items():
                setattr(existing_product, field, value)
            self.db.commit()
            self.db.refresh(existing_product)
            return existing_product
        else:
            # Create new product
            db_product = models.Product(**product_data.model_dump())
            self.db.add(db_product)
            self.db.commit()
            self.db.refresh(db_product)
            return db_product

    def get_or_create_location(self, location_data: schemas.LocationCreate) -> models.Location:
        """Get existing location or create new one"""
        existing_location = self.db.query(models.Location).filter(
            models.Location.location_id == location_data.location_id
        ).first()
        
        if existing_location:
            # Update existing location
            for field, value in location_data.model_dump(exclude_unset=True).items():
                setattr(existing_location, field, value)
            self.db.commit()
            self.db.refresh(existing_location)
            return existing_location
        else:
            # Create new location
            db_location = models.Location(**location_data.model_dump())
            self.db.add(db_location)
            self.db.commit()
            self.db.refresh(db_location)
            return db_location

    def get_or_create_vendor(self, vendor_data: schemas.VendorCreate) -> models.Vendor:
        """Get existing vendor or create new one"""
        existing_vendor = self.db.query(models.Vendor).filter(
            models.Vendor.vendor_id == vendor_data.vendor_id
        ).first()
        
        if existing_vendor:
            # Update existing vendor
            for field, value in vendor_data.model_dump(exclude_unset=True).items():
                setattr(existing_vendor, field, value)
            self.db.commit()
            self.db.refresh(existing_vendor)
            return existing_vendor
        else:
            # Create new vendor
            db_vendor = models.Vendor(**vendor_data.model_dump())
            self.db.add(db_vendor)
            self.db.commit()
            self.db.refresh(db_vendor)
            return db_vendor

    def get_or_create_category(self, category_data: schemas.CategoryCreate) -> models.Category:
        """Get existing category or create new one"""
        existing_category = self.db.query(models.Category).filter(
            models.Category.category_id == category_data.category_id
        ).first()
        
        if existing_category:
            # Update existing category
            for field, value in category_data.model_dump(exclude_unset=True).items():
                setattr(existing_category, field, value)
            self.db.commit()
            self.db.refresh(existing_category)
            return existing_category
        else:
            # Create new category
            db_category = models.Category(**category_data.model_dump())
            self.db.add(db_category)
            self.db.commit()
            self.db.refresh(db_category)
            return db_category

    def get_or_create_product_price(self, price_data: schemas.ProductPriceCreate) -> models.ProductPrice:
        """Get existing product price or create new one"""
        existing_price = self.db.query(models.ProductPrice).filter(
            models.ProductPrice.product_price_id == price_data.product_price_id
        ).first()
        
        if existing_price:
            # Update existing price
            for field, value in price_data.model_dump(exclude_unset=True).items():
                setattr(existing_price, field, value)
            self.db.commit()
            self.db.refresh(existing_price)
            return existing_price
        else:
            # Create new price
            db_price = models.ProductPrice(**price_data.model_dump())
            self.db.add(db_price)
            self.db.commit()
            self.db.refresh(db_price)
            return db_price

    def create_product_vendor_mapping(self, product_id: int, vendor_id: int) -> Optional[models.ProductVendor]:
        """Create product-vendor mapping if it doesn't exist"""
        existing_mapping = self.db.query(models.ProductVendor).filter(
            models.ProductVendor.product_id == product_id,
            models.ProductVendor.vendor_id == vendor_id
        ).first()
        
        if not existing_mapping:
            try:
                db_mapping = models.ProductVendor(product_id=product_id, vendor_id=vendor_id)
                self.db.add(db_mapping)
                self.db.commit()
                self.db.refresh(db_mapping)
                return db_mapping
            except IntegrityError:
                self.db.rollback()
                logger.warning(f"Product-Vendor mapping already exists: {product_id}-{vendor_id}")
                return None
        
        return existing_mapping

    def create_product_location_mapping(self, product_id: int, location_id: int) -> Optional[models.ProductLocation]:
        """Create product-location mapping if it doesn't exist"""
        existing_mapping = self.db.query(models.ProductLocation).filter(
            models.ProductLocation.product_id == product_id,
            models.ProductLocation.location_id == location_id
        ).first()
        
        if not existing_mapping:
            try:
                db_mapping = models.ProductLocation(product_id=product_id, location_id=location_id)
                self.db.add(db_mapping)
                self.db.commit()
                self.db.refresh(db_mapping)
                return db_mapping
            except IntegrityError:
                self.db.rollback()
                logger.warning(f"Product-Location mapping already exists: {product_id}-{location_id}")
                return None
        
        return existing_mapping

    def process_api_product(self, api_product: schemas.APIProduct) -> models.Product:
        """Process a single product from API response and store in database"""
        try:
            # Extract image data (first image if available)
            image_path = None
            if api_product.productImages and len(api_product.productImages) > 0:
                image_path = api_product.productImages[0].imagePath

            # Process category if available
            if api_product.fkProductCategoryId and api_product.categoryName:
                category_data = schemas.CategoryCreate(
                    category_id=api_product.fkProductCategoryId,
                    category_name=api_product.categoryName
                )
                self.get_or_create_category(category_data)

            # Create product data (without price fields)
            product_data = schemas.ProductCreate(
                product_id=api_product.productId,
                product_name=api_product.productName,
                short_name=api_product.shortName,
                description=api_product.description,
                brand_name=api_product.brandName,
                fk_product_category_id=api_product.fkProductCategoryId,
                image_path=image_path,
                eligible_for_return=api_product.eligibleForReturn,
                display_on_pos=api_product.displayOnPOS,
                display_on_online_store=api_product.displayOnOnlineStore
            )

            # Create or update product
            db_product = self.get_or_create_product(product_data)

            # FIRST: Process locations from the locations array
            if api_product.locations:
                for location in api_product.locations:
                    location_data = schemas.LocationCreate(
                        location_id=location.locationId,
                        location_name=location.locationName
                    )
                    self.get_or_create_location(location_data)
                    self.create_product_location_mapping(api_product.productId, location.locationId)

            # SECOND: Process ALL prices (location-wise pricing)
            # Ensure locations exist before processing prices
            if api_product.prices:
                for price in api_product.prices:
                    # Check if location exists for this price, if not create it
                    existing_location = self.db.query(models.Location).filter(
                        models.Location.location_id == price.fkLocationId
                    ).first()
                    
                    if not existing_location:
                        logger.warning(f"Location {price.fkLocationId} not found, creating minimal location record")
                        minimal_location_data = schemas.LocationCreate(
                            location_id=price.fkLocationId,
                            location_name=f"Location {price.fkLocationId}"  # Default name
                        )
                        self.get_or_create_location(minimal_location_data)
                        self.create_product_location_mapping(api_product.productId, price.fkLocationId)

                    # Now process the price
                    price_data_item = schemas.ProductPriceCreate(
                        product_price_id=price.productPriceId,
                        product_id=api_product.productId,
                        location_id=price.fkLocationId,
                        cost_price_per_unit=price.costPricePerUnit,
                        markup_value=price.markupValue,
                        margin_value=price.marginValue,
                        retail_price_excl_tax=price.retailPriceExclTax,
                        compare_at_price=price.compareAtPrice,
                        markup_type_name=price.markupTypeName,
                        margin_type_name=price.marginTypeName
                    )
                    self.get_or_create_product_price(price_data_item)

            # Process vendors
            if api_product.vendors:
                for vendor in api_product.vendors:
                    vendor_data = schemas.VendorCreate(
                        vendor_id=vendor.vendorId,
                        vendor_name=vendor.vendorName,
                        vendor_code=vendor.vendorCode
                    )
                    self.get_or_create_vendor(vendor_data)
                    self.create_product_vendor_mapping(api_product.productId, vendor.vendorId)

            return db_product

        except Exception as e:
            logger.error(f"Error processing product {api_product.productId}: {str(e)}")
            self.db.rollback()
            raise

    def fetch_products_from_api(self, request_data: schemas.ProductFetchRequest) -> Dict[str, Any]:
        """Fetch products from external API"""
        api_url = "https://hubwalletdev-api.myteamconnector.com/api/AIIntegration/ProductSearch"
        
        payload = {
            "companyId": request_data.company_id,
            "pageNumber": request_data.page_number,
            "pageSize": request_data.page_size
        }
        
        headers = {
            'accept': '*/*',
            'Content-Type': 'application/json'
        }
        
        try:
            response = requests.post(api_url, json=payload, headers=headers, timeout=60)
            response.raise_for_status()
            
            api_response = response.json()
            logger.info(f"API response received: {api_response.get('success', False)}")
            
            return api_response
            
        except requests.exceptions.RequestException as e:
            logger.error(f"API request failed: {str(e)}")
            # Raise a standard exception to avoid Celery serialization issues with requests exceptions
            raise RuntimeError(f"API request failed: {str(e)}")
        except ValueError as e:
            logger.error(f"JSON decode error: {str(e)}")
            raise RuntimeError(f"Invalid JSON response: {str(e)}")

    def process_api_response(self, api_response: Dict[str, Any]) -> Dict[str, Any]:
        """Process the complete API response and store products in database"""
        try:
            if not api_response.get('success', False):
                error_msg = api_response.get('errorMessage', 'Unknown API error')
                raise ValueError(f"API returned error: {error_msg}")

            data = api_response.get('data', {})
            items = data.get('items', [])
            
            processed_count = 0
            errors = []
            
            for item_data in items:
                try:
                    api_product = schemas.APIProduct(**item_data)
                    self.process_api_product(api_product)
                    processed_count += 1
                except Exception as e:
                    error_msg = f"Error processing product {item_data.get('productId', 'unknown')}: {str(e)}"
                    logger.error(error_msg)
                    errors.append(error_msg)

            result = {
                "success": True,
                "total_items": len(items),
                "processed_count": processed_count,
                "error_count": len(errors),
                "errors": errors
            }
            
            if errors:
                result["success"] = False
                
            return result
            
        except Exception as e:
            logger.error(f"Error processing API response: {str(e)}")
            self.db.rollback()
            return {
                "success": False,
                "error": str(e),
                "total_items": 0,
                "processed_count": 0,
                "error_count": 1
            }

    def get_products(self, skip: int = 0, limit: int = 100) -> List[models.Product]:
        """Get products with pagination"""
        return self.db.query(models.Product).offset(skip).limit(limit).all()

    def get_product_by_id(self, product_id: int) -> Optional[models.Product]:
        """Get product by external product ID"""
        return self.db.query(models.Product).filter(models.Product.product_id == product_id).first()

    def get_locations(self) -> List[models.Location]:
        """Get all locations"""
        return self.db.query(models.Location).all()

    def get_vendors(self) -> List[models.Vendor]:
        """Get all vendors"""
        return self.db.query(models.Vendor).all()

    def get_categories(self) -> List[models.Category]:
        """Get all categories"""
        return self.db.query(models.Category).all()

    def get_all_prices(self, skip: int = 0, limit: int = 100) -> List[models.ProductPrice]:
        """Get all prices with pagination"""
        return self.db.query(models.ProductPrice).offset(skip).limit(limit).all()