import ArgumentParser
import AppKit
import Vision
import Foundation

@main
struct FaceCrop: ParsableCommand {
    static let configuration = CommandConfiguration(
        abstract: "Crop images centered on detected faces using Apple Vision framework.",
        discussion: """
            This tool detects faces in images and crops them to a specified size,
            keeping the face centered in the frame. Supports bulk processing of directories.

            Examples:
              facecrop image.jpg -o cropped.jpg -w 400 -h 400
              facecrop ./photos -o ./cropped -w 800 -h 800
            """
    )

    @Argument(help: "Input image file or directory containing images.")
    var input: String

    @Option(name: .shortAndLong, help: "Output file or directory. Defaults to input location with '_cropped' suffix.")
    var output: String?

    @Option(name: .shortAndLong, help: "Width of the cropped image in pixels.")
    var width: Int = 400

    @Option(name: [.customShort("t"), .long], help: "Height of the cropped image in pixels.")
    var height: Int = 400

    @Option(name: .shortAndLong, help: "Padding around the face as a percentage (0.0-1.0). Default 0.5 means 50% padding.")
    var padding: Double = 0.5

    @Option(name: .shortAndLong, help: "JPEG quality for output (0.0-1.0). Default 0.9.")
    var quality: Double = 0.9

    @Flag(name: .shortAndLong, help: "Process images concurrently for faster bulk processing.")
    var concurrent: Bool = false

    @Flag(name: .long, help: "Skip images where no face is detected instead of erroring.")
    var skipNoFace: Bool = false

    @Flag(name: .shortAndLong, help: "Verbose output.")
    var verbose: Bool = false

    mutating func run() throws {
        let inputURL = URL(fileURLWithPath: input)
        let fileManager = FileManager.default

        var isDirectory: ObjCBool = false
        guard fileManager.fileExists(atPath: inputURL.path, isDirectory: &isDirectory) else {
            throw ValidationError("Input path does not exist: \(input)")
        }

        if isDirectory.boolValue {
            try processDirectory(inputURL: inputURL, fileManager: fileManager)
        } else {
            try processSingleFile(inputURL: inputURL)
        }
    }

    private func processDirectory(inputURL: URL, fileManager: FileManager) throws {
        let outputDir: URL
        if let output = output {
            outputDir = URL(fileURLWithPath: output)
        } else {
            outputDir = inputURL.appendingPathComponent("cropped")
        }

        // Create output directory if needed
        try fileManager.createDirectory(at: outputDir, withIntermediateDirectories: true)

        // Find all image files
        let imageExtensions = Set(["jpg", "jpeg", "png", "heic", "heif", "tiff", "tif", "bmp", "gif"])
        let contents = try fileManager.contentsOfDirectory(at: inputURL, includingPropertiesForKeys: nil)
        let imageFiles = contents.filter { imageExtensions.contains($0.pathExtension.lowercased()) }

        if imageFiles.isEmpty {
            print("No image files found in directory.")
            return
        }

        print("Found \(imageFiles.count) images to process...")

        var successCount = 0
        var skipCount = 0
        var errorCount = 0

        if concurrent {
            let results = processImagesConcurrently(imageFiles: imageFiles, outputDir: outputDir)
            successCount = results.success
            skipCount = results.skipped
            errorCount = results.errors
        } else {
            for (index, imageURL) in imageFiles.enumerated() {
                let outputURL = outputDir.appendingPathComponent(imageURL.lastPathComponent)

                do {
                    try processImage(inputURL: imageURL, outputURL: outputURL)
                    successCount += 1
                    if verbose {
                        print("[\(index + 1)/\(imageFiles.count)] Processed: \(imageURL.lastPathComponent)")
                    }
                } catch FaceCropError.noFaceDetected {
                    if skipNoFace {
                        skipCount += 1
                        if verbose {
                            print("[\(index + 1)/\(imageFiles.count)] Skipped (no face): \(imageURL.lastPathComponent)")
                        }
                    } else {
                        errorCount += 1
                        print("[\(index + 1)/\(imageFiles.count)] Error (no face): \(imageURL.lastPathComponent)")
                    }
                } catch {
                    errorCount += 1
                    print("[\(index + 1)/\(imageFiles.count)] Error: \(imageURL.lastPathComponent) - \(error.localizedDescription)")
                }
            }
        }

        print("\nComplete: \(successCount) processed, \(skipCount) skipped, \(errorCount) errors")
    }

