Skip to content

Commit

Permalink
Merge pull request #125 from ViCCo-Group/memory_fixes
Browse files Browse the repository at this point in the history
Memory fixes
  • Loading branch information
LukasMut authored Jan 18, 2023
2 parents b4b53dc + 9703877 commit 97696f9
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 136 deletions.
1 change: 0 additions & 1 deletion tests/extractor/extraction/test_pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import tests.helper as helper
import thingsvision.core.extraction.extractor

Array = np.ndarray

Expand Down
4 changes: 2 additions & 2 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ def create_extractor_and_dataloader(model_name: str, pretrained: bool, source: s
dataset = ImageDataset(
root=TEST_PATH,
out_path=OUT_PATH,
backend=extractor.backend,
backend=extractor.get_backend(),
transforms=extractor.get_transformations(),
)
batches = DataLoader(
dataset,
batch_size=BATCH_SIZE,
backend=extractor.backend,
backend=extractor.get_backend(),
)
return extractor, dataset, batches

Expand Down
2 changes: 1 addition & 1 deletion thingsvision/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.2.17"
__version__ = "2.2.18"
11 changes: 8 additions & 3 deletions thingsvision/core/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .helpers import (center_features, create_custom_extractor,
create_model_extractor, get_extractor,
get_extractor_from_model, normalize_features)
from .helpers import (
center_features,
create_custom_extractor,
create_model_extractor,
get_extractor,
get_extractor_from_model,
normalize_features,
)
61 changes: 41 additions & 20 deletions thingsvision/core/extraction/base.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
import abc
import os
import warnings
from dataclasses import dataclass, field
from typing import Any, Iterator
from typing import Any, Callable, Iterator, List

import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

Array = np.ndarray


@dataclass(init=True, repr=True)
class BaseExtractor:
model_name: str
pretrained: bool
device: str
model_path: str = None
model_parameters: Any = field(default_factory=lambda: {})
model: Any = None
preprocess: Any = None

def __post_init__(self) -> None:
if not self.model:
self.load_model()
class BaseExtractor(metaclass=abc.ABCMeta):
def __init__(self, device, preprocess) -> None:
self.device = device
self.preprocess = preprocess

def show(self) -> None:
warnings.warn(
Expand All @@ -30,9 +21,35 @@ def show(self) -> None:
)
self.show_model()

@abc.abstractmethod
def show_model(self) -> None:
print(self._show_model())
print()
"""Show model."""
raise NotImplementedError()

@abc.abstractmethod
def get_default_transformation(
self,
mean: List[float],
std: List[float],
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
) -> Callable:
raise NotImplementedError()

@abc.abstractmethod
def get_module_names(self):
raise NotImplementedError()

@abc.abstractmethod
def load_model(self):
raise NotImplementedError()

@abc.abstractmethod
def _extract_batch(
self, batch: Array, module_name: str, flatten_acts: bool
) -> Array:
raise NotImplementedError()

def extract_features(
self,
Expand All @@ -41,7 +58,7 @@ def extract_features(
flatten_acts: bool,
output_dir: str = None,
step_size: int = None,
) -> Array:
):
"""Extract hidden unit activations (at specified layer) for every image in the database.
Parameters
Expand Down Expand Up @@ -94,7 +111,7 @@ def extract_features(
enumerate(batches, start=1), desc="Batch", total=len(batches)
):
features.append(
self._extract_features(
self._extract_batch(
batch=batch, module_name=module_name, flatten_acts=flatten_acts
)
)
Expand Down Expand Up @@ -136,3 +153,7 @@ def get_transformations(
mean, std, resize_dim, crop_dim, apply_center_crop
)
return composition

@abc.abstractmethod
def get_backend(self) -> str:
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import os
from dataclasses import dataclass
from typing import Any, Dict

import numpy as np
import timm
import torchvision

import tensorflow as tf
import tensorflow.keras.applications as tensorflow_models
import timm
import torch
import torchvision

try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url

from .base import BaseExtractor
from .mixin import PyTorchMixin, TensorFlowMixin
from .tensorflow import TensorFlowExtractor
from .torch import PyTorchExtractor

# neccessary to prevent gpu memory conflicts between torch and tf
gpus = tf.config.list_physical_devices("GPU")
Expand All @@ -34,8 +34,7 @@
Array = np.ndarray


@dataclass(repr=True)
class TorchvisionExtractor(BaseExtractor, PyTorchMixin):
class TorchvisionExtractor(PyTorchExtractor):
def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -105,8 +104,7 @@ def get_default_transformation(
return transforms


@dataclass(repr=True)
class TimmExtractor(BaseExtractor, PyTorchMixin):
class TimmExtractor(PyTorchExtractor):
def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -135,8 +133,7 @@ def load_model_from_source(self) -> None:
)


@dataclass(repr=True)
class KerasExtractor(BaseExtractor, TensorFlowMixin):
class KerasExtractor(TensorFlowExtractor):
def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -175,8 +172,7 @@ def load_model_from_source(self) -> None:
)


@dataclass(repr=True)
class SSLExtractor(BaseExtractor, PyTorchMixin):
class SSLExtractor(PyTorchExtractor):
ENV_TORCH_HOME = "TORCH_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"
Expand Down
23 changes: 14 additions & 9 deletions thingsvision/core/extraction/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
import numpy as np
import thingsvision.custom_models as custom_models
import thingsvision.custom_models.cornet as cornet
import torch
from torchtyping import TensorType

from .base import BaseExtractor
from .extractor import (KerasExtractor, SSLExtractor, TimmExtractor,
TorchvisionExtractor)
from .mixin import PyTorchMixin, TensorFlowMixin
import torch

from .extractors import (
KerasExtractor,
SSLExtractor,
TimmExtractor,
TorchvisionExtractor,
)
from .tensorflow import TensorFlowExtractor
from .torch import PyTorchExtractor

Array = np.ndarray
AxisError = np.AxisError
Expand Down Expand Up @@ -43,9 +48,9 @@ def create_custom_extractor(
f"\nCould not find {model_name} among custom models.\nChoose a different model.\n"
)

backend_mixin = PyTorchMixin if backend == "pt" else TensorFlowMixin
Extractor = PyTorchExtractor if backend == "pt" else TensorFlowExtractor

class CustomExtractor(BaseExtractor, backend_mixin):
class CustomExtractor(Extractor):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -140,9 +145,9 @@ def create_model_extractor(
extractor: Any
The custom extractor class.
"""
backend_mixin = PyTorchMixin if backend == "pt" else TensorFlowMixin
Extractor = PyTorchExtractor if backend == "pt" else TensorFlowExtractor

class ModelExtractor(BaseExtractor, backend_mixin):
class ModelExtractor(Extractor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
92 changes: 92 additions & 0 deletions thingsvision/core/extraction/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
from dataclasses import field
from typing import Any, List

import numpy as np

from .base import BaseExtractor

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress tensorflow warnings
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

Array = np.ndarray


class TensorFlowExtractor(BaseExtractor):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Any = field(default_factory=lambda: {}),
model: Any = None,
preprocess: Any = None,
) -> None:
super().__init__(device, preprocess)
self.model_name = model_name
self.pretrained = pretrained
self.model_path = model_path
self.model_parameters = model_parameters
self.model = model

if not self.model:
self.load_model()

def _extract_batch(
self, batch: Array, module_name: str, flatten_acts: bool
) -> Array:
layer_out = [self.model.get_layer(module_name).output]
activation_model = keras.models.Model(
inputs=self.model.input,
outputs=layer_out,
)
activations = activation_model.predict(batch)
if flatten_acts:
activations = activations.reshape(activations.shape[0], -1)

return activations

def show_model(self) -> str:
return self.model.summary()

def load_model_from_source(self) -> None:
raise NotImplementedError

def load_model(self) -> None:
self.load_model_from_source()
if self.model_path:
self.model.load_weights(self.model_path)
self.model.trainable = False

def get_module_names(self) -> List[str]:
return [l._name for l in self.model.submodules]

def get_default_transformation(
self,
mean: List[float],
std: List[float],
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
) -> Any:
resize_dim = crop_dim
composes = [layers.experimental.preprocessing.Resizing(resize_dim, resize_dim)]
if apply_center_crop:
pass
# TODO: fix center crop problem with Keras
# composes.append(layers.experimental.preprocessing.CenterCrop(crop_dim, crop_dim))

composes += [
layers.experimental.preprocessing.Normalization(
mean=mean, variance=[std_ * std_ for std_ in std]
)
]
composition = tf.keras.Sequential(composes)

return composition

def get_backend(self) -> str:
return "tf"
Loading

0 comments on commit 97696f9

Please sign in to comment.