Skip to content

Commit

Permalink
[FIX] CUDA example & accessing correct model attr
Browse files Browse the repository at this point in the history
  • Loading branch information
kathryn-baker committed Nov 22, 2023
1 parent dbbf19a commit 3fa3637
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def forward(self, x):
pass

def to(self, device: str):
self.model._model.to(device)
self.model.to(device)
super().to(device)


Expand Down
32 changes: 23 additions & 9 deletions basic_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(device)"
]
},
Expand All @@ -50,12 +50,12 @@
" nn.Linear(5, 10),\n",
" nn.ReLU(),\n",
" nn.Linear(10, 2),\n",
" ).double()\n",
").double()\n",
"base_model.requires_grad_(False)\n",
"base_model.to(device)\n",
"\n",
"# create example data\n",
"x = torch.rand((100, 5), dtype=torch.double)\n",
"x = torch.rand((100, 5), dtype=torch.double, device=device)\n",
"pred = base_model(x)"
]
},
Expand Down Expand Up @@ -204,7 +204,13 @@
"# define data set\n",
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, x, y):\n",
" self.x, self.y, = x, y\n",
" (\n",
" self.x,\n",
" self.y,\n",
" ) = (\n",
" x,\n",
" y,\n",
" )\n",
"\n",
" def __len__(self):\n",
" return self.y.shape[0]\n",
Expand All @@ -220,7 +226,7 @@
" batch_size=pred.shape[0],\n",
" shuffle=True,\n",
" num_workers=0,\n",
" pin_memory= not device =='cuda', # we can't use this if we're on the GPU\n",
" pin_memory=not \"cuda\" in device, # we can't use this if we're on the GPU\n",
")\n",
"\n",
"# define optimizer and loss function\n",
Expand Down Expand Up @@ -341,8 +347,12 @@
"outputs": [],
"source": [
"# calibration parameters\n",
"df_x = pd.DataFrame(columns=[[\"x_offset\"] * 2 + [\"x_scale\"] * 2, [\"target\", \"learned\"] * 2])\n",
"df_y = pd.DataFrame(columns=[[\"y_offset\"] * 2 + [\"y_scale\"] * 2, [\"target\", \"learned\"] * 2])\n",
"df_x = pd.DataFrame(\n",
" columns=[[\"x_offset\"] * 2 + [\"x_scale\"] * 2, [\"target\", \"learned\"] * 2]\n",
")\n",
"df_y = pd.DataFrame(\n",
" columns=[[\"y_offset\"] * 2 + [\"y_scale\"] * 2, [\"target\", \"learned\"] * 2]\n",
")\n",
"for df in [df_x, df_y]:\n",
" for col in df.columns:\n",
" model = miscal_model\n",
Expand Down Expand Up @@ -563,7 +573,9 @@
"\n",
"\n",
"# create transformed model\n",
"transformed_model = TransformedModel(miscal_model, input_transformer, output_transformer)"
"transformed_model = TransformedModel(\n",
" miscal_model, input_transformer, output_transformer\n",
")"
]
},
{
Expand Down Expand Up @@ -591,7 +603,9 @@
"for i in range(y_size):\n",
" idx_sort = torch.argsort(pred[:, i])\n",
" ax[i].plot(pred[idx_sort, i].cpu(), \"C0x\", label=\"original\")\n",
" ax[i].plot(transformed_pred[idx_sort, i].cpu(), \"C1x\", label=\"calibrated (transformers)\")\n",
" ax[i].plot(\n",
" transformed_pred[idx_sort, i].cpu(), \"C1x\", label=\"calibrated (transformers)\"\n",
" )\n",
" ax[i].grid(color=\"gray\", linestyle=\"dashed\")\n",
" ax[i].set_xlabel(\"index\")\n",
" ax[i].set_ylabel(f\"y$_{i}$\")\n",
Expand Down

0 comments on commit 3fa3637

Please sign in to comment.