    private func processImagesConcurrently(imageFiles: [URL], outputDir: URL) -> (success: Int, skipped: Int, errors: Int) {
        let counter = ConcurrentCounter()
        let total = imageFiles.count
        let verbose = self.verbose
        let skipNoFace = self.skipNoFace

        DispatchQueue.concurrentPerform(iterations: imageFiles.count) { index in
            let imageURL = imageFiles[index]
            let outputURL = outputDir.appendingPathComponent(imageURL.lastPathComponent)

            do {
                try processImage(inputURL: imageURL, outputURL: outputURL)
                let current = counter.incrementSuccess()
                if verbose {
                    print("[\(current)/\(total)] Processed: \(imageURL.lastPathComponent)")
                }
            } catch FaceCropError.noFaceDetected {
                if skipNoFace {
                    let current = counter.incrementSkipped()
                    if verbose {
                        print("[\(current)/\(total)] Skipped (no face): \(imageURL.lastPathComponent)")
                    }
                } else {
                    counter.incrementError()
                    print("Error (no face): \(imageURL.lastPathComponent)")
                }
            } catch {
                counter.incrementError()
                print("Error: \(imageURL.lastPathComponent) - \(error.localizedDescription)")
            }
        }

        return counter.results
    }

    private func processSingleFile(inputURL: URL) throws {
        let outputURL: URL
        if let output = output {
            outputURL = URL(fileURLWithPath: output)
        } else {
            let filename = inputURL.deletingPathExtension().lastPathComponent + "_cropped"
            let ext = inputURL.pathExtension
            outputURL = inputURL.deletingLastPathComponent().appendingPathComponent(filename).appendingPathExtension(ext)
        }

        try processImage(inputURL: inputURL, outputURL: outputURL)
        print("Saved cropped image to: \(outputURL.path)")
    }

    private func processImage(inputURL: URL, outputURL: URL) throws {
        // Load image
        guard let image = NSImage(contentsOf: inputURL),
              let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) else {
            throw FaceCropError.failedToLoadImage(inputURL.path)
        }

        // Detect faces
        let faceRect = try detectFace(in: cgImage)

        // Crop image centered on face
        let croppedImage = try cropImage(cgImage, centeredOn: faceRect, targetWidth: width, targetHeight: height)

        // Save output
        try saveImage(croppedImage, to: outputURL)
    }

    private func detectFace(in image: CGImage) throws -> CGRect {
        let request = VNDetectFaceRectanglesRequest()
        let handler = VNImageRequestHandler(cgImage: image, options: [:])

        try handler.perform([request])

        guard let results = request.results, !results.isEmpty else {
            throw FaceCropError.noFaceDetected
        }

        // If multiple faces, use the largest one (likely the main subject)
        let largestFace = results.max(by: { $0.boundingBox.width * $0.boundingBox.height < $1.boundingBox.width * $1.boundingBox.height })!

        // Convert normalized coordinates to pixel coordinates
        let imageWidth = CGFloat(image.width)
        let imageHeight = CGFloat(image.height)

        let faceRect = CGRect(
            x: largestFace.boundingBox.origin.x * imageWidth,
            y: largestFace.boundingBox.origin.y * imageHeight,
            width: largestFace.boundingBox.width * imageWidth,
            height: largestFace.boundingBox.height * imageHeight
        )

        if verbose {
            print("  Face detected at: \(faceRect)")
        }

        return faceRect
    }

    private func cropImage(_ image: CGImage, centeredOn faceRect: CGRect, targetWidth: Int, targetHeight: Int) throws -> CGImage {
        let imageWidth = CGFloat(image.width)
        let imageHeight = CGFloat(image.height)

        // Calculate face center
        let faceCenterX = faceRect.midX
        let faceCenterY = faceRect.midY

        // Calculate crop size based on face size and padding
        // The crop should be large enough to include the face plus padding
        let faceSize = max(faceRect.width, faceRect.height)
        let cropSize = faceSize * (1.0 + padding * 2)

        // Determine crop dimensions maintaining target aspect ratio
        let targetAspect = CGFloat(targetWidth) / CGFloat(targetHeight)
        var cropWidth: CGFloat
        var cropHeight: CGFloat

        if targetAspect > 1 {
            // Wider than tall
            cropWidth = cropSize * targetAspect
            cropHeight = cropSize
        } else {
            // Taller than wide
            cropWidth = cropSize
            cropHeight = cropSize / targetAspect
        }

        // Ensure crop area fits within image bounds
        cropWidth = min(cropWidth, imageWidth)
        cropHeight = min(cropHeight, imageHeight)

        // Adjust to maintain aspect ratio if we hit bounds
        let actualAspect = cropWidth / cropHeight
        if actualAspect > targetAspect {
            cropWidth = cropHeight * targetAspect
        } else if actualAspect < targetAspect {
            cropHeight = cropWidth / targetAspect
        }

        // Calculate crop origin (centered on face)
        var cropX = faceCenterX - cropWidth / 2
        var cropY = faceCenterY - cropHeight / 2

        // Clamp to image bounds
        cropX = max(0, min(cropX, imageWidth - cropWidth))
        cropY = max(0, min(cropY, imageHeight - cropHeight))

        // Create crop rect (CGImage uses top-left origin, but Vision uses bottom-left)
        // We need to flip Y coordinate
        let flippedY = imageHeight - cropY - cropHeight
        let cropRect = CGRect(x: cropX, y: flippedY, width: cropWidth, height: cropHeight)

        if verbose {
            print("  Crop rect: \(cropRect)")
        }

        // Crop the image
        guard let croppedCGImage = image.cropping(to: cropRect) else {
            throw FaceCropError.failedToCrop
        }

        // Resize to target dimensions
        let resizedImage = try resizeImage(croppedCGImage, to: CGSize(width: targetWidth, height: targetHeight))

        return resizedImage
    }

    private func resizeImage(_ image: CGImage, to size: CGSize) throws -> CGImage {
        let context = CGContext(
            data: nil,
            width: Int(size.width),
            height: Int(size.height),
            bitsPerComponent: image.bitsPerComponent,
            bytesPerRow: 0,
            space: image.colorSpace ?? CGColorSpaceCreateDeviceRGB(),
            bitmapInfo: image.bitmapInfo.rawValue
        )

        guard let context = context else {
            throw FaceCropError.failedToResize
        }

        context.interpolationQuality = .high
        context.draw(image, in: CGRect(origin: .zero, size: size))

        guard let resizedImage = context.makeImage() else {
            throw FaceCropError.failedToResize
        }

        return resizedImage
    }

    private func saveImage(_ image: CGImage, to url: URL) throws {
        let ext = url.pathExtension.lowercased()

        let bitmapRep = NSBitmapImageRep(cgImage: image)
        let data: Data?

        switch ext {
        case "png":
            data = bitmapRep.representation(using: .png, properties: [:])
        case "jpg", "jpeg":
            data = bitmapRep.representation(using: .jpeg, properties: [.compressionFactor: quality])
        case "tiff", "tif":
            data = bitmapRep.representation(using: .tiff, properties: [:])
        default:
            // Default to JPEG
            data = bitmapRep.representation(using: .jpeg, properties: [.compressionFactor: quality])
        }

        guard let imageData = data else {
            throw FaceCropError.failedToSave(url.path)
        }

        try imageData.write(to: url)
    }
}

enum FaceCropError: Error, LocalizedError {
    case failedToLoadImage(String)
    case noFaceDetected
    case failedToCrop
    case failedToResize
    case failedToSave(String)

    var errorDescription: String? {
        switch self {
        case .failedToLoadImage(let path):
            return "Failed to load image: \(path)"
        case .noFaceDetected:
            return "No face detected in image"
        case .failedToCrop:
            return "Failed to crop image"
        case .failedToResize:
            return "Failed to resize image"
        case .failedToSave(let path):
            return "Failed to save image: \(path)"
        }
    }
}

final class ConcurrentCounter: @unchecked Sendable {
    private let lock = NSLock()
    private var _success = 0
    private var _skipped = 0
    private var _errors = 0
    private var _processed = 0

    @discardableResult
    func incrementSuccess() -> Int {
        lock.lock()
        defer { lock.unlock() }
        _success += 1
        _processed += 1
        return _processed
    }

    @discardableResult
    func incrementSkipped() -> Int {
        lock.lock()
        defer { lock.unlock() }
        _skipped += 1
        _processed += 1
        return _processed
    }

    @discardableResult
    func incrementError() -> Int {
        lock.lock()
        defer { lock.unlock() }
        _errors += 1
        _processed += 1
        return _processed
    }

    var results: (success: Int, skipped: Int, errors: Int) {
        lock.lock()
        defer { lock.unlock() }
        return (_success, _skipped, _errors)
    }
}
