This repository has been archived by the owner on Apr 25, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_segmentation.py
124 lines (98 loc) · 4.33 KB
/
run_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# bchhun, {2020-02-21}
from pipeline.segmentation import segmentation, instance_segmentation
from pipeline.segmentation_validation import segmentation_validation_michael
from multiprocessing import Process
import os
import numpy as np
import logging
log = logging.getLogger(__name__)
import argparse
from configs.config_reader import YamlReader
class Worker(Process):
def __init__(self, inputs, gpuid=0, method='segmentation'):
super().__init__()
self.gpuid = gpuid
self.inputs = inputs
self.method = method
def run(self):
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpuid)
if self.method == 'segmentation':
log.info(f"running segmentation worker on {self.gpuid}")
segmentation(*self.inputs)
elif self.method == 'instance_segmentation':
log.info(f"running instance segmentation")
instance_segmentation(*self.inputs)
elif self.method == 'segmentation_validation':
segmentation_validation_michael(*self.inputs)
def main(method_, raw_dir_, supp_dir_, val_dir_, config_):
method = method_
inputs = raw_dir_
outputs = supp_dir_
gpus = config_.segmentation.inference.gpu_ids
gpu_count = len(gpus)
assert len(config_.segmentation.inference.channels) > 0, "At least one channel must be specified"
# segmentation validation requires raw, supp, and validation definitions
if method == 'segmentation_validation':
if not val_dir_:
raise AttributeError("validation directory must be specified when method=segmentation_validation")
if not outputs:
raise AttributeError("supplemntary directory must be specifie dwhen method=segmentation_validation")
# segmentation requires raw (NNProb), and weights to be defined
elif method == 'segmentation':
if config_.segmentation.inference.weights is None:
raise AttributeError("Weights supp_dir must be specified when method=segmentation")
# instance segmentation requires raw (stack, NNprob), supp (to write outputs) to be defined
elif method == 'instance_segmentation':
TARGET = ''
else:
raise AttributeError(f"method flag {method} not implemented")
# all methods all require
if config_.segmentation.inference.fov:
sites = config_.segmentation.inference.fov
else:
# get all "XX-SITE_#" identifiers in raw data directory
img_names = [file for file in os.listdir(inputs) if (file.endswith(".npy")) & ('_NN' not in file)]
sites = [os.path.splitext(img_name)[0] for img_name in img_names]
sites = list(set(sites))
segment_sites = [site for site in sites if os.path.exists(os.path.join(inputs, "%s.npy" % site))]
sep = np.linspace(0, len(segment_sites), gpu_count + 1).astype(int)
processes = []
for i, gpu in enumerate(gpus):
_sites = segment_sites[sep[i]:sep[i + 1]]
args = (inputs, outputs, val_dir_, _sites, config_)
process = Worker(args, gpuid=gpu, method=method)
process.start()
processes.append(process)
for p in processes:
p.join()
def parse_args():
"""
Parse command line arguments for CLI.
:return: namespace containing the arguments passed.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'-m', '--method',
type=str,
required=True,
choices=['segmentation', 'instance_segmentation', 'segmentation_validation'],
default='segmentation',
help="Method: one of 'segmentation', 'instance_segmentation', or 'segmentation_validation'",
)
parser.add_argument(
'-c', '--config',
type=str,
required=True,
help='path to yaml configuration file. Run_segmentation takes arguments from "inference" category'
)
return parser.parse_args()
if __name__ == '__main__':
arguments = parse_args()
config = YamlReader()
config.read_config(arguments.config)
# batch run
for (raw_dir, supp_dir, val_dir) in list(zip(config.segmentation.inference.raw_dirs,
config.segmentation.inference.supp_dirs,
config.segmentation.inference.validation_dirs)):
main(arguments.method, raw_dir, supp_dir, val_dir, config)