Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model_choice.py: define_model() function is superfluous and adds constraints and overhead when loading pretrained weights #460

Open
remtav opened this issue Jan 23, 2023 · 0 comments

Comments

@remtav
Copy link
Collaborator

remtav commented Jan 23, 2023

See define_model().

It doesn't seem relevant to add define_model() function to the already existing define_model_architecture(). What define_model() does on top of define_model_architecture() is
(1) convert model to a DataParallel model if more than 1 gpu is used (which is only relevant for training and can be performed directly in main() function)
(2) push model to device, which can (and should?) be done directly in main() function for training and inference (without adding lines of code really).
(3) read weights from a checkpoint file if this file is provided. However, this operation creates overhead at inference, since the checkpoint must already be read once in main() function to override default params.

Related to #455 ::
With the "softcode download directory for checkpoints that are urls" feature, loading weights from checkpoint in this function with its current usage would mean:

  1. Reading checkpoint in main(), and optionnaly downloading weights to checkpoint_dir if url
  2. If checkpoint was url, read_checkpoint() would need to return not only the checkpoint dictionary containing params and weights, but also the new path to downloaded local checkpoint (not url) for further use by define_model().
  3. Define_model() would then take this updated path to read checkpoint a second time into memory, then perform the "model.load_state_dict" operation from those weights.

Since, all of this would require reading two times rather than once the provided checkpoint, it wouldn't be very optimal in my opinion (and would also require more lines of code).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant