Skip to content

Commit

Permalink
create repo
Browse files Browse the repository at this point in the history
  • Loading branch information
bknutson0 committed Oct 3, 2024
0 parents commit ad8d455
Show file tree
Hide file tree
Showing 37 changed files with 5,310 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Ignore local folders
data
env
__pycache__
.DS_Store
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Maze Extrapolation

## Models
The current models are:
- `dt_net`: a model from [Bansal et. al. ](https://github.com/aks2203/deep-thinking)

Binary file added models/dt_net.pth
Binary file not shown.
174 changes: 174 additions & 0 deletions notebooks/check_first_betti_num.ipynb

Large diffs are not rendered by default.

81 changes: 81 additions & 0 deletions notebooks/check_neighbor_counting.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import necessary modules\n",
"\n",
"import sys\n",
"import os\n",
"\n",
"# Set root folder to project root\n",
"os.chdir(os.path.dirname(os.getcwd()))\n",
"\n",
"# Add root folder to path\n",
"sys.path.append(os.getcwd())\n",
"\n",
"from src.utils.loading import get_mazes\n",
"from src.utils.plotting import plot_mazes"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n"
]
}
],
"source": [
"# Load mazes\n",
"\n",
"from src.utils.testing import count_start_neighbors\n",
"\n",
"inputs, _ = get_mazes(dataset='maze-dataset', maze_size=9, num_mazes=100, deadend_start=False)\n",
"\n",
"neighbors = count_start_neighbors(inputs)\n",
"\n",
"maze_idx_start = 0\n",
"maze_idx_end = 10\n",
"print(neighbors[maze_idx_start:maze_idx_end])\n",
"plot_mazes(inputs[maze_idx_start:maze_idx_end])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
173 changes: 173 additions & 0 deletions notebooks/create_mazes.ipynb

Large diffs are not rendered by default.

977 changes: 977 additions & 0 deletions notebooks/demo_maze-dataset.ipynb

Large diffs are not rendered by default.

143 changes: 143 additions & 0 deletions notebooks/explore_failed_predictions.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import necessary modules\n",
"\n",
"import sys\n",
"import os\n",
"\n",
"# Set root folder to project root\n",
"os.chdir(os.path.dirname(os.getcwd()))\n",
"\n",
"# Add root folder to path\n",
"sys.path.append(os.getcwd())\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from src.utils.testing import compare_mazes\n",
"from src.utils.loading import load_model, get_mazes\n",
"from src.utils.plotting import plot_mazes"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n",
"Loaded pi_net to cuda\n"
]
}
],
"source": [
"# Load model and mazes\n",
"\n",
"model = load_model('dt_net')\n",
"\n",
"inputs, solutions = get_mazes(\n",
" dataset='maze-dataset', \n",
" maze_size=9, \n",
" num_mazes=30,\n",
" percolation=0.0,\n",
" deadend_start=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Compute predictions and compare to solutions\n",
"\n",
"predictions = torch.zeros_like(solutions)\n",
"for i in range(inputs.shape[0]):\n",
" predictions[i:i+1] = model.predict(inputs[i:i+1], iters=300)\n",
"\n",
"corrects = torch.tensor(compare_mazes(predictions, solutions), dtype=torch.bool)\n",
"incorrects = ~corrects"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No incorrect predictions found.\n"
]
}
],
"source": [
"# Plot incorrect predictions\n",
"\n",
"if incorrects.any():\n",
" plot_mazes(inputs[incorrects], \n",
" predictions=predictions[incorrects], \n",
" solutions=solutions[incorrects], \n",
" file_name=f'outputs/mazes/{model.name()}_incorrects.pdf')\n",
"else:\n",
" print('No incorrect predictions found.')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "IndexError",
"evalue": "index 0 is out of bounds for axis 0 with size 0",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[5], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m incorrect_inputs \u001b[38;5;241m=\u001b[39m inputs[incorrects]\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m 2\u001b[0m incorrect_inputs \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmoveaxis(incorrect_inputs, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Move RGB axis to last\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m plt\u001b[38;5;241m.\u001b[39mimshow(\u001b[43mincorrect_inputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m, cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgray\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
"\u001b[0;31mIndexError\u001b[0m: index 0 is out of bounds for axis 0 with size 0"
]
}
],
"source": [
"incorrect_inputs = inputs[incorrects].cpu().numpy()\n",
"incorrect_inputs = np.moveaxis(incorrect_inputs, 1, -1) # Move RGB axis to last\n",
"\n",
"plt.imshow(incorrect_inputs[0], cmap='gray')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit ad8d455

Please sign in to comment.