diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0d6b00f78..929005084 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -91,7 +91,7 @@ jobs: run: docker compose -f label_studio_ml/examples/${{ matrix.backend_dir_name }}/docker-compose.yml up -d --build - name: Wait for stack - timeout-minutes: 10 + timeout-minutes: 20 run: | while [ "$(curl -s -o /dev/null -L -w ''%{http_code}'' "http://localhost:9090/health")" != "200" ]; do echo "=> Waiting for service to become available" && sleep 2s diff --git a/label_studio_ml/examples/yolo/Dockerfile b/label_studio_ml/examples/yolo/Dockerfile index ff417da64..0b7c09f1c 100644 --- a/label_studio_ml/examples/yolo/Dockerfile +++ b/label_studio_ml/examples/yolo/Dockerfile @@ -49,13 +49,11 @@ WORKDIR /app COPY . ./ -WORKDIR /app/models - # Download the YOLO models -RUN yolo predict model=yolov8m.pt source=/app/tests/car.jpg \ - && yolo predict model=yolov8n.pt source=/app/tests/car.jpg \ - && yolo predict model=yolov8n-cls.pt source=/app/tests/car.jpg \ - && yolo predict model=yolov8n-seg.pt source=/app/tests/car.jpg +RUN yolo predict model=/app/models/yolov8m.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n-cls.pt source=/app/tests/car.jpg \ + && yolo predict model=/app/models/yolov8n-seg.pt source=/app/tests/car.jpg WORKDIR /app diff --git a/label_studio_ml/examples/yolo/README_TIMELINE_LABELS.md b/label_studio_ml/examples/yolo/README_TIMELINE_LABELS.md index 6a72ebd40..643d8a940 100644 --- a/label_studio_ml/examples/yolo/README_TIMELINE_LABELS.md +++ b/label_studio_ml/examples/yolo/README_TIMELINE_LABELS.md @@ -173,6 +173,7 @@ The cache is used for incremental training on the fly and prediction speedup. - **Early stop on training data**: The model uses early stopping based on the F1 score and accuracy on the training data. This may lead to overfitting on the training data. It was made because of the lack of validation data when updating on one annotation. - **YOLO model limitations**: The model uses a pre-trained YOLO model trained on object classification tasks for feature extraction, which may not be optimal for all use cases such as event detection. This approach doesn't tune the YOLO model, it trains only the LSTM part upon. - **Label balance**: The model may struggle with imbalanced labels. Ensure that the labels are well-distributed in the training data. Consider modifying the loss function (BCEWithLogitsLoss) and using class pos weights to address this issue. +- **Training on all daa**: Training on all data is not yet implemented, so the model trains only on the last annotation. See `timeline_labels.py::fit()` for more details. ## Example use case: detecting a ball in football videos diff --git a/label_studio_ml/examples/yolo/model.py b/label_studio_ml/examples/yolo/model.py index 4752a9f92..9e0087631 100644 --- a/label_studio_ml/examples/yolo/model.py +++ b/label_studio_ml/examples/yolo/model.py @@ -132,27 +132,8 @@ def predict( def fit(self, event, data, **kwargs): """ - This method is called each time an annotation is created or updated - You can run your logic here to update the model and persist it to the cache - It is not recommended to perform long-running operations here, as it will block the main thread - Instead, consider running a separate process or a thread (like RQ worker) to perform the training - :param event: event type can be ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING') - :param data: the payload received from the event - (check [Webhook event reference](https://labelstud.io/guide/webhook_reference.html)) - - # use cache to retrieve the data from the previous fit() runs - old_data = self.get('my_data') - old_model_version = self.get('model_version') - print(f'Old data: {old_data}') - print(f'Old model version: {old_model_version}') - - # store new data to the cache - self.set('my_data', 'my_new_data_value') - self.set('model_version', 'my_new_model_version') - print(f'New data: {self.get("my_data")}') - print(f'New model version: {self.get("model_version")}') - - print('fit() is not implemented!') + This method is called each time an annotation is created or updated. + Or it's called when "Start training" clicked on the model in the project settings. """ results = {} control_models = self.detect_control_models()