Skip to content

Commit

Permalink
final code
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchenzhao committed Apr 25, 2021
0 parents commit 07ade0b
Show file tree
Hide file tree
Showing 44 changed files with 3,966 additions and 0 deletions.
133 changes: 133 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

wandb/
ISIC2018/
path/
.vscode/
old_scripts/
curr_script/
wandb/
.vscode/
# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
67 changes: 67 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# MICCAI-2021 PAC Bayesian Performance Guarantees for Deep(Stochastic) Networks in Medical Imaging


## Introduction



This code repository is an implementation of "**PAC Bayesian Performance Guarantees for Deep(Stochastic) Networks in Medical Imaging.**"

## Preparation



### Prerequisites

- Python 3.6
- Pytorch 1.4
- numpy
- tqdm
- pandas
- PIL

### Dataset Preparation

- Run `get_data.sh` to retrieve the ISIC2018 challenge data.
- Run `make_split.py` to generate a train test split.
- Run `python3 -m src.main **kwargs` to train models and compute bounds.

## Training



To reproduce the results showed in the fig a, b, c, and d, please run the following scripts.

### Fig a

- `sh scripts/fig_a/LW.sh`
- `sh scripts/fig_a/LW-PBB.sh`
- `sh scripts/fig_a/U-Net.sh`
- `sh scripts/fig_a/U-Net-PBB.sh`

### Fig b

- `sh scripts/fig_b/sigma_prior_0.001.sh`
- `sh scripts/fig_b/sigma_prior_0.005.sh`
- `sh scripts/fig_b/sigma_prior_0.01.sh`
- `sh scripts/fig_b/sigma_prior_0.02.sh`
- `sh scripts/fig_b/sigma_prior_0.03.sh`
- `sh scripts/fig_b/sigma_prior_0.04.sh`
- `sh scripts/fig_b/sigma_prior_0.05.sh`

### Fig c

- `sh scripts/fig_c/sigma_prior_0.001.sh`
- `sh scripts/fig_c/sigma_prior_0.005.sh`
- `sh scripts/fig_c/sigma_prior_0.01.sh`
- `sh scripts/fig_c/sigma_prior_0.02.sh`
- `sh scripts/fig_c/sigma_prior_0.03.sh`
- `sh scripts/fig_c/sigma_prior_0.04.sh`
- `sh scripts/fig_c/sigma_prior_0.05.sh`
- `sh scripts/fig_c/sigma_prior_0.1.sh`
- `sh scripts/fig_c/sigma_prior_0.2.sh`

### Fig d

- `sh scripts/fig_d/LW.sh`
- `sh scripts/fig_d/U-Net.sh`
19 changes: 19 additions & 0 deletions get_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# This script is untested, but should work (and, otherwise, give the general idea
# for how to construct the ISIC2018 data directory)
mkdir ISIC2018
wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Training_Input.zip
unzip ISIC2018_Task1-2_Training_Input.zip
rm ISIC2018_Task1-2_Training_Input.zip
mv ISIC2018_Task1-2_Training_Input ISIC2018
wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Training_GroundTruth.zip
unzip ISIC2018_Task1_Training_GroundTruth.zip
rm ISIC2018_Task1_Training_GroundTruth.zip
mv ISIC2018_Task1_Training_GroundTruth ISIC2018
wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task3_Training_Input.zip
unzip ISIC2018_Task3_Training_Input.zip
rm ISIC2018_Task3_Training_Input.zip
mv ISIC2018_Task3_Training_Input ISIC2018
wget https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task3_Training_GroundTruth.zip
unzip ISIC2018_Task3_Training_GroundTruth.zip
rm ISIC2018_Task3_Training_GroundTruth.zip
mv ISIC2018_Task3_Training_GroundTruth ISIC2018
92 changes: 92 additions & 0 deletions make_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import random

from pathlib import Path
import os
from PIL import Image

TASK1_IMG_DIR = 'ISIC2018/ISIC2018_Task1-2_Training_Input'
TASK3_IMG_DIR = 'ISIC2018/ISIC2018_Task3_Training_Input'

def isimage(fname):

try:
_ = Image.open(fname)
return True
except IOError:
return False

raise NotImplementedError('Unhandled case encountered.')

def write(paths, fname):
with open(fname, 'w') as out:
for p in paths:
out.write(f'{p}\n')

if __name__ == '__main__':

random.seed(0)

tasks = [(TASK1_IMG_DIR, 'task1'),
(TASK3_IMG_DIR, 'task3')]

