Skip to content

A multi-backend (TensorFlow, PyTorch, JAX, and NumPy) implementation of the Segment Anything model in Keras 3.0

License

Notifications You must be signed in to change notification settings

tirthasheshpatel/segment_anything_keras

Repository files navigation

Segment Anything Model in Multi-Backend Keras

This is an implementation of the Segment Anything predictor and automatic mask generator in Keras 3.

The demos uses KerasCV's Segment Anything model:

Install the package

pip install git+https://github.com/tirthasheshpatel/segment_anything_keras.git

Install the required dependencies:

pip install -U Pillow numpy keras keras-cv

Install TensorFlow, JAX, or PyTorch, whichever backend you'd like to use.

To get all the dependencies and all the backends to run the demos, do:

pip install -r requirements.txt

Getting the pretrained Segment Anything Model

# Use TensorFlow backend, choose any you want
import os
os.environ['KERAS_BACKEND'] = "tensorflow"

from keras_cv.models import SegmentAnythingModel
from sam_keras import SAMPredictor

# Get the huge model trained on the SA-1B dataset.
# Other available options are:
#   - "sam_base_sa1b"
#   - "sam_large_sa1b"
model = SegmentAnythingModel.from_preset("sam_huge_sa1b")

# Create the predictor
predictor = SAMPredictor(model)

# Now you can use the predictor just like the one on the original repo.
# The only difference is list of input dicts isn't supported; instead
# pass each input dict separately to the `predict` method.

Notes

Right now JAX and TensorFlow have large compile-time overhead. Prompt encoder recompiles each time a different combination of prompts (points only, points + boxes, boxes only, etc) is passed. To avoid this, compile the model with run_eagerly=True and jit_compile=False.

About

A multi-backend (TensorFlow, PyTorch, JAX, and NumPy) implementation of the Segment Anything model in Keras 3.0

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published