Skip to content

Commit

Permalink
Fix path model, timeout in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
makseq committed Sep 17, 2024
1 parent 1d2b30d commit 2ee27d5
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ jobs:
run: docker compose -f label_studio_ml/examples/${{ matrix.backend_dir_name }}/docker-compose.yml up -d --build

- name: Wait for stack
timeout-minutes: 10
timeout-minutes: 20
run: |
while [ "$(curl -s -o /dev/null -L -w ''%{http_code}'' "http://localhost:9090/health")" != "200" ]; do
echo "=> Waiting for service to become available" && sleep 2s
Expand Down
10 changes: 4 additions & 6 deletions label_studio_ml/examples/yolo/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,11 @@ WORKDIR /app

COPY . ./

WORKDIR /app/models

# Download the YOLO models
RUN yolo predict model=yolov8m.pt source=/app/tests/car.jpg \
&& yolo predict model=yolov8n.pt source=/app/tests/car.jpg \
&& yolo predict model=yolov8n-cls.pt source=/app/tests/car.jpg \
&& yolo predict model=yolov8n-seg.pt source=/app/tests/car.jpg
RUN yolo predict model=/app/models/yolov8m.pt source=/app/tests/car.jpg \
&& yolo predict model=/app/models/yolov8n.pt source=/app/tests/car.jpg \
&& yolo predict model=/app/models/yolov8n-cls.pt source=/app/tests/car.jpg \
&& yolo predict model=/app/models/yolov8n-seg.pt source=/app/tests/car.jpg

WORKDIR /app

Expand Down
1 change: 1 addition & 0 deletions label_studio_ml/examples/yolo/README_TIMELINE_LABELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ The cache is used for incremental training on the fly and prediction speedup.
- **Early stop on training data**: The model uses early stopping based on the F1 score and accuracy on the training data. This may lead to overfitting on the training data. It was made because of the lack of validation data when updating on one annotation.
- **YOLO model limitations**: The model uses a pre-trained YOLO model trained on object classification tasks for feature extraction, which may not be optimal for all use cases such as event detection. This approach doesn't tune the YOLO model, it trains only the LSTM part upon.
- **Label balance**: The model may struggle with imbalanced labels. Ensure that the labels are well-distributed in the training data. Consider modifying the loss function (BCEWithLogitsLoss) and using class pos weights to address this issue.
- **Training on all daa**: Training on all data is not yet implemented, so the model trains only on the last annotation. See `timeline_labels.py::fit()` for more details.

## Example use case: detecting a ball in football videos

Expand Down
23 changes: 2 additions & 21 deletions label_studio_ml/examples/yolo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,8 @@ def predict(

def fit(self, event, data, **kwargs):
"""
This method is called each time an annotation is created or updated
You can run your logic here to update the model and persist it to the cache
It is not recommended to perform long-running operations here, as it will block the main thread
Instead, consider running a separate process or a thread (like RQ worker) to perform the training
:param event: event type can be ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING')
:param data: the payload received from the event
(check [Webhook event reference](https://labelstud.io/guide/webhook_reference.html))
# use cache to retrieve the data from the previous fit() runs
old_data = self.get('my_data')
old_model_version = self.get('model_version')
print(f'Old data: {old_data}')
print(f'Old model version: {old_model_version}')
# store new data to the cache
self.set('my_data', 'my_new_data_value')
self.set('model_version', 'my_new_model_version')
print(f'New data: {self.get("my_data")}')
print(f'New model version: {self.get("model_version")}')
print('fit() is not implemented!')
This method is called each time an annotation is created or updated.
Or it's called when "Start training" clicked on the model in the project settings.
"""
results = {}
control_models = self.detect_control_models()
Expand Down

0 comments on commit 2ee27d5

Please sign in to comment.