Skip to content

Commit

Permalink
Improve default launch device for train (#3523)
Browse files Browse the repository at this point in the history
* feat: improve launch device

* add licence header

* crash instead of trying to set the available option

* set cuda as default

---------

Co-authored-by: Noé Pion <[email protected]>
Co-authored-by: J.Y. <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2024
1 parent a8888e7 commit 555d554
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
10 changes: 10 additions & 0 deletions nerfstudio/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.utils import comms, profiler
from nerfstudio.utils.available_devices import get_available_devices
from nerfstudio.utils.rich_utils import CONSOLE

DEFAULT_TIMEOUT = timedelta(minutes=30)
Expand Down Expand Up @@ -226,6 +227,15 @@ def launch(
def main(config: TrainerConfig) -> None:
"""Main function."""

# Check if the specified device type is available
available_device_types = get_available_devices()
if config.machine.device_type not in available_device_types:
raise RuntimeError(
f"Specified device type '{config.machine.device_type}' is not available. "
f"Available device types: {available_device_types}. "
"Please specify a valid device type using the CLI option: --machine.device_type [cuda|mps|cpu]"
)

if config.data:
CONSOLE.log("Using --data alias for --data.pipeline.datamanager.data")
config.pipeline.datamanager.data = config.data
Expand Down
32 changes: 32 additions & 0 deletions nerfstudio/utils/available_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Literal

import torch


def get_available_devices() -> List[Literal["cpu", "cuda", "mps"]]:
"""Determine the available devices on the machine
Returns:
list: List of available device types
"""
available_devices: List[Literal["cpu", "cuda", "mps"]] = []
if torch.cuda.is_available():
available_devices.append("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
available_devices.append("mps")
available_devices.append("cpu")
return available_devices

0 comments on commit 555d554

Please sign in to comment.