import os, uuid, json, logging
from typing import List, Dict, Any, Optional
from datetime import datetime
from src.utils.settings import settings
import re
import openai
import qdrant_client
from openai import AzureOpenAI
from qdrant_client import QdrantClient
from qdrant_client.models import (
    PointStruct,
    VectorParams,
    Distance,
    FieldCondition,
    Filter,
    Range,
    MatchValue,
    MatchAny,  
    MatchText   
)

from fastapi import HTTPException
from src.menu_design.apps.editor.models import Templates
from sqlalchemy.orm import Session




class EmbeddingService:
    def __init__(self):
        self.client = AzureOpenAI(
            api_version="2024-02-01",
            azure_endpoint=settings.AZURE_OPENAI_API_ENDPOINT,
            api_key=settings.AZUREKEY_CREDENTIAL_API_KEY
        )
        self.deployment = "EmbedLarge3"
        self.embedding_dimensions = 3072
    
    def generate_embedding(self, text: str) -> list:
        """Generate embedding with robust error handling"""
        if not text or not isinstance(text, str):
            # Return zero vector if no valid text
            return [0.0] * self.embedding_dimensions
        
        # Clean the text
        clean_text = self._clean_text(text)
        if not clean_text:
            return [0.0] * self.embedding_dimensions
        
        try:
            response = self.client.embeddings.create(
                input=[clean_text],
                model=self.deployment
            )
            return response.data[0].embedding
        except Exception as e:
            print(f"Embedding generation failed: {str(e)}")
            return [0.0] * self.embedding_dimensions
    
    def _clean_text(self, text: str) -> str:
        """Clean text before sending to embedding API"""
        # Remove excessive whitespace
        text = ' '.join(text.split())
        
        # Truncate if too long (Azure has token limits)
        max_length = 8000  # characters
        if len(text) > max_length:
            text = text[:max_length]
        
        return text.strip()
    
##-----------------Qdrant operations --------------Start----------------------------------------------------------#

class QdrantService:
    def __init__(self):
        self.embedding_service = EmbeddingService()
        self.embedding_dimensions = 3072
        self.client = QdrantClient(
            url=settings.QADERND_URL,
            api_key=settings.QDRANT_API_KEY,
            prefer_grpc=True
        )

    # def _get_embedding(self, text: str) -> List[float]:
    #     return self.embedding_service.generate_embedding(text)
    def _get_embedding(self, text: str) -> List[float]:
        """Wrapper with additional logging"""
        try:
            embedding = self.embedding_service.generate_embedding(text)
            if len(embedding) != self.embedding_dimensions:
                print(f"Warning: Invalid embedding dimension received: {len(embedding)}")
                return [0.0] * self.embedding_dimensions
            return embedding
        except Exception as e:
            print(f"Embedding service error: {str(e)}")
            return [0.0] * self.embedding_dimensions

    def _extract_search_text(self, data: Any) -> str:
        """Enhanced text extraction from template data"""
        text_parts = []
        
        if isinstance(data, str):
            return data  # Handle plain text case
        
        if isinstance(data, list):
            for item in data:
                text_parts.append(self._extract_search_text(item))
        
        elif isinstance(data, dict):
            # Extract all string values
            for value in data.values():
                if isinstance(value, str):
                    text_parts.append(value)
                elif isinstance(value, (dict, list)):
                    text_parts.append(self._extract_search_text(value))
            
            # Special handling for layer structures
            if 'layers' in data:
                for layer in data['layers'].values():
                    if isinstance(layer, dict):
                        if layer.get('type', {}).get('resolvedName') == 'TextLayer':
                            text = layer.get('props', {}).get('text', '')
                            if text:
                                text_parts.append(re.sub(r'<[^>]+>', '', text))
        
        # Clean and combine all text
        clean_text = ' '.join([
            t.strip() for t in text_parts 
            if t and isinstance(t, str)
        ])
        return clean_text[:8000]  # Truncate to model limit

    def create_collection(self, name: str, db: Session = None) -> None:
        """Create collection optimized for template search"""
        if not self.client.collection_exists(name):
            self.client.create_collection(
                name,
                vectors_config=VectorParams(
                    size=self.embedding_dimensions,
                    distance=Distance.COSINE
                    
                ),
                optimizers_config={
                    "default_segment_number": 2,
                    "indexing_threshold": 0
                }
            )
            
            # Index metadata fields that are useful for filtering
            metadata_fields = [
                ("template_type", "keyword"),
                ("industry", "keyword"), 
                ("tags", "keyword"),
                ("design_style", "keyword"),
                ("color_theme", "keyword")
            ]
            
            for field, field_type in metadata_fields:
                self._create_field_index(
                    name,
                    f"metadata.{field}",
                    field_type
                )
            
            if db:
                self._mark_template_as_processed(db=db, name=name)

    def _mark_template_as_processed(self, db: Session, name: str):
        template = db.query(Templates).filter(Templates.template_name == name).first()
        if template and template.data_process != 1:
            template.data_process = 1
            db.commit()
            db.refresh(template)

    def _create_field_index(self, collection: str, field: str, field_type: str) -> None:
        try:
            self.client.create_payload_index(
                collection_name=collection,
                field_name=field,
                field_schema=field_type
            )
            logging.info(f"Created index for {field} as {field_type}")
        except Exception as e:
            logging.warning(f"Failed to create index for {field}: {str(e)}")

   
    def search(
        self,
        query: str,
        collection: str,
        limit: int = 5,
        filters: Optional[Dict] = None,
        score_threshold: float = 0.3 # Lower default threshold
    ) -> Dict:
        if not self.client.collection_exists(collection):
            raise HTTPException(404, f"Collection '{collection}' not found")

        # Generate embedding
        vector = self._get_embedding(query)
        
        # Parse filters
        query_filter = self._parse_filters(filters) if filters else None

        # Execute search
        results = self.client.search(
            collection_name=collection,
            query_vector=vector,
            query_filter=query_filter,
            limit=limit,
            score_threshold=score_threshold,
            with_payload=True
        )

        return {
                "query": query,
                "results": [{
                    "id": r.payload.get("original_id", r.id),
                    "score": r.score,
                    "data": r.payload.get("json_data", {}),
                    "searchable_text": r.payload.get("searchable_text", ""),
                    "highlights": self._find_highlights(query, r.payload)
                } for r in results if r.score>0.6],
                "total_results": len(results)
            }
    def insert_documents(self, docs: List[Dict[str, Any]], collection: str, db: Session) -> Dict:
        """Insert documents with specific fields from database in payload"""
        self.create_collection(collection, db=db)
        
        points = []
        for doc in docs:
            try:
                original_id = str(doc["id"])
                try:
                    doc_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, original_id)
                except:
                    doc_uuid = uuid.uuid4()
                
                # Extract meaningful text for embedding
                search_text = self._extract_search_text(doc["template_metadata"])
                embedding = self._get_embedding(search_text)
                
                # Store only the specific fields we need in payload
                point = PointStruct(
                    id=str(doc_uuid),
                    vector=embedding,
                    payload={
                        "searchable_text": search_text,
                        "template_url": doc.get("template_url"),
                        "template_id": original_id,
                        "template_name": doc.get("template_name"),
                        "template_metadata": doc.get("template_metadata", {}),
                        "inserted_at": datetime.utcnow().isoformat()
                    }
                )
                points.append(point)
            except Exception as e:
                print(f"Failed to process document {doc.get('id')}: {str(e)}")
                continue
        print("points",points)
        if not points:
            return {"success": False, "error": "No valid documents to insert"}
        
        try:
            operation_result = self.client.upsert(
                collection_name=collection,
                points=points,
                wait=True
            )
            return {
                "success": True,
                "inserted_count": len(points),
                "operation_id": operation_result.operation_id
            }
        except Exception as e:
            print(f"Qdrant upsert failed: {str(e)}")
            return {"success": False, "error": str(e)}
    
    

    def _parse_filters(self, filters: Dict) -> Optional[Filter]:
        """Convert API filter format to Qdrant Filter"""
        conditions = []

        # Normalize filter format
        if 'must' in filters:
            filter_items = filters['must']
        else:
            filter_items = [filters]  # Wrap single dict

        for item in filter_items:
            if 'key' in item:  # Structured filter with explicit key
                field = item['key'].replace('json_data.', '')
                value = item.get('match', {}).get('value') or item.get('value')
            else:  # Simple key-value or range format
                field, value = next(iter(item.items()))

            field_name = f"json_data.{field}"

            if isinstance(value, dict):  # Range or MatchAny
                if "any" in value:
                    # Handle MatchAny
                    conditions.append(FieldCondition(key=field_name, match=MatchAny(any=value["any"])))
                else:
                    # Handle Range
                    conditions.append(FieldCondition(key=field_name, range=Range(**value)))
            elif isinstance(value, list):  # Fallback list handling
                conditions.append(FieldCondition(key=field_name, match=MatchAny(any=value)))
            else:  # Single match value
                conditions.append(FieldCondition(key=field_name, match=MatchValue(value=value)))

        return Filter(must=conditions) if conditions else None


    def _find_highlights(self, query: str, payload: Dict) -> List[str]:
        """Identify matching text snippets"""
        query_terms = set(query.lower().split())
        text = payload.get("searchable_text", "").lower()
        fields = payload.get("json_data", {})
        
        highlights = []
        for term in query_terms:
            if len(term) > 3 and term in text:
                highlights.append(term)
        
        # Add field-specific matches
        for field, value in fields.items():
            if isinstance(value, str) and any(term in value.lower() for term in query_terms):
                highlights.append(f"{field}:{value}")
        
        return highlights[:5]  # Return top 5 highlights

    def _find_matched_fields(self, query: str, payload: Dict) -> List[str]:
        """
        Identify which fields likely matched the query
        
        Args:
            query: Original search query
            payload: Document payload
            
        Returns:
            List of field names that likely matched
        """
        matched_fields = []
        json_data = payload.get("json_data", {})
        search_fields = payload.get("searchable_fields", [])
        
        # Simple heuristic: check if query terms appear in fields
        query_terms = set(query.lower().split())
        for field in search_fields:
            field_value = str(json_data.get(field, "")).lower()
            if any(term in field_value for term in query_terms if len(term) > 3):
                matched_fields.append(field)
        
        return matched_fields

    # Additional utility methods
    def collection_info(self, name: str) -> Dict:
        """Get collection information"""
        if not self.client.collection_exists(name):
            raise HTTPException(404, f"Collection '{name}' not found")
        info = self.client.get_collection(name)
        return {
            "status": info.status,
            "vectors_count": info.vectors_count,
            "points_count": info.points_count,
            "config": info.config.dict()
        }

    def delete_collection(self, name: str) -> Dict:
        """Delete a collection"""
        if not self.client.collection_exists(name):
            raise HTTPException(404, f"Collection '{name}' not found")
        self.client.delete_collection(name)
        return {"success": True, "message": f"Collection '{name}' deleted"}
    
##-----------------Qdrant operations -----------------End-------------------------------------------------------#



    








