builtin-programs/recognition/sam2.folk

When when the SAM2 segmenter is /any/ /any/ with environment /any/ {
    Wish -keep 500ms to load the SAM2 segmenter
}

When /someone/ wishes to load the SAM2 segmenter &\
     the image uvx argtype definer is /defineImageArgtype/ {
    fn defineImageArgtype

    set py [Uvx --with pillow --with torch --with numpy \
                --with huggingface_hub \
                --with "git+https://github.com/facebookresearch/sam2.git"]

    defineImageArgtype $py

    $py exec {
        print(f"sam2: Boot", file=sys.stderr)

        import torch
        import numpy as np
        import threading
        from sam2.sam2_image_predictor import SAM2ImagePredictor
        import sys

        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"

        print(f"sam2: Loading.", file=sys.stderr)
        predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-small", device=device)
        predictor_lock = threading.Lock()
        print(f"sam2: Using device: {device}", file=sys.stderr, flush=True)
    }

    $py def segment {Image image {list list num} pointCoords {list num} pointLabels} {
        image_np = np.array(image.convert("RGB"))
        with predictor_lock, torch.inference_mode():
            predictor.set_image(image_np)
            masks, scores, _ = predictor.predict(
                point_coords=pointCoords,
                point_labels=pointLabels,
                multimask_output=True,
            )
        best = int(scores.argmax())
        # Note: we spend 10-20ms just on serialization/deserialization
        # of the mask (it's basically a whole image).
        return {"mask": masks[best].tolist(), "score": float(scores[best])}
    }

    fn SAM2 args { return [$py segment {*}$args] }
    Claim the SAM2 segmenter is [fn SAM2]

    When the image library is /imageLib/ {
        set cc [C]
        $cc extend $imageLib
        $cc proc maskToBinaryImage {Jim_Obj* mask} Image {
            int height = Jim_ListLength(interp, mask);
            int width = 0;
            if (height > 0) {
                Jim_Obj *firstRow = Jim_ListGetIndex(interp, mask, 0);
                width = Jim_ListLength(interp, firstRow);
            }

            Image result = imageNew(width, height, 1, 0);
            for (int y = 0; y < height; y++) {
                Jim_Obj *row = Jim_ListGetIndex(interp, mask, y);
                for (int x = 0; x < width; x++) {
                    Jim_Obj *elem = Jim_ListGetIndex(interp, row, x);
                    double val; Jim_GetDouble(interp, elem, &val);
                    result.data[y * result.bytesPerRow + x] = val > 0.5 ? 255 : 0;
                }
            }
            return result;
        }
        $cc proc applyMaskToImage {Image im Jim_Obj* mask} Image {
            int height = Jim_ListLength(interp, mask);
            int width = 0;
            if (height > 0) {
                Jim_Obj *firstRow = Jim_ListGetIndex(interp, mask, 0);
                width = Jim_ListLength(interp, firstRow);
            }

            Image result = imageNew(im.width, im.height, im.components, im.uniq);
            for (int y = 0; y < height; y++) {
                Jim_Obj *row = Jim_ListGetIndex(interp, mask, y);
                for (int x = 0; x < width; x++) {
                    Jim_Obj *elem = Jim_ListGetIndex(interp, row, x);
                    double val; Jim_GetDouble(interp, elem, &val);
                    uint8_t m = val > 0.5 ? 255 : 0;
                    for (int c = 0; c < im.components; c++) {
                        result.data[y * result.bytesPerRow + x * im.components + c] =
                            m ? im.data[y * im.bytesPerRow + x * im.components + c] : 0;
                    }
                }
            }
            return result;
        }
        set maskToImageLib [$cc compile]
        Claim the SAM2 mask-to-image library is $maskToImageLib
    }
}