|
248 | 248 | "# check if pretraining datasets overlap\n",
|
249 | 249 | "dm = PubChemDeepSMILES()\n",
|
250 | 250 | "processed_path = dm.processed_dir\n",
|
251 |
| - "test_set = torch.load(os.path.join(processed_path, \"test.pt\"))\n", |
252 |
| - "val_set = torch.load(os.path.join(processed_path, \"validation.pt\"))\n", |
253 |
| - "train_set = torch.load(os.path.join(processed_path, \"train.pt\"))\n", |
| 251 | + "test_set = torch.load(os.path.join(processed_path, \"test.pt\"), weights_only=False)\n", |
| 252 | + "val_set = torch.load(os.path.join(processed_path, \"validation.pt\"), weights_only=False)\n", |
| 253 | + "train_set = torch.load(os.path.join(processed_path, \"train.pt\"), weights_only=False)\n", |
254 | 254 | "print(processed_path)\n",
|
255 | 255 | "test_smiles = [entry[\"features\"] for entry in test_set]\n",
|
256 | 256 | "val_smiles = [entry[\"features\"] for entry in val_set]\n",
|
|
320 | 320 | "data_module_v200 = ChEBIOver100()\n",
|
321 | 321 | "data_module_v148 = ChEBIOver100(chebi_version_train=148)\n",
|
322 | 322 | "data_module_v227 = ChEBIOver100(chebi_version_train=227)\n",
|
323 |
| - "# dataset = torch.load(data_path)\n", |
| 323 | + "# dataset = torch.load(data_path, weights_only=False)\n", |
324 | 324 | "# processors = [CustomResultsProcessor()]\n",
|
325 | 325 | "# factory = ResultFactory(model, data_module, processors)\n",
|
326 | 326 | "# factory.execute(data_path)"
|
|
653 | 653 | " if test_file is None:\n",
|
654 | 654 | " test_file = data_module.processed_file_names_dict[\"test\"]\n",
|
655 | 655 | " data_path = os.path.join(data_module.processed_dir, test_file)\n",
|
656 |
| - " data_list = torch.load(data_path)\n", |
| 656 | + " data_list = torch.load(data_path, weights_only=False)\n", |
657 | 657 | " preds_list = []\n",
|
658 | 658 | " labels_list = []\n",
|
659 | 659 | " # if common_classes_mask is not N\n",
|
|
0 commit comments