Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model_choice.py: remove define_model() function #461

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from dataset.aoi import aois_from_csv
from dataset.stacitem import SingleBandItemEO
from utils.logger import get_logger, set_tracker
from models.model_choice import define_model, read_checkpoint
from models.model_choice import read_checkpoint, define_model_architecture
from utils import augmentation
from utils.utils import get_device_ids, get_key_def, \
add_metadata_from_raster_to_sample, _window_2D, set_device
Expand Down Expand Up @@ -401,14 +401,13 @@ def main(params: Union[DictConfig, dict]) -> None:
bands_requested = [SingleBandItemEO.band_to_cname(band) for band in bands_requested]
logging.warning(f"Will request: {bands_requested}")

model = define_model(
model = define_model_architecture(
net_params=params.model,
in_channels=num_bands,
out_classes=num_classes,
main_device=device,
devices=[list(gpu_devices_dict.keys())],
state_dict_path=state_dict,
)
model.to(device)
model.load_state_dict(state_dict=checkpoint['model_state_dict'])

# GET LIST OF INPUT IMAGES FOR INFERENCE
list_aois = aois_from_csv(csv_path=raw_data_csv, bands_requested=bands_requested,
Expand Down
28 changes: 1 addition & 27 deletions models/model_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def define_model_architecture(
return instantiate(net_params, in_channels=in_channels, classes=out_classes)


def read_checkpoint(filename, out_dir: str = 'checkpoints', update=True) -> DictConfig:
def read_checkpoint(filename, out_dir: str = 'checkpoints', update=False) -> DictConfig:
"""
Loads checkpoint from provided path to GDL's expected format,
ie model's state dictionary should be under "model_state_dict" and
Expand Down Expand Up @@ -116,29 +116,3 @@ def to_dp_model(model, devices: List):
f"Trying devices with ids {list(range(len(devices)))}")
model = nn.DataParallel(model, device_ids=list(range(len(devices))))
return model


def define_model(
net_params: dict,
in_channels: int,
out_classes: int,
main_device: str = 'cpu',
devices: List = [],
state_dict_path: str = None,
state_dict_strict_load: bool = True,
):
"""
Defines model's architecture with weights from provided checkpoint and pushes to device(s)
@return:
"""
model = define_model_architecture(
net_params=net_params,
in_channels=in_channels,
out_classes=out_classes,
)
model = to_dp_model(model=model, devices=devices[1:]) if len(devices) > 1 else model
model.to(main_device)
if state_dict_path:
checkpoint = read_checkpoint(state_dict_path)
model.load_state_dict(state_dict=checkpoint['model_state_dict'], strict=state_dict_strict_load)
return model
13 changes: 7 additions & 6 deletions tests/model/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import models.unet
from models import unet
from models.model_choice import read_checkpoint, adapt_checkpoint_to_dp_model, define_model, define_model_architecture
from models.model_choice import read_checkpoint, adapt_checkpoint_to_dp_model, define_model_architecture, to_dp_model
from utils.utils import get_device_ids, set_device


Expand Down Expand Up @@ -102,12 +102,13 @@ class TestDefineModelMultigpu(object):
if len(gpu_devices_dict.keys()) == 0:
logging.critical(f"No GPUs available. Cannot perform multi-gpu testing.")
else:
define_model(
model = define_model_architecture(
net_params={'_target_': 'models.unet.UNet'},
in_channels=4,
out_classes=4,
main_device=device,
devices=list(gpu_devices_dict.keys()),
state_dict_path=filename,
state_dict_strict_load=True,
)
devices = list(gpu_devices_dict.keys())
model = to_dp_model(model=model, devices=devices[1:]) if len(devices) > 1 else model
model.to(device)
checkpoint = read_checkpoint(filename)
model.load_state_dict(state_dict=checkpoint['model_state_dict'])
17 changes: 11 additions & 6 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from models.model_choice import read_checkpoint, define_model, adapt_checkpoint_to_dp_model
from models.model_choice import read_checkpoint, adapt_checkpoint_to_dp_model, to_dp_model, \
define_model_architecture
from tiling_segmentation import Tiler
from utils import augmentation as aug
from dataset import create_dataset
Expand Down Expand Up @@ -629,15 +630,19 @@ def train(cfg: DictConfig) -> None:
device = set_device(gpu_devices_dict=gpu_devices_dict)

# INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
model = define_model(
model = define_model_architecture(
net_params=cfg.model,
in_channels=num_bands,
out_classes=num_classes,
main_device=device,
devices=list(gpu_devices_dict.keys()),
state_dict_path=train_state_dict_path,
state_dict_strict_load=state_dict_strict,
)
devices = list(gpu_devices_dict.keys())
model = to_dp_model(model=model, devices=devices[1:]) if len(devices) > 1 else model
model.to(device)

if train_state_dict_path:
checkpoint = read_checkpoint(train_state_dict_path)
model.load_state_dict(state_dict=checkpoint['model_state_dict'], strict=state_dict_strict)

criterion = define_loss(loss_params=cfg.loss, class_weights=class_weights)
criterion = criterion.to(device)
optimizer = instantiate(cfg.optimizer, params=model.parameters())
Expand Down