This repository has been archived by the owner on Sep 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_encoding.py
141 lines (125 loc) · 4.84 KB
/
run_encoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import numpy as np
import logging
import torch
import zarr
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset.dataset import ImageDataset
from train import resnet as resnet
import argparse
from utils.config_reader import YamlReader
from utils.logger import make_logger
def encode_patches(raw_dir: str,
config_: YamlReader,
gpu: int=0,
):
""" Wrapper method for patch encoding
This function loads prepared dataset and applies trained VAE to encode
static patches for each well.
Resulting latent vectors will be saved in raw_dir:
train_embeddings.npy
val_embeddings.npy
Args:
raw_folder (str): folder for raw data, segmentation and
summarized results
supp_folder (str): folder for supplementary data
sites (list): list of FOVs to process
config_ (YamlReader): Reads fields from the "INFERENCE" category
"""
log = logging.getLogger('dynacontrast.log')
model_dir = config_.inference.weights_dirs
n_chan = config_.inference.n_channels
network = config_.inference.network
network_width = config_.inference.network_width
batch_size = config_.inference.batch_size
num_workers = config_.inference.num_workers
normalization = config_.inference.normalization
projection = config_.inference.projection
splits = config_.inference.splits
model_name = os.path.basename(model_dir)
if projection:
output_dir = os.path.join(raw_dir, model_name + '_proj')
encode_layer = 'z'
else:
output_dir = os.path.join(raw_dir, model_name)
encode_layer = 'h'
os.makedirs(output_dir, exist_ok=True)
datasets = {}
for split in splits:
if normalization == 'dataset':
zarr_path = os.path.join(raw_dir, 'cell_patches_datasetnorm_{}.zarr'.format(split))
elif normalization == 'patch':
zarr_path = os.path.join(raw_dir, 'cell_patches_{}.zarr'.format(split))
else:
raise ValueError('Parameter "normalization" must be "dataset" or "patch"')
if not os.path.isdir(zarr_path):
msg = '{} is not found.'.format(zarr_path)
log.error(msg)
raise FileNotFoundError(msg)
datasets[split] = zarr.open(zarr_path)
datasets[split] = ImageDataset(datasets[split])
device = torch.device('cuda:%d' % gpu)
print('Encoding images using gpu {}...'.format(gpu))
# Only ResNet is available now
if 'ResNet' not in network:
raise ValueError('Network {} is not available'.format(network))
for data_name, dataset in datasets.items():
network_cls = getattr(resnet, 'EncodeProject')
model = network_cls(arch=network, num_inputs=n_chan, width=network_width)
model = model.to(device)
# print(model)
model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt'), map_location=device))
model.eval()
data_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
h_s = []
with tqdm(data_loader, desc='inference batch') as batch_pbar:
for batch in batch_pbar:
# print(batch.shape)
batch = batch.to(device)
code = model.encode(batch, out=encode_layer).cpu().data.numpy().squeeze()
# print(code.shape)
h_s.append(code)
dats = np.concatenate(h_s, axis=0)
output_fname = '{}_embeddings.npy'.format(data_name)
print(f"\tsaving {os.path.join(output_dir, output_fname)}")
with open(os.path.join(output_dir, output_fname), 'wb') as f:
np.save(f, dats)
def main(raw_dir, config_):
logger = make_logger(
log_dir=raw_dir,
log_level=20,
)
gpu_id = config_.inference.gpu_id
encode_patches(raw_dir, config_, gpu=gpu_id)
def parse_args():
"""
Parse command line arguments for CLI.
:return: namespace containing the arguments passed.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'-c', '--config',
type=str,
required=True,
help='path to yaml configuration file'
)
return parser.parse_args()
if __name__ == '__main__':
arguments = parse_args()
config = YamlReader()
config.read_config(arguments.config)
if type(config.inference.weights_dirs) is not list:
weights = [config.inference.weights_dirs]
else:
weights = config.inference.weights_dirs
# batch run
for raw_dir in config.inference.raw_dirs:
for weight in weights:
config.inference.weights_dirs = weight
main(raw_dir, config)