Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port to SAM2.1 #1

Open
aakash-vwo opened this issue Oct 23, 2024 · 0 comments
Open

Port to SAM2.1 #1

aakash-vwo opened this issue Oct 23, 2024 · 0 comments

Comments

@aakash-vwo
Copy link

Saw you are assoicated with Roboflow, so I am creating this issue
Could you nudge me in the right direction to port SoM to SAM2.1?

I cloned
https://huggingface.co/spaces/Roboflow/SoM/blob/main/app.py

and rewrote sam_utils.py

import numpy as np
import supervision as sv
import torch
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

SAM_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
SAM_CONFIG = "sam2_hiera_s.yaml"


def load_sam_model(device: torch.device) -> SAM2ImagePredictor:
    model = build_sam2(SAM_CONFIG, SAM_CHECKPOINT, device=device)
    return SAM2ImagePredictor(sam_model=model)


def sam_inference(image: np.ndarray, model: SAM2ImagePredictor) -> sv.Detections:
    pil_image = Image.fromarray(image)
    model.set_image(np.array(pil_image.convert("RGB")))
    masks, _, _ = model.predict()

    return sv.Detections(xyxy=sv.mask_to_xyxy(masks), mask=masks.astype(bool))


def sam_interactive_inference(
    image: np.ndarray, mask: np.ndarray, model: SAM2ImagePredictor
) -> sv.Detections:
    pil_image = Image.fromarray(image)
    model.set_image(np.array(pil_image.convert("RGB")))
    masks = []
    for polygon in sv.mask_to_polygons(mask.astype(bool)):
        random_point_indexes = np.random.choice(polygon.shape[0], size=5, replace=True)
        input_point = polygon[random_point_indexes]
        input_label = np.ones(5)
        mask, _, _ = model.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
        masks.append(mask[0])  # SAM2 returns a list of masks, we take the first one
    masks = np.array(masks, dtype=bool)
    return sv.Detections(xyxy=sv.mask_to_xyxy(masks), mask=masks)

I think I am doing something wrong because I only get 1 annotation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant