diff --git a/flyemflows/workflow/sparsemeshes.py b/flyemflows/workflow/sparsemeshes.py index 8901e1ef..17478e2f 100644 --- a/flyemflows/workflow/sparsemeshes.py +++ b/flyemflows/workflow/sparsemeshes.py @@ -1,55 +1,62 @@ import os import copy import logging -from math import log2, pow, ceil -import numpy as np import pandas as pd - -import dask.bag as db +import distributed from dvid_resource_manager.client import ResourceManagerClient - -from neuclease.util import Timer -from neuclease.dvid import fetch_sparsevol_coarse, fetch_sparsevol +from confiddler import flow_style +from neuclease.util import Timer, tqdm_proxy +from neuclease.dvid import fetch_sparsevol, set_default_dvid_session_timeout +from neuclease.dvid.rle import blockwise_masks_from_ranges from vol2mesh import Mesh from ..volumes import VolumeService, DvidVolumeService, DvidSegmentationVolumeSchema +from ..util import as_completed_synchronous from .util import BodyListSchema, load_body_list from . import Workflow logger = logging.getLogger(__name__) + class SparseMeshes(Workflow): """ - This workflow 'naively' computes meshes from downloaded sparsevols. - It will download each sparsevol at the best scale it can, ensuring that - the bounding-box of the body at that scale doesn't exceed a certain size. - Then it computes the entire mesh all at once (not in blocks, no stitching required). + Compute meshes for a set of bodies from their sparsevol representations. It saves the resulting mesh files to a directory. """ - + OptionsSchema = { "type": "object", "description": "Settings specific to the SparseMeshes workflow", "default": {}, "additionalProperties": False, "properties": { - "min-scale": { - "description": "Minimum scale at which to fetch sparsevols.\n" - "For individual bodies, the scale may be forced higher\n" - "if needed according to max-analysis-volume.", + "bodies": BodyListSchema, + "scale": { + "description": + "Scale at which to fetch sparsevols.\n" + "Mesh vertices will be converted to scale-0.\n", "type": "integer", "default": 0 }, - "max-analysis-volume": { - "description": "The above scale will be overridden (to something higher, i.e. lower resolution) \n" - "if the body would still be too large to generate a mesh for, as defined by this setting.\n", - "type": "number", - "default": 1e9 # 1 GB max + "rescale-factor": { + "description": "Optionally rescale the vertex positions before storing the mesh.\n", + "type": "array", + "items": { "type": "number" }, + "minItems": 3, + "maxItems": 3, + "default": flow_style([1, 1, 1]) + }, + "block-shape": { + "description": "The mesh will be generated in blocks and the blocks will be stitched together.\n", + "type": "array", + "items": { "type": "integer" }, + "minItems": 3, + "maxItems": 3, + "default": flow_style([-1,-1,-1]) }, - "bodies": BodyListSchema, "smoothing-iterations": { "description": "How many iterations of smoothing to apply before decimation", "type": "integer", @@ -63,15 +70,18 @@ class SparseMeshes(Workflow): " the decimation fraction will be auto-increased for that mesh.", "type": "number", "minimum": 0.0000001, - "maximum": 1.0, # 1.0 == disable - "default": 0.1 + "maximum": 1.0, # 1.0 == disable + "default": 0.1 }, "format": { "description": "Format to save the meshes in", "type": "string", - "enum": ["obj", # Wavefront OBJ (.obj) - "drc", # Draco (compressed) (.drc) - "ngmesh"], # "neuroglancer mesh" format -- a custom binary format. Note: Data is presumed to be 8nm resolution + "enum": [ + "obj", # Wavefront OBJ (.obj) + "drc", # Draco (compressed) (.drc) + "ngmesh" # "neuroglancer mesh" format -- a custom binary format. + # Note: Data is presumed to be 8nm resolution, so you need to use rescale-facctor + ], "default": "obj" }, "output-directory": { @@ -82,140 +92,80 @@ class SparseMeshes(Workflow): } } - Schema = copy.deepcopy(Workflow.schema()) Schema["properties"].update({ "input": DvidSegmentationVolumeSchema, "sparsemeshes": OptionsSchema }) - @classmethod def schema(cls): return SparseMeshes.Schema - def _init_service(self): - options = self.config["sparsemeshes"] + def execute(self): input_config = self.config["input"] mgr_options = self.config["resource-manager"] - - self.mgr_client = ResourceManagerClient( mgr_options["server"], mgr_options["port"] ) - self.input_service = VolumeService.create_from_config( input_config, self.mgr_client ) - assert isinstance(self.input_service, DvidVolumeService), \ + mgr_client = ResourceManagerClient(mgr_options["server"], mgr_options["port"]) + input_service = VolumeService.create_from_config(input_config, mgr_client) + assert isinstance(input_service, DvidVolumeService), \ "Input must be plain dvid source, not scaled, transposed, etc." - - min_scale = options["min-scale"] - max_scale = max(self.input_service.available_scales) - assert min_scale <= max_scale, \ - f"Largest available scale in the input ({max_scale}) is smaller than the min-scale you provided ({min_scale})." - - def execute(self): - self._init_service() - mgr_client = self.mgr_client options = self.config["sparsemeshes"] - max_box_voxels = options["max-analysis-volume"] - min_scale = options["min-scale"] - max_scale = max(self.input_service.available_scales) + scale = options["scale"] + halo = 1 smoothing_iterations = options["smoothing-iterations"] decimation_fraction = options["decimation-fraction"] + block_shape = options["block-shape"][::-1] + output_dir = options["output-directory"] + rescale = options["rescale-factor"] + fmt = options["format"] - server, uuid, instance = self.input_service.base_service.instance_triple - is_supervoxels = self.input_service.base_service.supervoxels + server, uuid, instance = input_service.base_service.instance_triple + is_supervoxels = input_service.base_service.supervoxels bodies = load_body_list(options["bodies"], is_supervoxels) logger.info(f"Input is {len(bodies)} bodies") - os.makedirs(options["output-directory"], exist_ok=True) - + os.makedirs(output_dir, exist_ok=True) + def compute_mesh_and_write(body): + set_default_dvid_session_timeout(600.0, 600.0) with Timer() as timer: - # Fetch the sparsevol to determine the bounding-box size (in scale-0 voxels) try: with mgr_client.access_context(server, True, 1, 0): - # sparsevol-coarse is at scale-6 - coords_s6 = fetch_sparsevol_coarse(server, uuid, instance, body, is_supervoxels) - except: - return (body, 0, 0, 0, 0.0, timer.seconds, 'error-sparsevol-coarse') - - box_s6 = np.array([coords_s6.min(axis=0), 1+coords_s6.max(axis=0)]) - box_s0 = (2**6) * box_s6 - shape_s0 = (box_s0[1] - box_s0[0]) - box_voxels_s0 = np.prod(shape_s0.astype(float)) - - # Determine the scale we'll use. - # Solve for 'scale' in the following relationship: - # - # box_voxels_s0/((2^scale)^3) <= max_box_voxels - # - scale = log2(pow(box_voxels_s0 / max_box_voxels, 1/3)) - scale = max(ceil(scale), min_scale) - - if scale > max_scale: - raise RuntimeError(f"Can't compute mesh for body {body}. Bounding box is {box_s0[:, ::-1].tolist()}, " - f"which is too large to fit in desired RAM, even at scale {max_scale}") - - try: - with mgr_client.access_context(server, True, 1, 0): - coords = fetch_sparsevol(server, uuid, instance, body, scale=scale, supervoxels=is_supervoxels, dtype=np.int16) - except: - return (body, 0, 0, 0, 0.0, timer.seconds, 'error-sparsevol') - - box = box_s0 // (2**scale) - coords -= box[0] - num_voxels = len(coords) - - shape = box[1] - box[0] - vol = np.zeros(shape, np.uint8) - vol[(*coords.transpose(),)] = 1 - del coords - - try: - mesh = Mesh.from_binary_vol(vol, box_s0) - except: - return (body, scale, num_voxels, 0, 0.0, timer.seconds, 'error-construction') - - del vol - try: - mesh.laplacian_smooth(smoothing_iterations) - except: - return (body, scale, num_voxels, 0.0, len(mesh.vertices_zyx), timer.seconds, 'error-smoothing') - - fraction = decimation_fraction - if scale > min_scale: - # Since we're starting from a lower resolution than the user requested, - # Reduce the decimation we're applying accordingly. - # Since meshes are 2D surfaces, we approximate the difference in - # vertexes as the SQUARE of the difference in resolution. - fraction *= (2**(scale - min_scale))**2 - fraction = min(fraction, 1.0) - - try: - mesh.simplify_openmesh(fraction) - except: - return (body, scale, num_voxels, 0.0, len(mesh.vertices_zyx), timer.seconds, 'error-decimation') - - output_path = f'{options["output-directory"]}/{body}.{options["format"]}' - mesh.serialize(output_path) - - return (body, scale, num_voxels, fraction, len(mesh.vertices_zyx), timer.seconds, 'success') - - # Run the computation -- scatter first to ensure fair distribution (fixme: does this make a difference?) - # And use a lot of partitions to enable work-stealing if some meshes are slow to compute. - bodies_bag = db.from_sequence(bodies, npartitions=2000) - bodies_bag = self.client.scatter(bodies_bag).result() - stats = bodies_bag.map(compute_mesh_and_write).compute() - - # Save stats - stats_df = pd.DataFrame(stats, columns=['body', 'scale', 'voxels', 'decimation_fraction', 'vertices', 'total_seconds', 'result']) - stats_df.to_csv('mesh-stats.csv', index=False, header=True) - + rng = fetch_sparsevol(server, uuid, instance, body, scale, format='ranges') + + boxes, masks = blockwise_masks_from_ranges(rng, block_shape, halo) + m = Mesh.from_binary_blocks(masks, boxes * 2**scale) + m.laplacian_smooth(smoothing_iterations) + m.simplify(decimation_fraction) + m.vertices_zyx *= rescale + output_path = f'{output_dir}/{body}.{fmt}' + m.serialize(output_path) + return (body, len(m.vertices_zyx), timer.seconds, 'success', '') + except Exception as ex: + return (body, 0, timer.seconds, 'failed', str(ex)) + + futures = self.client.map(compute_mesh_and_write, bodies) + + # Support synchronous testing with a fake 'as_completed' object + if hasattr(self.client, 'DEBUG'): + ac = as_completed_synchronous(futures, with_results=True) + else: + ac = distributed.as_completed(futures, with_results=True) + + try: + stats = [] + for f, r in tqdm_proxy(ac, total=len(futures)): + stats.append(r) + body, vertices, total_seconds, result, err = r + if result != "success": + logger.warning(f"Body {body} failed: {err}") + finally: + stats_df = pd.DataFrame(stats, columns=['body', 'vertices', 'total_seconds', 'result', 'errors']) + stats_df.to_csv('mesh-stats.csv', index=False, header=True) + failed_df = stats_df.query('result != "success"') if len(failed_df) > 0: logger.warning(f"{len(failed_df)} meshes could not be generated. See mesh-stats.csv") logger.warning(f"Results:\n{stats_df['result'].value_counts()}") - - scales_histogram = stats_df.query("result == 'success'")['scale'].value_counts().sort_index() - logger.info(f"Scales chosen:\n{scales_histogram}") - -