-
Notifications
You must be signed in to change notification settings - Fork 261
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: RND-114: Add SAM2 integration for Video Object Tracking (#596)
Co-authored-by: nik <[email protected]> Co-authored-by: Micaela Kaplan <[email protected]>
- Loading branch information
1 parent
a9cf8e7
commit d803e87
Showing
11 changed files
with
677 additions
and
0 deletions.
There are no files selected for viewing
59 changes: 59 additions & 0 deletions
59
label_studio_ml/examples/segment_anything_2_video/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime | ||
ARG DEBIAN_FRONTEND=noninteractive | ||
ARG TEST_ENV | ||
|
||
WORKDIR /app | ||
|
||
RUN conda update conda -y | ||
|
||
RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \ | ||
--mount=type=cache,target="/var/lib/apt/lists",sharing=locked \ | ||
apt-get -y update \ | ||
&& apt-get install -y git \ | ||
&& apt-get install -y wget \ | ||
&& apt-get install -y g++ freeglut3-dev build-essential libx11-dev \ | ||
libxmu-dev libxi-dev libglu1-mesa libglu1-mesa-dev libfreeimage-dev \ | ||
&& apt-get -y install ffmpeg libsm6 libxext6 libffi-dev python3-dev python3-pip gcc | ||
|
||
ENV PYTHONUNBUFFERED=1 \ | ||
PYTHONDONTWRITEBYTECODE=1 \ | ||
PIP_CACHE_DIR=/.cache \ | ||
PORT=9090 \ | ||
WORKERS=2 \ | ||
THREADS=4 \ | ||
CUDA_HOME=/usr/local/cuda \ | ||
SEGMENT_ANYTHING_2_REPO_PATH=/segment-anything-2 | ||
|
||
RUN conda install -c "nvidia/label/cuda-12.1.1" cuda -y | ||
ENV CUDA_HOME=/opt/conda \ | ||
TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6+PTX;8.9;9.0" | ||
|
||
# install base requirements | ||
COPY requirements-base.txt . | ||
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ | ||
pip install -r requirements-base.txt | ||
|
||
COPY requirements.txt . | ||
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ | ||
pip3 install -r requirements.txt | ||
|
||
# install segment-anything-2 | ||
RUN cd / && git clone --depth 1 --branch main --single-branch https://github.com/facebookresearch/segment-anything-2.git | ||
WORKDIR /segment-anything-2 | ||
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ | ||
pip3 install -e . | ||
RUN cd checkpoints && ./download_ckpts.sh | ||
|
||
WORKDIR /app | ||
|
||
# install test requirements if needed | ||
COPY requirements-test.txt . | ||
# build only when TEST_ENV="true" | ||
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ | ||
if [ "$TEST_ENV" = "true" ]; then \ | ||
pip3 install -r requirements-test.txt; \ | ||
fi | ||
|
||
COPY . ./ | ||
|
||
CMD ["/app/start.sh"] |
63 changes: 63 additions & 0 deletions
63
label_studio_ml/examples/segment_anything_2_video/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
This guide describes the simplest way to start using **SegmentAnything 2** with Label Studio. | ||
|
||
This repository is specifically for working with object tracking in videos. For working with images, | ||
see the [segment_anything_2_image repository](https://github.com/HumanSignal/label-studio-ml-backend/tree/master/label_studio_ml/examples/segment_anything_2_image) | ||
|
||
![sam2](./Sam2Video.gif) | ||
|
||
## Running from source | ||
|
||
1. To run the ML backend without Docker, you have to clone the repository and install all dependencies using pip: | ||
|
||
```bash | ||
git clone https://github.com/HumanSignal/label-studio-ml-backend.git | ||
cd label-studio-ml-backend | ||
pip install -e . | ||
cd label_studio_ml/examples/segment_anything_2_video | ||
pip install -r requirements.txt | ||
``` | ||
|
||
2. Download [`segment-anything-2` repo](https://github.com/facebookresearch/segment-anything-2) into the root directory. Install SegmentAnything model and download checkpoints using [the official Meta documentation](https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#installation). Make sure that you complete the steps for downloadingn the checkpoint files! | ||
|
||
3. Export the following environment variables (fill them in with your credentials!): | ||
- LABEL_STUDIO_URL: the http:// or https:// link to your label studio instance (include the prefix!) | ||
- LABEL_STUDIO_API_KEY: your api key for label studio, available in your profile. | ||
|
||
4. Then you can start the ML backend on the default port `9090`: | ||
|
||
```bash | ||
cd ../ | ||
label-studio-ml start ./segment_anything_2_video | ||
``` | ||
Note that if you're running in a cloud server, you'll need to run on an exposed port. To change the port, add `-p <port number>` to the end of the start command above. | ||
5. Connect running ML backend server to Label Studio: go to your project `Settings -> Machine Learning -> Add Model` and specify `http://localhost:9090` as a URL. Read more in the official [Label Studio documentation](https://labelstud.io/guide/ml#Connect-the-model-to-Label-Studio). | ||
Again, if you're running in the cloud, you'll need to replace this localhost location with whatever the external ip address is of your container, along with the exposed port. | ||
|
||
# Labeling Config | ||
For your project, you can use any labeling config with video properties. Here's a basic one to get you started! | ||
|
||
<View> | ||
<Labels name="videoLabels" toName="video" allowEmpty="true"> | ||
|
||
|
||
|
||
<Label value="Player" background="#11A39E"/><Label value="Ball" background="#D4380D"/></Labels> | ||
|
||
<!-- Please specify FPS carefully, it will be used for all project videos --> | ||
<Video name="video" value="$video" framerate="25.0"/> | ||
<VideoRectangle name="box" toName="video" smart="true"/> | ||
</View><!--{ | ||
"video": "/static/samples/opossum_snow.mp4" | ||
}--> | ||
|
||
|
||
# Known limitiations | ||
- As of 8/11/2024, SAM2 only runs on GPU servers. | ||
- Currently, we only support the tracking of one object in video, although SAM2 can support multiple. | ||
- Currently, we do not support video segmentation. | ||
- No Docker support | ||
|
||
If you want to contribute to this repository to help with some of these limitations, you can submit a PR. | ||
# Customization | ||
|
||
The ML backend can be customized by adding your own models and logic inside the `./segment_anything_2_video` directory. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
121 changes: 121 additions & 0 deletions
121
label_studio_ml/examples/segment_anything_2_video/_wsgi.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import os | ||
import argparse | ||
import json | ||
import logging | ||
import logging.config | ||
|
||
logging.config.dictConfig({ | ||
"version": 1, | ||
"formatters": { | ||
"standard": { | ||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s" | ||
} | ||
}, | ||
"handlers": { | ||
"console": { | ||
"class": "logging.StreamHandler", | ||
"level": os.getenv('LOG_LEVEL'), | ||
"stream": "ext://sys.stdout", | ||
"formatter": "standard" | ||
} | ||
}, | ||
"root": { | ||
"level": os.getenv('LOG_LEVEL'), | ||
"handlers": [ | ||
"console" | ||
], | ||
"propagate": True | ||
} | ||
}) | ||
|
||
from label_studio_ml.api import init_app | ||
from model import NewModel | ||
|
||
|
||
_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json') | ||
|
||
|
||
def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH): | ||
if not os.path.exists(config_path): | ||
return dict() | ||
with open(config_path) as f: | ||
config = json.load(f) | ||
assert isinstance(config, dict) | ||
return config | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='Label studio') | ||
parser.add_argument( | ||
'-p', '--port', dest='port', type=int, default=9090, | ||
help='Server port') | ||
parser.add_argument( | ||
'--host', dest='host', type=str, default='0.0.0.0', | ||
help='Server host') | ||
parser.add_argument( | ||
'--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='), | ||
help='Additional LabelStudioMLBase model initialization kwargs') | ||
parser.add_argument( | ||
'-d', '--debug', dest='debug', action='store_true', | ||
help='Switch debug mode') | ||
parser.add_argument( | ||
'--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None, | ||
help='Logging level') | ||
parser.add_argument( | ||
'--model-dir', dest='model_dir', default=os.path.dirname(__file__), | ||
help='Directory where models are stored (relative to the project directory)') | ||
parser.add_argument( | ||
'--check', dest='check', action='store_true', | ||
help='Validate model instance before launching server') | ||
parser.add_argument('--basic-auth-user', | ||
default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None), | ||
help='Basic auth user') | ||
|
||
parser.add_argument('--basic-auth-pass', | ||
default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None), | ||
help='Basic auth pass') | ||
|
||
args = parser.parse_args() | ||
|
||
# setup logging level | ||
if args.log_level: | ||
logging.root.setLevel(args.log_level) | ||
|
||
def isfloat(value): | ||
try: | ||
float(value) | ||
return True | ||
except ValueError: | ||
return False | ||
|
||
def parse_kwargs(): | ||
param = dict() | ||
for k, v in args.kwargs: | ||
if v.isdigit(): | ||
param[k] = int(v) | ||
elif v == 'True' or v == 'true': | ||
param[k] = True | ||
elif v == 'False' or v == 'false': | ||
param[k] = False | ||
elif isfloat(v): | ||
param[k] = float(v) | ||
else: | ||
param[k] = v | ||
return param | ||
|
||
kwargs = get_kwargs_from_config() | ||
|
||
if args.kwargs: | ||
kwargs.update(parse_kwargs()) | ||
|
||
if args.check: | ||
print('Check "' + NewModel.__name__ + '" instance creation..') | ||
model = NewModel(**kwargs) | ||
|
||
app = init_app(model_class=NewModel, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass) | ||
|
||
app.run(host=args.host, port=args.port, debug=args.debug) | ||
|
||
else: | ||
# for uWSGI use | ||
app = init_app(model_class=NewModel) |
41 changes: 41 additions & 0 deletions
41
label_studio_ml/examples/segment_anything_2_video/docker-compose.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
version: "3.8" | ||
|
||
services: | ||
segment_anything_2_video: | ||
container_name: segment_anything_2_video | ||
image: humansignal/segment_anything_2_video:v0 | ||
build: | ||
context: . | ||
args: | ||
TEST_ENV: ${TEST_ENV} | ||
environment: | ||
# specify these parameters if you want to use basic auth for the model server | ||
- BASIC_AUTH_USER= | ||
- BASIC_AUTH_PASS= | ||
# set the log level for the model server | ||
- LOG_LEVEL=DEBUG | ||
# any other parameters that you want to pass to the model server | ||
- ANY=PARAMETER | ||
# specify the number of workers and threads for the model server | ||
- WORKERS=1 | ||
- THREADS=8 | ||
# specify the model directory (likely you don't need to change this) | ||
- MODEL_DIR=/data/models | ||
# specify device | ||
- DEVICE=cuda # or 'cpu' (coming soon) | ||
# SAM2 model config | ||
- MODEL_CONFIG=sam2_hiera_l.yaml | ||
# SAM2 checkpoint | ||
- MODEL_CHECKPOINT=sam2_hiera_large.pt | ||
|
||
# Specify the Label Studio URL and API key to access | ||
# uploaded, local storage and cloud storage files. | ||
# Do not use 'localhost' as it does not work within Docker containers. | ||
# Use prefix 'http://' or 'https://' for the URL always. | ||
# Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows). | ||
- LABEL_STUDIO_URL= | ||
- LABEL_STUDIO_API_KEY= | ||
ports: | ||
- "9090:9090" | ||
volumes: | ||
- "./data/server:/data" |
Oops, something went wrong.