import cv2
import numpy as np
import pickle
import os
from typing import List, Tuple, Optional
from PIL import Image
import json
from sqlalchemy.orm import Session
from database import Image as ImageModel, FaceEncoding
from config import settings

# Try to import face_recognition, fall back to dummy implementation if not available
try:
    import face_recognition
    FACE_RECOGNITION_AVAILABLE = True
except ImportError:
    FACE_RECOGNITION_AVAILABLE = False
    print("Warning: face_recognition not available. Face recognition features will be disabled.")

class FaceRecognitionService:
    def __init__(self):
        self.tolerance = settings.FACE_RECOGNITION_TOLERANCE

    def detect_faces_in_image(self, image_path: str) -> Tuple[List[np.ndarray], List[Tuple[int, int, int, int]]]:
        """
        Detect faces in an image and return face encodings and locations
        Returns: (face_encodings, face_locations)
        """
        if not FACE_RECOGNITION_AVAILABLE:
            print("Face recognition not available - returning empty results")
            return [], []
            
        try:
            # Load image
            image = face_recognition.load_image_file(image_path)
            
            # Find face locations
            face_locations = face_recognition.face_locations(image, model="hog")
            
            # Generate face encodings
            face_encodings = face_recognition.face_encodings(image, face_locations)
            
            return face_encodings, face_locations
        except Exception as e:
            print(f"Error detecting faces in {image_path}: {str(e)}")
            return [], []

    def process_image_for_faces(self, db: Session, image_model: ImageModel) -> bool:
        """
        Process an uploaded image to extract and store face encodings
        """
        try:
            face_encodings, face_locations = self.detect_faces_in_image(image_model.file_path)
            
            if face_encodings:
                # Save face encodings to database
                for i, (encoding, location) in enumerate(zip(face_encodings, face_locations)):
                    face_encoding_model = FaceEncoding(
                        image_id=image_model.id,
                        encoding=pickle.dumps(encoding),
                        face_location=json.dumps(location)
                    )
                    db.add(face_encoding_model)
                
                # Update image to mark it has faces
                image_model.has_faces = True
                image_model.face_processed = True
                db.commit()
                return True
            else:
                # No faces found, but mark as processed
                image_model.has_faces = False
                image_model.face_processed = True
                db.commit()
                return False
                
        except Exception as e:
            print(f"Error processing image {image_model.id}: {str(e)}")
            image_model.face_processed = True
            db.commit()
            return False

    def compare_faces(self, known_encoding: np.ndarray, unknown_encoding: np.ndarray) -> bool:
        """
        Compare two face encodings to see if they match
        """
        if not FACE_RECOGNITION_AVAILABLE:
            return False
            
        try:
            results = face_recognition.compare_faces([known_encoding], unknown_encoding, tolerance=self.tolerance)
            return results[0] if results else False
        except Exception as e:
            print(f"Error comparing faces: {str(e)}")
            return False

    def find_matching_images(self, db: Session, target_encoding: np.ndarray, user_id: str) -> List[str]:
        """
        Find all images that contain faces matching the target encoding
        """
        try:
            matching_image_ids = []
            
            # Get all face encodings for the user's images
            face_encodings = db.query(FaceEncoding).join(ImageModel).filter(
                ImageModel.owner_id == user_id,
                ImageModel.face_processed == True
            ).all()
            
            print(f"Face matching: Checking {len(face_encodings)} face encodings for user {user_id}")
            
            for i, face_encoding_record in enumerate(face_encodings):
                try:
                    # Deserialize the stored encoding
                    stored_encoding = pickle.loads(face_encoding_record.encoding)
                    
                    # Compare with target encoding
                    is_match = self.compare_faces(stored_encoding, target_encoding)
                    print(f"Face {i+1}: Image {face_encoding_record.image_id} - Match: {is_match}")
                    
                    if is_match:
                        matching_image_ids.append(face_encoding_record.image_id)
                        
                except Exception as e:
                    print(f"Error processing face encoding {i+1}: {str(e)}")
                    continue
            
            # Remove duplicates (an image might have multiple matching faces)
            unique_matches = list(set(matching_image_ids))
            print(f"Face matching: Found {len(unique_matches)} unique matching images")
            return unique_matches
            
        except Exception as e:
            print(f"Error finding matching images: {str(e)}")
            return []

    def process_face_scan_image(self, image_data: bytes) -> Optional[np.ndarray]:
        """
        Process a face scan image (from camera) and return the face encoding
        """
        if not FACE_RECOGNITION_AVAILABLE:
            print("Face recognition not available - cannot process face scan")
            return None
            
        try:
            # Convert bytes to numpy array
            nparr = np.frombuffer(image_data, np.uint8)
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            
            # Convert BGR to RGB (OpenCV uses BGR, face_recognition uses RGB)
            rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Find face locations and encodings
            face_locations = face_recognition.face_locations(rgb_img)
            face_encodings = face_recognition.face_encodings(rgb_img, face_locations)
            
            if face_encodings:
                # Return the first face encoding found
                return face_encodings[0]
            else:
                return None
                
        except Exception as e:
            print(f"Error processing face scan image: {str(e)}")
            return None

    def get_face_locations_in_image(self, image_path: str) -> List[Tuple[int, int, int, int]]:
        """
        Get face locations in an image for display purposes
        """
        if not FACE_RECOGNITION_AVAILABLE:
            return []
            
        try:
            image = face_recognition.load_image_file(image_path)
            face_locations = face_recognition.face_locations(image)
            return face_locations
        except Exception as e:
            print(f"Error getting face locations: {str(e)}")
            return []

    def create_face_thumbnails(self, image_path: str, output_dir: str) -> List[str]:
        """
        Create thumbnail images of detected faces
        """
        try:
            face_locations = self.get_face_locations_in_image(image_path)
            
            if not face_locations:
                return []
            
            # Load image with PIL
            image = Image.open(image_path)
            thumbnail_paths = []
            
            for i, (top, right, bottom, left) in enumerate(face_locations):
                # Crop face with some padding
                padding = 20
                face_image = image.crop((
                    max(0, left - padding),
                    max(0, top - padding),
                    min(image.width, right + padding),
                    min(image.height, bottom + padding)
                ))
                
                # Save thumbnail
                thumbnail_filename = f"face_{i}_{os.path.basename(image_path)}"
                thumbnail_path = os.path.join(output_dir, thumbnail_filename)
                face_image.save(thumbnail_path)
                thumbnail_paths.append(thumbnail_path)
            
            return thumbnail_paths
            
        except Exception as e:
            print(f"Error creating face thumbnails: {str(e)}")
            return []

# Global face recognition service instance
face_recognition_service = FaceRecognitionService()
