import logging
from typing import Dict, Any
from src.utils.celery_worker import celery_app

# Import database session
try:
    from src.utils.db import get_db_session
    from src.smart_inventory.apps.products.services import ProductService
    from src.smart_inventory.apps.products.schemas import ProductFetchRequest
except ImportError as e:
    print(f"Warning: Could not import required modules: {e}")
    # Set to None so we can handle gracefully
    get_db_session = None
    ProductService = None
    ProductFetchRequest = None

logger = logging.getLogger(__name__)


@celery_app.task(bind=True)
def product_fetch(self, company_id: int = 2, page_number: int = 1, page_size: int = 10, total_pages: int = 1) -> Dict[str, Any]:
    """
    Celery task to fetch products from external API and store in database
    Can handle both single page and bulk operations based on total_pages parameter
    
    Args:
        company_id: Company ID for API request
        page_number: Starting page number to fetch (for single page) or ignored (for bulk)
        page_size: Number of items per page
        total_pages: Total number of pages to fetch (1 for single page, >1 for bulk)
    
    Returns:
        Dict containing task results
    """
    # Check if required modules are available
    if not all([get_db_session, ProductService, ProductFetchRequest]):
        error_msg = "Required modules not available"
        logger.error(error_msg)
        return {
            "success": False,
            "error": error_msg,
            "task_id": self.request.id
        }
    
    try:
        if total_pages == 1:
            # Single page fetch
            logger.info(f"Starting single page fetch: company_id={company_id}, page={page_number}, size={page_size}")
            
            # Update task state to indicate progress
            self.update_state(state='PROGRESS', meta={'message': 'Initializing product fetch...'})
            
            # Create database session
            db = get_db_session()
            
            try:
                # Initialize product service
                product_service = ProductService(db)
                
                # Create request data
                request_data = ProductFetchRequest(
                    company_id=company_id,
                    page_number=page_number,
                    page_size=page_size
                )
                
                # Update task state
                self.update_state(state='PROGRESS', meta={'message': 'Fetching data from API...'})
                
                # Fetch data from API
                api_response = product_service.fetch_products_from_api(request_data)
                
                # Update task state
                self.update_state(state='PROGRESS', meta={'message': 'Processing and storing products...'})
                
                # Process and store the data
                result = product_service.process_api_response(api_response)
                
                # Add task metadata
                result.update({
                    "task_id": self.request.id,
                    "company_id": company_id,
                    "page_number": page_number,
                    "page_size": page_size,
                    "total_pages": total_pages
                })
                
                if result.get("success", False):
                    logger.info(f"Product fetch completed successfully: {result}")
                    self.update_state(state='SUCCESS', meta=result)
                else:
                    logger.error(f"Product fetch completed with errors: {result}")
                    self.update_state(state='FAILURE', meta=result)
                
                return result
                
            finally:
                # Always close the database session
                db.close()
        
        else:
            # Bulk fetch (multiple pages)
            logger.info(f"Starting bulk product fetch: company_id={company_id}, total_pages={total_pages}, page_size={page_size}")
            
            results = {
                "success": True,
                "total_pages": total_pages,
                "processed_pages": 0,
                "total_items": 0,
                "total_processed": 0,
                "total_errors": 0,
                "page_results": [],
                "task_id": self.request.id,
                "company_id": company_id,
                "page_size": page_size
            }
            
            for page in range(1, total_pages + 1):
                try:
                    self.update_state(
                        state='PROGRESS', 
                        meta={'message': f'Processing page {page}/{total_pages}', 'current_page': page, 'total_pages': total_pages}
                    )
                    
                    # Create database session for this page
                    db = get_db_session()
                    
                    try:
                        # Initialize product service
                        product_service = ProductService(db)
                        
                        # Create request data for this page
                        request_data = ProductFetchRequest(
                            company_id=company_id,
                            page_number=page,
                            page_size=page_size
                        )
                        
                        # Fetch data from API
                        api_response = product_service.fetch_products_from_api(request_data)
                        
                        # Process and store the data
                        page_result = product_service.process_api_response(api_response)
                        
                        # Add page metadata
                        page_result.update({
                            "page": page,
                            "company_id": company_id,
                            "page_size": page_size
                        })
                        
                        results["page_results"].append(page_result)
                        results["processed_pages"] += 1
                        
                        if page_result.get("success", False):
                            results["total_items"] += page_result.get("total_items", 0)
                            results["total_processed"] += page_result.get("processed_count", 0)
                            results["total_errors"] += page_result.get("error_count", 0)
                        else:
                            results["success"] = False
                            results["total_errors"] += 1
                            
                    finally:
                        # Always close the database session
                        db.close()
                        
                except Exception as e:
                    logger.error(f"Error processing page {page}: {str(e)}")
                    results["success"] = False
                    results["total_errors"] += 1
                    results["page_results"].append({
                        "success": False,
                        "error": str(e),
                        "page": page
                    })
            
            logger.info(f"Bulk product fetch completed: {results}")
            return results
            
    except Exception as e:
        error_msg = f"Product fetch task failed: {str(e)}"
        logger.error(error_msg, exc_info=True)
        
        result = {
            "success": False,
            "error": error_msg,
            "task_id": self.request.id,
            "company_id": company_id,
            "page_number": page_number,
            "page_size": page_size,
            "total_pages": total_pages
        }
        
        self.update_state(state='FAILURE', meta=result)
        return result
