Skip to content

Commit

Permalink
Model loading change
Browse files Browse the repository at this point in the history
  • Loading branch information
Vlad Sobol committed Feb 8, 2020
1 parent d5fb5ef commit 8307d99
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 5 deletions.
10 changes: 9 additions & 1 deletion eval_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,15 @@ def get_optimal_pool_size():

def load_models(opt, data_path, device='cuda'):
stats = torch.load(path.join(data_path, 'data_stats.pth'))
forward_model = torch.load(path.join(opt.model_dir, opt.mfile))

model_path = path.join(opt.model_dir, opt.mfile)
if path.exists(model_path):
forward_model = torch.load(model_path)
elif path.exists(opt.mfile):
forward_model = torch.load(opt.mfile)
else:
raise runtime_error(f'couldn\'t find file {opt.mfile}')

if type(forward_model) is dict:
forward_model = forward_model['model']
value_function, policy_network_il, policy_network_mper = None, None, None
Expand Down
2 changes: 1 addition & 1 deletion scripts/submit_eval_mpur.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#SBATCH --gres gpu:1
#SBATCH --constraint="gpu_12gb&pascal"
#SBATCH --exclude="weaver1, weaver2, weaver3, weaver4, weaver5, vine5, vine11"
#SBATCH --cpus-per-task=3
#SBATCH --cpus-per-task=7
#SBATCH --qos=batch
#SBATCH --nodes=1
#SBATCH --mem=48000
Expand Down
4 changes: 2 additions & 2 deletions scripts/submit_eval_mpur_path.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#SBATCH --gres gpu:1
#SBATCH --constraint="gpu_12gb&pascal"
#SBATCH --exclude="weaver1, weaver2, weaver3, weaver4, weaver5, vine5, vine11"
#SBATCH --cpus-per-task=3
#SBATCH --cpus-per-task=7
#SBATCH --qos=batch
#SBATCH --nodes=1
#SBATCH --mem=48000
Expand All @@ -18,8 +18,8 @@ conda activate PPUU

cd ../
srun python eval_policy.py \
$@ \
-model_dir $model_dir \
-method policy-MPUR \
-policy_model $policy \
-save_grad_vid

11 changes: 10 additions & 1 deletion train_MPUR.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,18 @@
print('WARNING: You have a CUDA device, so you should probably run without -no_cuda')

# load the model
model = torch.load(path.join(opt.model_dir, opt.mfile))

model_path = path.join(opt.model_dir, opt.mfile)
if path.exists(model_path):
model = torch.load(model_path)
elif path.exists(opt.mfile):
model = torch.load(opt.mfile)
else:
raise runtime_error(f'couldn\'t find file {opt.mfile}')

if not hasattr(model.encoder, 'n_channels'):
model.encoder.n_channels = 3

if type(model) is dict: model = model['model']
model.opt.lambda_l = opt.lambda_l # used by planning.py/compute_uncertainty_batch
model.opt.lambda_o = opt.lambda_o # used by planning.py/compute_uncertainty_batch
Expand Down

0 comments on commit 8307d99

Please sign in to comment.