builtin-programs/recognition/trocr.folk

When when the TrOCR text recognizer is /any/ /any/ with environment /any/ {
    Wish -keep 500ms to load the TrOCR text recognizer
}

When /someone/ wishes to load the TrOCR text recognizer &\
     the image uvx argtype definer is /defineImageArgtype/ {
    fn defineImageArgtype

    set py [Uvx --with transformers --with pillow --with torch --with protobuf]
    defineImageArgtype $py

    $py exec {
        from transformers import TrOCRProcessor, VisionEncoderDecoderModel
        import os
        import sys
        import torch
        import time

        # Determine device (prefer CUDA > MPS > CPU)
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"

        # Load TrOCR model
        TROCR_PATH = os.path.expanduser("~/folk-data/trocr")
        try:
            processor = TrOCRProcessor.from_pretrained(TROCR_PATH)
            model = VisionEncoderDecoderModel.from_pretrained(TROCR_PATH)
            print("trocr: Loaded TrOCR model from disk.", file=sys.stderr, flush=True)
        except Exception:
            print("trocr: Model not saved; loading from Internet.", file=sys.stderr, flush=True)
            processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
            processor.save_pretrained(TROCR_PATH)
            model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
            model.save_pretrained(TROCR_PATH)
            print("trocr: Loaded TrOCR model from Internet.", file=sys.stderr, flush=True)

        model.to(device)
        print(f"trocr: Using device: {device}", file=sys.stderr, flush=True)
    }

    $py def ocrImage {Image image} {
        start_time = time.time()

        # Run TrOCR on the entire image
        with torch.no_grad():
            pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
            generated_ids = model.generate(pixel_values)
            text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        elapsed = time.time() - start_time
        print(f"trocr: Result: {text} ({elapsed:.3f}s)", file=sys.stderr, flush=True)

        return text
    }

    fn TrOCR {im} { return [$py ocrImage $im] }
    Claim the TrOCR text recognizer is [fn TrOCR]
}