Skip to content

Commit

Permalink
SparseMeshes: Reimplemented to perform meshing blockwise
Browse files Browse the repository at this point in the history
  • Loading branch information
stuarteberg committed Nov 3, 2023
1 parent 974c073 commit 1c4fe97
Showing 1 changed file with 84 additions and 134 deletions.
218 changes: 84 additions & 134 deletions flyemflows/workflow/sparsemeshes.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,62 @@
import os
import copy
import logging
from math import log2, pow, ceil

import numpy as np
import pandas as pd

import dask.bag as db
import distributed

from dvid_resource_manager.client import ResourceManagerClient

from neuclease.util import Timer
from neuclease.dvid import fetch_sparsevol_coarse, fetch_sparsevol
from confiddler import flow_style
from neuclease.util import Timer, tqdm_proxy
from neuclease.dvid import fetch_sparsevol, set_default_dvid_session_timeout
from neuclease.dvid.rle import blockwise_masks_from_ranges

from vol2mesh import Mesh

from ..volumes import VolumeService, DvidVolumeService, DvidSegmentationVolumeSchema
from ..util import as_completed_synchronous
from .util import BodyListSchema, load_body_list
from . import Workflow

logger = logging.getLogger(__name__)


class SparseMeshes(Workflow):
"""
This workflow 'naively' computes meshes from downloaded sparsevols.
It will download each sparsevol at the best scale it can, ensuring that
the bounding-box of the body at that scale doesn't exceed a certain size.
Then it computes the entire mesh all at once (not in blocks, no stitching required).
Compute meshes for a set of bodies from their sparsevol representations.
It saves the resulting mesh files to a directory.
"""

