Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Oct 19, 2022
1 parent a6b222f commit 6b6b3f2
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ PRETRAINED_MODEL_STR = "deepset/gbert-large"
EVALUATION_TOKENIZER = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_STR)
labels = ['reflexive']
saved_model_path = os.path.join('..','output','saved_models',
'reflexive_ex_mk_binary_gbert-large_monaco_epochs:20_lamb_0.0001_None_dh:0.3_da:0.0.pt')
'reflexive_ex_mk_binary_gbert-large_monaco-ex-kleist_epochs_20_lamb_0.0001_None_dh_0.3_da_0.0.pt')

device = 'cuda' if cuda.is_available() else 'cpu'
fine_tuned_model = load_model(model_path = saved_model_path, device = device,
Expand Down
2 changes: 1 addition & 1 deletion src/ml/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def load_model(model_path, device, petrained_model_str, no_labels):
#except RuntimeError:
# 'Attempting to deserialize object on a CUDA '
# print('loading a cuda model on the CPU')
trained_model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
trained_model.load_state_dict(torch.load(model_path, map_location=torch.device(device)), strict=False)
trained_model.to(device)
return trained_model

Expand Down

0 comments on commit 6b6b3f2

Please sign in to comment.