Skip to content

Commit

Permalink
embeddings,object_detector: add OOM recover fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vicentebolea committed Jul 26, 2024
1 parent 16a953d commit 61361ee
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 26 deletions.
12 changes: 10 additions & 2 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,19 @@ def __init__(self, server):

self.server.controller.add("on_server_ready")(self.on_server_ready)
self._on_hover_fn = None
self.detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50")

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.state.change("current_dataset")(self.on_current_dataset_change)
self.state.change("current_num_elements")(self.on_current_num_elements_change)
self.state.change("object_detection_model")(self.on_object_detection_model_change)

self.on_object_detection_model_change(self.state.object_detection_model)
self.on_current_dataset_change(self.state.current_dataset)

def on_object_detection_model_change(self, model_name, **kwargs):
self.detector = object_detector.ObjectDetector(model_name=model_name)

def set_on_transform(self, fn):
self._on_transform_fn = fn

Expand Down Expand Up @@ -191,7 +195,11 @@ def compute_annotations(self, ids):
if len(ids) == 0:
return

predictions = self.detector.eval(image_ids=ids, content=self.context.image_objects)
predictions = self.detector.eval(
image_ids=ids,
content=self.context.image_objects,
batch_size=int(self.state.object_detection_batch_size),
)

for id_, annotations in predictions.items():
image_annotations = []
Expand Down
7 changes: 7 additions & 0 deletions src/nrtk_explorer/app/ui/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf
emit_value=True,
map_options=True,
)
quasar.QInput(
v_model=("object_detection_batch_size", 32),
filled=True,
stack_label=True,
label="Batch Size",
type="number",
)

filter_title_slot, filter_content_slot, filter_actions_slot = ui.card("collapse_filter")

Expand Down
41 changes: 32 additions & 9 deletions src/nrtk_explorer/library/embeddings_extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import gc
import logging
import numpy as np
import timm
import torch

from nrtk_explorer.library import images_manager
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -75,11 +76,33 @@ def extract(self, paths, content=None, batch_size=32):
transformed_images.append(self.transform_image(img))

# Extract features from images
for batch in DataLoader(ImagesDataset(transformed_images), batch_size=batch_size):
# Copy image to device if using device
if self.device.type == "cuda":
batch = batch.cuda()

features.append(self.model(batch).numpy(force=True))

return np.vstack(features)
adjusted_batch_size = batch_size
while adjusted_batch_size > 0:
try:
for batch in DataLoader(
ImagesDataset(transformed_images), batch_size=adjusted_batch_size
):
# Copy image to device if using device
if self.device.type == "cuda":
batch = batch.cuda()

features.append(self.model(batch).numpy(force=True))
return np.vstack(features)

except RuntimeError as e:
if "out of memory" in str(e) and adjusted_batch_size > 1:
previous_batch_size = adjusted_batch_size
adjusted_batch_size = adjusted_batch_size // 2
print(
f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}"
)
else:
raise

finally:
# Pytorch needs to freed its allocations outside of the exception context
gc.collect()
torch.cuda.empty_cache()

# We should never reach here
return None
52 changes: 37 additions & 15 deletions src/nrtk_explorer/library/object_detector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import logging
import torch
import transformers
Expand All @@ -6,7 +7,7 @@

from nrtk_explorer.library import images_manager

ImageIdToAnnotations = dict[str, Sequence[dict]]
ImageIdToAnnotations = Optional[dict[str, Sequence[dict]]]


class ObjectDetector:
Expand Down Expand Up @@ -77,17 +78,38 @@ def eval(
batches[img.size][0].append(path)
batches[img.size][1].append(img)

predictions_in_baches = [
zip(
image_ids,
self.pipeline(images, batch_size=batch_size),
)
for image_ids, images in batches.values()
]

predictions_by_image_id = {
image_id: predictions
for batch in predictions_in_baches
for image_id, predictions in batch
}
return predictions_by_image_id
adjusted_batch_size = batch_size
while adjusted_batch_size > 0:
try:
predictions_in_baches = [
zip(
image_ids,
self.pipeline(images, batch_size=adjusted_batch_size),
)
for image_ids, images in batches.values()
]

predictions_by_image_id = {
image_id: predictions
for batch in predictions_in_baches
for image_id, predictions in batch
}
return predictions_by_image_id

except RuntimeError as e:
if "out of memory" in str(e) and adjusted_batch_size > 1:
previous_batch_size = adjusted_batch_size
adjusted_batch_size = adjusted_batch_size // 2
print(
f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}"
)
else:
raise

finally:
# Pytorch needs to freed its allocations outside of the exception context
gc.collect()
torch.cuda.empty_cache()

# We should never reach here
return None

0 comments on commit 61361ee

Please sign in to comment.