-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* updated intro and paths * updated figures, tested data loader * setup.sh fetches correct dataset * finalized the exercise outline * semi-final exercise * parts 1 and 2 tested, part 3 outline ready * clearer variables, train with larger patch size * fix typo * clarify variable names * trying to log graph * match example size with training * reuse globals * fix reference * log sample images from the first batch * wider model * low LR solution * fix path * seed everything * fix test dataset without masks * metrics solution this needs a new test dataset * fetch test data, compute metrics * byass cellpose import error due to numpy version conflicts * final exercise * moved files * fixed formatting - ready for review * viscy -> VisCy (#34) (#39) Introducing capitalization to highlight vision and single-cell aspects of the pipeline. * trying to log graph * log graph * black --------- Co-authored-by: Shalin Mehta <[email protected]> Co-authored-by: Shalin Mehta <[email protected]>
- Loading branch information
1 parent
7a08716
commit 76c3b31
Showing
14 changed files
with
2,596 additions
and
305 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import argparse | ||
from traitlets.config import Config | ||
import nbformat as nbf | ||
from nbconvert.preprocessors import TagRemovePreprocessor, ClearOutputPreprocessor | ||
from nbconvert.exporters import NotebookExporter | ||
|
||
|
||
def get_arg_parser(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('input_file') | ||
parser.add_argument('output_file') | ||
|
||
return parser | ||
|
||
|
||
def convert(input_file, output_file): | ||
c = Config() | ||
c.TagRemovePreprocessor.remove_cell_tags = ("solution",) | ||
c.TagRemovePreprocessor.enabled = True | ||
c.ClearOutputPreprocesser.enabled = True | ||
c.NotebookExporter.preprocessors = [ | ||
"nbconvert.preprocessors.TagRemovePreprocessor", | ||
"nbconvert.preprocessors.ClearOutputPreprocessor" | ||
] | ||
|
||
exporter = NotebookExporter(config=c) | ||
exporter.register_preprocessor(TagRemovePreprocessor(config=c), True) | ||
exporter.register_preprocessor(ClearOutputPreprocessor(), True) | ||
|
||
output = NotebookExporter(config=c).from_filename(input_file) | ||
with open(output_file, 'w') as f: | ||
f.write(output[0]) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = get_arg_parser() | ||
args = parser.parse_args() | ||
|
||
convert(args.input_file, args.output_file) | ||
print(f'Converted {args.input_file} to {args.output_file}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
|
||
# %% | ||
# %% Imports and paths | ||
|
||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
import torchview | ||
import torchvision | ||
from iohub import open_ome_zarr | ||
from lightning.pytorch import seed_everything | ||
from lightning.pytorch.loggers import CSVLogger | ||
|
||
# pytorch lightning wrapper for Tensorboard. | ||
from tensorboard import notebook # for viewing tensorboard in notebook | ||
from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard | ||
|
||
# HCSDataModule makes it easy to load data during training. | ||
from viscy.light.data import HCSDataModule | ||
|
||
# Trainer class and UNet. | ||
from viscy.light.engine import VSTrainer, VSUNet | ||
|
||
seed_everything(42, workers=True) | ||
|
||
# Paths to data and log directory | ||
data_path = Path( | ||
Path("~/data/04_image_translation/HEK_nuclei_membrane_pyramid.zarr/") | ||
).expanduser() | ||
|
||
log_dir = Path("~/data/04_image_translation/logs/").expanduser() | ||
|
||
# Create log directory if needed, and launch tensorboard | ||
log_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# fmt: off | ||
%reload_ext tensorboard | ||
%tensorboard --logdir {log_dir} --port 6007 --bind_all | ||
# fmt: on | ||
|
||
# %% The entire training loop is contained in this cell. | ||
|
||
GPU_ID = 0 | ||
BATCH_SIZE = 10 | ||
YX_PATCH_SIZE = (512, 512) | ||
|
||
|
||
# Dictionary that specifies key parameters of the model. | ||
phase2fluor_config = { | ||
"architecture": "2D", | ||
"num_filters": [24, 48, 96, 192, 384], | ||
"in_channels": 1, | ||
"out_channels": 2, | ||
"residual": True, | ||
"dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data. | ||
"task": "reg", # reg = regression task. | ||
} | ||
|
||
phase2fluor_model = VSUNet( | ||
model_config=phase2fluor_config.copy(), | ||
batch_size=BATCH_SIZE, | ||
loss_function=torch.nn.functional.l1_loss, | ||
schedule="WarmupCosine", | ||
log_num_samples=10, # Number of samples from each batch to log to tensorboard. | ||
example_input_yx_shape=YX_PATCH_SIZE, | ||
) | ||
|
||
# Reinitialize the data module. | ||
phase2fluor_data = HCSDataModule( | ||
data_path, | ||
source_channel="Phase", | ||
target_channel=["Nuclei", "Membrane"], | ||
z_window_size=1, | ||
split_ratio=0.8, | ||
batch_size=BATCH_SIZE, | ||
num_workers=8, | ||
architecture="2D", | ||
yx_patch_size=YX_PATCH_SIZE, | ||
augment=True, | ||
) | ||
phase2fluor_data.setup("fit") | ||
|
||
|
||
# Train for 3 epochs to see if you can log graph. | ||
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], max_epochs=3, default_root_dir=log_dir) | ||
|
||
# trainer class takes the model and the data module as inputs. | ||
trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) | ||
|
||
# %% Is exmple_input_array present? | ||
print(f'{phase2fluor_model.example_input_array.shape},{phase2fluor_model.example_input_array.dtype}') | ||
trainer.logger.log_graph(phase2fluor_model, phase2fluor_model.example_input_array) | ||
# %% |
Oops, something went wrong.