OptionsSchema = {
"type": "object",
"description": "Settings specific to the SparseMeshes workflow",
"default": {},
"additionalProperties": False,
"properties": {
"min-scale": {
"description": "Minimum scale at which to fetch sparsevols.\n"
"For individual bodies, the scale may be forced higher\n"
"if needed according to max-analysis-volume.",
"bodies": BodyListSchema,
"scale": {
"description":
"Scale at which to fetch sparsevols.\n"
"Mesh vertices will be converted to scale-0.\n",
"type": "integer",
"default": 0
},
"max-analysis-volume": {
"description": "The above scale will be overridden (to something higher, i.e. lower resolution) \n"
"if the body would still be too large to generate a mesh for, as defined by this setting.\n",
"type": "number",
"default": 1e9 # 1 GB max
"rescale-factor": {
"description": "Optionally rescale the vertex positions before storing the mesh.\n",
"type": "array",
"items": { "type": "number" },
"minItems": 3,
"maxItems": 3,
"default": flow_style([1, 1, 1])
},
"block-shape": {
"description": "The mesh will be generated in blocks and the blocks will be stitched together.\n",
"type": "array",
"items": { "type": "integer" },
"minItems": 3,
"maxItems": 3,
"default": flow_style([-1,-1,-1])
},
"bodies": BodyListSchema,
"smoothing-iterations": {
"description": "How many iterations of smoothing to apply before decimation",
"type": "integer",
Expand All @@ -63,15 +70,18 @@ class SparseMeshes(Workflow):
" the decimation fraction will be auto-increased for that mesh.",
"type": "number",
"minimum": 0.0000001,
"maximum": 1.0, # 1.0 == disable
"default": 0.1
"maximum": 1.0, # 1.0 == disable
"default": 0.1
},
"format": {
"description": "Format to save the meshes in",
"type": "string",
"enum": ["obj", # Wavefront OBJ (.obj)
"drc", # Draco (compressed) (.drc)
"ngmesh"], # "neuroglancer mesh" format -- a custom binary format. Note: Data is presumed to be 8nm resolution
"enum": [
"obj", # Wavefront OBJ (.obj)
"drc", # Draco (compressed) (.drc)
"ngmesh" # "neuroglancer mesh" format -- a custom binary format.
# Note: Data is presumed to be 8nm resolution, so you need to use rescale-facctor
],
"default": "obj"
},
"output-directory": {
Expand All @@ -82,140 +92,80 @@ class SparseMeshes(Workflow):
}
}


Schema = copy.deepcopy(Workflow.schema())
Schema["properties"].update({
"input": DvidSegmentationVolumeSchema,
"sparsemeshes": OptionsSchema
})


@classmethod
def schema(cls):
return SparseMeshes.Schema

def _init_service(self):
options = self.config["sparsemeshes"]
def execute(self):
input_config = self.config["input"]
mgr_options = self.config["resource-manager"]

self.mgr_client = ResourceManagerClient( mgr_options["server"], mgr_options["port"] )
self.input_service = VolumeService.create_from_config( input_config, self.mgr_client )
assert isinstance(self.input_service, DvidVolumeService), \
mgr_client = ResourceManagerClient(mgr_options["server"], mgr_options["port"])
input_service = VolumeService.create_from_config(input_config, mgr_client)
assert isinstance(input_service, DvidVolumeService), \
"Input must be plain dvid source, not scaled, transposed, etc."

min_scale = options["min-scale"]
max_scale = max(self.input_service.available_scales)
assert min_scale <= max_scale, \
f"Largest available scale in the input ({max_scale}) is smaller than the min-scale you provided ({min_scale})."

def execute(self):
self._init_service()
mgr_client = self.mgr_client

options = self.config["sparsemeshes"]
max_box_voxels = options["max-analysis-volume"]
min_scale = options["min-scale"]
max_scale = max(self.input_service.available_scales)
scale = options["scale"]
halo = 1
smoothing_iterations = options["smoothing-iterations"]
decimation_fraction = options["decimation-fraction"]
block_shape = options["block-shape"][::-1]
output_dir = options["output-directory"]
rescale = options["rescale-factor"]
fmt = options["format"]

server, uuid, instance = self.input_service.base_service.instance_triple
is_supervoxels = self.input_service.base_service.supervoxels
server, uuid, instance = input_service.base_service.instance_triple
is_supervoxels = input_service.base_service.supervoxels

bodies = load_body_list(options["bodies"], is_supervoxels)
logger.info(f"Input is {len(bodies)} bodies")

os.makedirs(options["output-directory"], exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

def compute_mesh_and_write(body):
set_default_dvid_session_timeout(600.0, 600.0)
with Timer() as timer:
# Fetch the sparsevol to determine the bounding-box size (in scale-0 voxels)
try:
with mgr_client.access_context(server, True, 1, 0):
# sparsevol-coarse is at scale-6
coords_s6 = fetch_sparsevol_coarse(server, uuid, instance, body, is_supervoxels)
except:
return (body, 0, 0, 0, 0.0, timer.seconds, 'error-sparsevol-coarse')

box_s6 = np.array([coords_s6.min(axis=0), 1+coords_s6.max(axis=0)])
box_s0 = (2**6) * box_s6
shape_s0 = (box_s0[1] - box_s0[0])
box_voxels_s0 = np.prod(shape_s0.astype(float))

# Determine the scale we'll use.
# Solve for 'scale' in the following relationship:
#
# box_voxels_s0/((2^scale)^3) <= max_box_voxels
#
scale = log2(pow(box_voxels_s0 / max_box_voxels, 1/3))
scale = max(ceil(scale), min_scale)

if scale > max_scale:
raise RuntimeError(f"Can't compute mesh for body {body}. Bounding box is {box_s0[:, ::-1].tolist()}, "
f"which is too large to fit in desired RAM, even at scale {max_scale}")

try:
with mgr_client.access_context(server, True, 1, 0):
coords = fetch_sparsevol(server, uuid, instance, body, scale=scale, supervoxels=is_supervoxels, dtype=np.int16)
except:
return (body, 0, 0, 0, 0.0, timer.seconds, 'error-sparsevol')

box = box_s0 // (2**scale)
coords -= box[0]
num_voxels = len(coords)

shape = box[1] - box[0]
vol = np.zeros(shape, np.uint8)
vol[(*coords.transpose(),)] = 1
del coords

try:
mesh = Mesh.from_binary_vol(vol, box_s0)
except:
return (body, scale, num_voxels, 0, 0.0, timer.seconds, 'error-construction')

del vol
try:
mesh.laplacian_smooth(smoothing_iterations)
except:
return (body, scale, num_voxels, 0.0, len(mesh.vertices_zyx), timer.seconds, 'error-smoothing')

fraction = decimation_fraction
if scale > min_scale:
# Since we're starting from a lower resolution than the user requested,
# Reduce the decimation we're applying accordingly.
# Since meshes are 2D surfaces, we approximate the difference in
# vertexes as the SQUARE of the difference in resolution.
fraction *= (2**(scale - min_scale))**2
fraction = min(fraction, 1.0)

try:
mesh.simplify_openmesh(fraction)
except:
return (body, scale, num_voxels, 0.0, len(mesh.vertices_zyx), timer.seconds, 'error-decimation')

output_path = f'{options["output-directory"]}/{body}.{options["format"]}'
mesh.serialize(output_path)

return (body, scale, num_voxels, fraction, len(mesh.vertices_zyx), timer.seconds, 'success')

# Run the computation -- scatter first to ensure fair distribution (fixme: does this make a difference?)
# And use a lot of partitions to enable work-stealing if some meshes are slow to compute.
bodies_bag = db.from_sequence(bodies, npartitions=2000)
bodies_bag = self.client.scatter(bodies_bag).result()
stats = bodies_bag.map(compute_mesh_and_write).compute()

# Save stats
stats_df = pd.DataFrame(stats, columns=['body', 'scale', 'voxels', 'decimation_fraction', 'vertices', 'total_seconds', 'result'])
stats_df.to_csv('mesh-stats.csv', index=False, header=True)

rng = fetch_sparsevol(server, uuid, instance, body, scale, format='ranges')

boxes, masks = blockwise_masks_from_ranges(rng, block_shape, halo)
m = Mesh.from_binary_blocks(masks, boxes * 2**scale)
m.laplacian_smooth(smoothing_iterations)
m.simplify(decimation_fraction)
m.vertices_zyx *= rescale
output_path = f'{output_dir}/{body}.{fmt}'
m.serialize(output_path)
return (body, len(m.vertices_zyx), timer.seconds, 'success', '')
except Exception as ex:
return (body, 0, timer.seconds, 'failed', str(ex))

futures = self.client.map(compute_mesh_and_write, bodies)

# Support synchronous testing with a fake 'as_completed' object
if hasattr(self.client, 'DEBUG'):
ac = as_completed_synchronous(futures, with_results=True)
else:
ac = distributed.as_completed(futures, with_results=True)

try:
stats = []
for f, r in tqdm_proxy(ac, total=len(futures)):
stats.append(r)
body, vertices, total_seconds, result, err = r
if result != "success":
logger.warning(f"Body {body} failed: {err}")
finally:
stats_df = pd.DataFrame(stats, columns=['body', 'vertices', 'total_seconds', 'result'])
stats_df.to_csv('mesh-stats.csv', index=False, header=True)

failed_df = stats_df.query('result != "success"')
if len(failed_df) > 0:
logger.warning(f"{len(failed_df)} meshes could not be generated. See mesh-stats.csv")
logger.warning(f"Results:\n{stats_df['result'].value_counts()}")

scales_histogram = stats_df.query("result == 'success'")['scale'].value_counts().sort_index()
logger.info(f"Scales chosen:\n{scales_histogram}")


0 comments on commit 1c4fe97

Please sign in to comment.