Skip to content

Commit

Permalink
made torch home for storing checkpoints project-wide available
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasd4 committed Jul 26, 2023
1 parent 7aefc0e commit 67d3f06
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
25 changes: 3 additions & 22 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .tensorflow import TensorFlowExtractor
from .torch import PyTorchExtractor
from thingsvision.utils.checkpointing import get_torch_home

# neccessary to prevent gpu memory conflicts between torch and tf
gpus = tf.config.list_physical_devices("GPU")
Expand Down Expand Up @@ -172,9 +173,6 @@ def load_model_from_source(self) -> None:


class SSLExtractor(PyTorchExtractor):
ENV_TORCH_HOME = "TORCH_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"
MODELS = {
"simclr-rn50": {
"url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/model_final_checkpoint_phase799.torch",
Expand Down Expand Up @@ -301,7 +299,7 @@ def _download_and_save_model(self, model_url: str, output_model_filepath: str):
return converted_model

def _replace_module_prefix(
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
):
"""
Remove prefixes in a state_dict needed when loading models that are not VISSL
Expand All @@ -316,23 +314,6 @@ def _replace_module_prefix(
}
return state_dict

def _get_torch_home(self):
"""
Gets the torch home folder used as a cache directory for the vissl models.
"""
torch_home = os.path.expanduser(
os.getenv(
SSLExtractor.ENV_TORCH_HOME,
os.path.join(
os.getenv(
SSLExtractor.ENV_XDG_CACHE_HOME, SSLExtractor.DEFAULT_CACHE_DIR
),
"torch",
),
)
)
return torch_home

def load_model_from_source(self) -> None:
"""
Load a (pretrained) neural network model from vissl. Downloads the model when it is not available.
Expand All @@ -341,7 +322,7 @@ def load_model_from_source(self) -> None:
if self.model_name in SSLExtractor.MODELS:
model_config = SSLExtractor.MODELS[self.model_name]
if model_config["type"] == "vissl":
cache_dir = os.path.join(self._get_torch_home(), "vissl")
cache_dir = os.path.join(get_torch_home(), "vissl")
model_filepath = os.path.join(cache_dir, self.model_name + ".torch")
if not os.path.exists(model_filepath):
os.makedirs(cache_dir, exist_ok=True)
Expand Down
6 changes: 4 additions & 2 deletions thingsvision/custom_models/dreamsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchvision import transforms

from thingsvision.custom_models.custom import Custom
from thingsvision.utils.checkpointing import get_torch_home

Tensor = torch.Tensor

Expand All @@ -22,9 +23,10 @@ def __init__(self, model_type, device) -> None:

self.model_type = model_type
self.device = device
model_dir = os.path.join(torch.hub.get_dir(), 'checkpoints')
model_dir = os.path.join(get_torch_home(), 'dreamsim')
self.model, _ = dreamsim(
pretrained=True, dreamsim_type=model_type, normalize_embeds=False, device=device, cache_dir=model_dir
pretrained=True, dreamsim_type=model_type, normalize_embeds=False,
device=device, cache_dir=model_dir
)

def forward(self, x: Tensor) -> Tensor:
Expand Down
23 changes: 23 additions & 0 deletions thingsvision/utils/checkpointing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

ENV_TORCH_HOME = "TORCH_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"


def get_torch_home():
"""
Gets the torch home folder used as a cache directory for model checkpoints.
"""
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME,
os.path.join(
os.getenv(
ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR
),
"torch",
),
)
)
return torch_home

0 comments on commit 67d3f06

Please sign in to comment.