-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit ad8d455
Showing
37 changed files
with
5,310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Ignore local folders | ||
data | ||
env | ||
__pycache__ | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Maze Extrapolation | ||
|
||
## Models | ||
The current models are: | ||
- `dt_net`: a model from [Bansal et. al. ](https://github.com/aks2203/deep-thinking) | ||
|
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Import necessary modules\n", | ||
"\n", | ||
"import sys\n", | ||
"import os\n", | ||
"\n", | ||
"# Set root folder to project root\n", | ||
"os.chdir(os.path.dirname(os.getcwd()))\n", | ||
"\n", | ||
"# Add root folder to path\n", | ||
"sys.path.append(os.getcwd())\n", | ||
"\n", | ||
"from src.utils.loading import get_mazes\n", | ||
"from src.utils.plotting import plot_mazes" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"1.0\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Load mazes\n", | ||
"\n", | ||
"from src.utils.testing import count_start_neighbors\n", | ||
"\n", | ||
"inputs, _ = get_mazes(dataset='maze-dataset', maze_size=9, num_mazes=100, deadend_start=False)\n", | ||
"\n", | ||
"neighbors = count_start_neighbors(inputs)\n", | ||
"\n", | ||
"maze_idx_start = 0\n", | ||
"maze_idx_end = 10\n", | ||
"print(neighbors[maze_idx_start:maze_idx_end])\n", | ||
"plot_mazes(inputs[maze_idx_start:maze_idx_end])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "env", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Import necessary modules\n", | ||
"\n", | ||
"import sys\n", | ||
"import os\n", | ||
"\n", | ||
"# Set root folder to project root\n", | ||
"os.chdir(os.path.dirname(os.getcwd()))\n", | ||
"\n", | ||
"# Add root folder to path\n", | ||
"sys.path.append(os.getcwd())\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"from matplotlib import pyplot as plt\n", | ||
"\n", | ||
"from src.utils.testing import compare_mazes\n", | ||
"from src.utils.loading import load_model, get_mazes\n", | ||
"from src.utils.plotting import plot_mazes" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using device: cuda\n", | ||
"Loaded pi_net to cuda\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Load model and mazes\n", | ||
"\n", | ||
"model = load_model('dt_net')\n", | ||
"\n", | ||
"inputs, solutions = get_mazes(\n", | ||
" dataset='maze-dataset', \n", | ||
" maze_size=9, \n", | ||
" num_mazes=30,\n", | ||
" percolation=0.0,\n", | ||
" deadend_start=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Compute predictions and compare to solutions\n", | ||
"\n", | ||
"predictions = torch.zeros_like(solutions)\n", | ||
"for i in range(inputs.shape[0]):\n", | ||
" predictions[i:i+1] = model.predict(inputs[i:i+1], iters=300)\n", | ||
"\n", | ||
"corrects = torch.tensor(compare_mazes(predictions, solutions), dtype=torch.bool)\n", | ||
"incorrects = ~corrects" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"No incorrect predictions found.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Plot incorrect predictions\n", | ||
"\n", | ||
"if incorrects.any():\n", | ||
" plot_mazes(inputs[incorrects], \n", | ||
" predictions=predictions[incorrects], \n", | ||
" solutions=solutions[incorrects], \n", | ||
" file_name=f'outputs/mazes/{model.name()}_incorrects.pdf')\n", | ||
"else:\n", | ||
" print('No incorrect predictions found.')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "IndexError", | ||
"evalue": "index 0 is out of bounds for axis 0 with size 0", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[5], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m incorrect_inputs \u001b[38;5;241m=\u001b[39m inputs[incorrects]\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m 2\u001b[0m incorrect_inputs \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmoveaxis(incorrect_inputs, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Move RGB axis to last\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m plt\u001b[38;5;241m.\u001b[39mimshow(\u001b[43mincorrect_inputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m, cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgray\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", | ||
"\u001b[0;31mIndexError\u001b[0m: index 0 is out of bounds for axis 0 with size 0" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"incorrect_inputs = inputs[incorrects].cpu().numpy()\n", | ||
"incorrect_inputs = np.moveaxis(incorrect_inputs, 1, -1) # Move RGB axis to last\n", | ||
"\n", | ||
"plt.imshow(incorrect_inputs[0], cmap='gray')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "env", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.