Skip to content

Commit 7fc96a9

Browse files
committed
set weights_only parameter of torch.load to False
- #48
1 parent e17a9c0 commit 7fc96a9

16 files changed

+68
-29
lines changed

chebai/models/electra.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def __init__(
256256
# Load pretrained checkpoint if provided
257257
if pretrained_checkpoint:
258258
with open(pretrained_checkpoint, "rb") as fin:
259-
model_dict = torch.load(fin, map_location=self.device)
259+
model_dict = torch.load(
260+
fin, map_location=self.device, weights_only=False
261+
)
260262
if load_prefix:
261263
state_dict = filter_dict(model_dict["state_dict"], load_prefix)
262264
else:
@@ -414,7 +416,9 @@ def __init__(self, cone_dimensions=20, **kwargs):
414416
model_prefix = kwargs.get("load_prefix", None)
415417
if pretrained_checkpoint:
416418
with open(pretrained_checkpoint, "rb") as fin:
417-
model_dict = torch.load(fin, map_location=self.device)
419+
model_dict = torch.load(
420+
fin, map_location=self.device, weights_only=False
421+
)
418422
if model_prefix:
419423
state_dict = {
420424
str(k)[len(model_prefix) :]: v

chebai/preprocessing/datasets/base.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ def load_processed_data(
200200
filename = self.processed_file_names_dict[kind]
201201
except NotImplementedError:
202202
filename = f"{kind}.pt"
203-
return torch.load(os.path.join(self.processed_dir, filename))
203+
return torch.load(
204+
os.path.join(self.processed_dir, filename), weights_only=False
205+
)
204206

205207
def dataloader(self, kind: str, **kwargs) -> DataLoader:
206208
"""
@@ -519,7 +521,7 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
519521
DataLoader: DataLoader object for the specified subset.
520522
"""
521523
subdatasets = [
522-
torch.load(os.path.join(s.processed_dir, f"{kind}.pt"))
524+
torch.load(os.path.join(s.processed_dir, f"{kind}.pt"), weights_only=False)
523525
for s in self.subsets
524526
]
525527
dataset = [
@@ -1022,7 +1024,9 @@ def _retrieve_splits_from_csv(self) -> None:
10221024
splits_df = pd.read_csv(self.splits_file_path)
10231025

10241026
filename = self.processed_file_names_dict["data"]
1025-
data = torch.load(os.path.join(self.processed_dir, filename))
1027+
data = torch.load(
1028+
os.path.join(self.processed_dir, filename), weights_only=False
1029+
)
10261030
df_data = pd.DataFrame(data)
10271031

10281032
train_ids = splits_df[splits_df["split"] == "train"]["id"]
@@ -1081,7 +1085,9 @@ def load_processed_data(
10811085

10821086
# If filename is provided
10831087
try:
1084-
return torch.load(os.path.join(self.processed_dir, filename))
1088+
return torch.load(
1089+
os.path.join(self.processed_dir, filename), weights_only=False
1090+
)
10851091
except FileNotFoundError:
10861092
raise FileNotFoundError(f"File {filename} doesn't exist")
10871093

chebai/preprocessing/datasets/chebi.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,9 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
407407
"""
408408
try:
409409
filename = self.processed_file_names_dict["data"]
410-
data_chebi_version = torch.load(os.path.join(self.processed_dir, filename))
410+
data_chebi_version = torch.load(
411+
os.path.join(self.processed_dir, filename), weights_only=False
412+
)
411413
except FileNotFoundError:
412414
raise FileNotFoundError(
413415
f"File data.pt doesn't exists. "
@@ -428,7 +430,8 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
428430
data_chebi_train_version = torch.load(
429431
os.path.join(
430432
self._chebi_version_train_obj.processed_dir, filename_train
431-
)
433+
),
434+
weights_only=False,
432435
)
433436
except FileNotFoundError:
434437
raise FileNotFoundError(

chebai/preprocessing/datasets/go_uniprot.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,9 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
508508
"""
509509
try:
510510
filename = self.processed_file_names_dict["data"]
511-
data_go = torch.load(os.path.join(self.processed_dir, filename))
511+
data_go = torch.load(
512+
os.path.join(self.processed_dir, filename), weights_only=False
513+
)
512514
except FileNotFoundError:
513515
raise FileNotFoundError(
514516
f"File data.pt doesn't exists. "

chebai/preprocessing/datasets/pubchem.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,10 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader:
891891
DataLoader: DataLoader instance.
892892
"""
893893
labeled_data = torch.load(
894-
os.path.join(self.labeled.processed_dir, f"{kind}.pt")
894+
os.path.join(self.labeled.processed_dir, f"{kind}.pt"), weights_only=False
895895
)
896896
unlabeled_data = torch.load(
897-
os.path.join(self.unlabeled.processed_dir, f"{kind}.pt")
897+
os.path.join(self.unlabeled.processed_dir, f"{kind}.pt"), weights_only=False
898898
)
899899
if self.data_limit is not None:
900900
labeled_data = labeled_data[: self.data_limit]

chebai/preprocessing/migration/chebi_data_migration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _combine_pt_splits(
168168
df_list: List[pd.DataFrame] = []
169169
for split, file_name in old_splits_file_names.items():
170170
file_path = os.path.join(old_dir, file_name)
171-
file_df = pd.DataFrame(torch.load(file_path))
171+
file_df = pd.DataFrame(torch.load(file_path, weights_only=False))
172172
df_list.append(file_df)
173173

174174
return pd.concat(df_list, ignore_index=True)

chebai/result/analyse_sem.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ def run_all(
427427
os.path.join(buffer_dir_smoothed, "preds000.pt")
428428
):
429429
preds = torch.load(
430-
os.path.join(buffer_dir_smoothed, "preds000.pt"), DEVICE
430+
os.path.join(buffer_dir_smoothed, "preds000.pt"),
431+
DEVICE,
432+
weights_only=False,
431433
)
432434
labels = None
433435
else:

chebai/result/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _generate_predictions(self, data_path, raw=False, **kwargs):
5454
else:
5555
data_tuples = [
5656
(x.get("raw_features", x["ident"]), x["ident"], x)
57-
for x in torch.load(data_path)
57+
for x in torch.load(data_path, weights_only=False)
5858
]
5959

6060
for raw_features, ident, row in tqdm.tqdm(data_tuples):

chebai/result/pretraining.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def evaluate_model(logs_base_path, model_filename, data_module):
3434
collate = data_module.reader.COLLATOR()
3535
test_file = "test.pt"
3636
data_path = os.path.join(data_module.processed_dir, test_file)
37-
data_list = torch.load(data_path)
37+
data_list = torch.load(data_path, weights_only=False)
3838
preds_list = []
3939
labels_list = []
4040

chebai/result/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def load_results_from_buffer(
182182
torch.load(
183183
os.path.join(buffer_dir, filename),
184184
map_location=torch.device(device),
185+
weights_only=False,
185186
)
186187
)
187188
i += 1
@@ -194,6 +195,7 @@ def load_results_from_buffer(
194195
torch.load(
195196
os.path.join(buffer_dir, filename),
196197
map_location=torch.device(device),
198+
weights_only=False,
197199
)
198200
)
199201
i += 1

tests/testCustomBalancedAccuracyMetric.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def test_metric_against_realistic_data(self) -> None:
4949

5050
# load single file to get the num of labels for metric class instantiation
5151
labels = torch.load(
52-
f"{directory_path}/labels{0:03d}.pt", map_location=torch.device(self.device)
52+
f"{directory_path}/labels{0:03d}.pt",
53+
map_location=torch.device(self.device),
54+
weights_only=False,
5355
)
5456
num_labels = labels.shape[1]
5557
balanced_acc_custom = BalancedAccuracy(num_labels=num_labels)
@@ -58,10 +60,12 @@ def test_metric_against_realistic_data(self) -> None:
5860
labels = torch.load(
5961
f"{directory_path}/labels{i:03d}.pt",
6062
map_location=torch.device(self.device),
63+
weights_only=False,
6164
)
6265
preds = torch.load(
6366
f"{directory_path}/preds{i:03d}.pt",
6467
map_location=torch.device(self.device),
68+
weights_only=False,
6569
)
6670
balanced_acc_custom.update(preds, labels)
6771

tests/testCustomMacroF1Metric.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def test_metric_against_realistic_data(self) -> None:
119119

120120
# Load single file to get the number of labels for metric class instantiation
121121
labels = torch.load(
122-
f"{directory_path}/labels{0:03d}.pt", map_location=torch.device(self.device)
122+
f"{directory_path}/labels{0:03d}.pt",
123+
map_location=torch.device(self.device),
124+
weights_only=False,
123125
)
124126
num_labels = labels.shape[1]
125127
macro_f1_custom = MacroF1(num_labels=num_labels)
@@ -130,10 +132,12 @@ def test_metric_against_realistic_data(self) -> None:
130132
labels = torch.load(
131133
f"{directory_path}/labels{i:03d}.pt",
132134
map_location=torch.device(self.device),
135+
weights_only=False,
133136
)
134137
preds = torch.load(
135138
f"{directory_path}/preds{i:03d}.pt",
136139
map_location=torch.device(self.device),
140+
weights_only=False,
137141
)
138142
macro_f1_standard.update(preds, labels)
139143
macro_f1_custom.update(preds, labels)

tests/testPubChemData.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ def getDataSplitsOverlaps(cls) -> None:
3737
processed_path = os.path.join(os.getcwd(), cls.pubChem.processed_dir)
3838
print(f"Checking Data from - {processed_path}")
3939

40-
train_set = torch.load(os.path.join(processed_path, "train.pt"))
41-
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
42-
test_set = torch.load(os.path.join(processed_path, "test.pt"))
40+
train_set = torch.load(
41+
os.path.join(processed_path, "train.pt"), weights_only=False
42+
)
43+
val_set = torch.load(
44+
os.path.join(processed_path, "validation.pt"), weights_only=False
45+
)
46+
test_set = torch.load(
47+
os.path.join(processed_path, "test.pt"), weights_only=False
48+
)
4349

4450
train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
4551
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)

tests/testTox21MolNetData.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ def getDataSplitsOverlaps(cls) -> None:
3737
processed_path = os.path.join(os.getcwd(), cls.tox21.processed_dir)
3838
print(f"Checking Data from - {processed_path}")
3939

40-
train_set = torch.load(os.path.join(processed_path, "train.pt"))
41-
val_set = torch.load(os.path.join(processed_path, "validation.pt"))
42-
test_set = torch.load(os.path.join(processed_path, "test.pt"))
40+
train_set = torch.load(
41+
os.path.join(processed_path, "train.pt"), weights_only=False
42+
)
43+
val_set = torch.load(
44+
os.path.join(processed_path, "validation.pt"), weights_only=False
45+
)
46+
test_set = torch.load(
47+
os.path.join(processed_path, "test.pt"), weights_only=False
48+
)
4349

4450
train_smiles, train_smiles_ids = cls.get_features_ids(train_set)
4551
val_smiles, val_smiles_ids = cls.get_features_ids(val_set)

tutorials/demo_process_results.ipynb

+5-5
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@
248248
"# check if pretraining datasets overlap\n",
249249
"dm = PubChemDeepSMILES()\n",
250250
"processed_path = dm.processed_dir\n",
251-
"test_set = torch.load(os.path.join(processed_path, \"test.pt\"))\n",
252-
"val_set = torch.load(os.path.join(processed_path, \"validation.pt\"))\n",
253-
"train_set = torch.load(os.path.join(processed_path, \"train.pt\"))\n",
251+
"test_set = torch.load(os.path.join(processed_path, \"test.pt\"), weights_only=False)\n",
252+
"val_set = torch.load(os.path.join(processed_path, \"validation.pt\"), weights_only=False)\n",
253+
"train_set = torch.load(os.path.join(processed_path, \"train.pt\"), weights_only=False)\n",
254254
"print(processed_path)\n",
255255
"test_smiles = [entry[\"features\"] for entry in test_set]\n",
256256
"val_smiles = [entry[\"features\"] for entry in val_set]\n",
@@ -320,7 +320,7 @@
320320
"data_module_v200 = ChEBIOver100()\n",
321321
"data_module_v148 = ChEBIOver100(chebi_version_train=148)\n",
322322
"data_module_v227 = ChEBIOver100(chebi_version_train=227)\n",
323-
"# dataset = torch.load(data_path)\n",
323+
"# dataset = torch.load(data_path, weights_only=False)\n",
324324
"# processors = [CustomResultsProcessor()]\n",
325325
"# factory = ResultFactory(model, data_module, processors)\n",
326326
"# factory.execute(data_path)"
@@ -653,7 +653,7 @@
653653
" if test_file is None:\n",
654654
" test_file = data_module.processed_file_names_dict[\"test\"]\n",
655655
" data_path = os.path.join(data_module.processed_dir, test_file)\n",
656-
" data_list = torch.load(data_path)\n",
656+
" data_list = torch.load(data_path, weights_only=False)\n",
657657
" preds_list = []\n",
658658
" labels_list = []\n",
659659
" # if common_classes_mask is not N\n",

tutorials/process_results_old_chebi.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@
167167
" if test_file is None:\n",
168168
" test_file = data_module.processed_file_names_dict[\"test\"]\n",
169169
" data_path = os.path.join(data_module.processed_dir, test_file)\n",
170-
" data_list = torch.load(data_path)\n",
170+
" data_list = torch.load(data_path, weights_only=False)\n",
171171
" preds_list = []\n",
172172
" labels_list = []\n",
173173
"\n",

0 commit comments

Comments
 (0)