for task_dir, task_name in tasks:

path = os.path.join('path', task_name)

Path(path).mkdir(exist_ok=True, parents=True)
task_files = [fname for fname in os.listdir(task_dir)
if isimage(os.path.join(task_dir,fname))]

train = []
final_holdout = []

for fname in task_files:
if random.random() <= 0.9:
train.append(fname)
else:
final_holdout.append(fname)

write(final_holdout,
os.path.join(path, 'final_holdout.txt'))
write(train,
os.path.join(path, 'pac_bayes_full_train.txt'))

hoeffding_holdout = []
hoeffding_train = []

for fname in train:
if random.random() <= 0.9:
hoeffding_train.append(fname)
else:
hoeffding_holdout.append(fname)

write(hoeffding_holdout,
os.path.join(path, 'hoeffding_holdout.txt'))
write(hoeffding_train,
os.path.join(path, 'hoeffding_train.txt'))

prefix = []
bound = []

for fname in train:
if random.random() <= 0.5:
bound.append(fname)
else:
prefix.append(fname)

write(prefix,
os.path.join(path, 'pac_bayes_prefix.txt'))
write(bound,
os.path.join(path, 'pac_bayes_prefix_bound.txt'))












3 changes: 3 additions & 0 deletions scripts/fig_a/LW-PBB.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.01 --kl_dampening=1 --prior_max_train=30 --use_prefix --estimator=sample --device=0 --freeze_batchnorm --mc_samples=1000
3 changes: 3 additions & 0 deletions scripts/fig_a/LW.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --baseline --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --train_bound=none --task=segment --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_a/U-Net-PBB.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.01 --kl_dampening=1 --prior_max_train=30 --use_prefix --estimator=sample --device=1 --freeze_batchnorm --mc_samples=1000
3 changes: 3 additions & 0 deletions scripts/fig_a/U-Net.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --baseline --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --train_bound=none --task=segment --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_b/sigma_prior_0.001.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.001 --kl_dampening=0.1 --prior_max_train=30 --use_prefix --estimator=sample --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_b/sigma_prior_0.005.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.005 --kl_dampening=0.1 --prior_max_train=30 --use_prefix --estimator=sample --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_b/sigma_prior_0.01.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.01 --kl_dampening=0.1 --prior_max_train=30 --use_prefix --estimator=sample --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_b/sigma_prior_0.02.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.02 --kl_dampening=1 --prior_max_train=30 --use_prefix --estimator=sample --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_b/sigma_prior_0.03.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.03 --kl_dampening=1 --prior_max_train=30 --use_prefix --estimator=sample --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_b/sigma_prior_0.04.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.04 --kl_dampening=1 --prior_max_train=30 --use_prefix --estimator=sample --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_b/sigma_prior_0.05.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=unet --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.05 --kl_dampening=0.1 --prior_max_train=30 --use_prefix --estimator=sample --device=0 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_c/sigma_prior_0.001.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.001 --kl_dampening=0.1 --prior_max_train=30 --use_prefix --estimator=sample --device=1 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_c/sigma_prior_0.005.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.005 --kl_dampening=0.1 --prior_max_train=30 --use_prefix --estimator=sample --device=1 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_c/sigma_prior_0.01.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.01 --kl_dampening=0.1 --prior_max_train=30 --use_prefix --estimator=sample --device=1 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_c/sigma_prior_0.02.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.02 --kl_dampening=1 --prior_max_train=30 --use_prefix --estimator=sample --device=1 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_c/sigma_prior_0.03.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.03 --kl_dampening=1 --prior_max_train=30 --use_prefix --estimator=sample --device=1 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_c/sigma_prior_0.04.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.04 --kl_dampening=1 --prior_max_train=30 --use_prefix --estimator=sample --device=1 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_c/sigma_prior_0.05.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --task=segment --sigma_prior=0.05 --kl_dampening=0.1 --prior_max_train=30 --use_prefix --estimator=sample --device=1 --freeze_batchnorm
3 changes: 3 additions & 0 deletions scripts/fig_c/sigma_prior_0.1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python3 -m src.main --model=light --init_lr=1e-2 --lr_step=30 --epochs=120 --momentum=0.95 --batch_size=8 --sigma_prior=0.1 --kl_dampening=1 --train_bound=variational --prior_max_train=30 --use_prefix --mc_samples=100 --estimator=sample --task=segment --device=1 --freeze_batchnorm
Loading

0 comments on commit 07ade0b

Please sign in to comment.