Skip to content

Commit

Permalink
embeddings: add GPU OOM recovering fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vicentebolea committed Jul 25, 2024
1 parent 16a953d commit 1bd381b
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/nrtk_explorer/library/embeddings_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import logging
import numpy as np
import gc
import timm

from nrtk_explorer.library import images_manager
Expand Down Expand Up @@ -75,11 +76,30 @@ def extract(self, paths, content=None, batch_size=32):
transformed_images.append(self.transform_image(img))

# Extract features from images
for batch in DataLoader(ImagesDataset(transformed_images), batch_size=batch_size):
# Copy image to device if using device
if self.device.type == "cuda":
batch = batch.cuda()

features.append(self.model(batch).numpy(force=True))

return np.vstack(features)
adjusted_batch_size = batch_size
while adjusted_batch_size >= 2:
try:
for batch in DataLoader(
ImagesDataset(transformed_images), batch_size=adjusted_batch_size
):
# Copy image to device if using device
if self.device.type == "cuda":
batch = batch.cuda()

features.append(self.model(batch).numpy(force=True))
return np.vstack(features)

except RuntimeError as e:
if "out of memory" in str(e):
adjusted_batch_size //= 2
print(f"WARNING: ran out of memory, Pytorch exception{e}")
print(
f"Recovered from OOM due to large batch size, reducing batch_size to {adjusted_batch_size}"
)

finally:
gc.collect()
torch.cuda.empty_cache()

# We should only reach here in case of irrecuperable OOM error
return None

0 comments on commit 1bd381b

Please sign in to comment.