Skip to content

Commit

Permalink
Check onnx model during sanity check
Browse files Browse the repository at this point in the history
  • Loading branch information
ranlu committed Mar 21, 2024
1 parent efc506b commit 2d0b430
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
43 changes: 43 additions & 0 deletions dags/chunkflow_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,42 @@ def check_patch_parameters(param):
return param


def check_onnx_model(param):
from cloudfiles import dl
import onnx
onnx_path = param.get("ONNX_MODEL_PATH", None)
if onnx_path:
onnx_file = dl(onnx_path)
onnx_model = onnx.load_model_from_string(onnx_file["content"])
inputs = onnx_model.graph.input
outputs = onnx_model.graph.output
if len(inputs) > 1:
slack_message(":u7981:*ERROR: Chunkflow does not support models with multiple inputs!*")
raise ValueError('Onnx model with multiple inputs')
if len(outputs) > 1:
slack_message(":u7981:*WARNING: Model produces multiple output, chunkflow only collects the first one*")
input_shape = [dim.dim_value for dim in inputs[0].type.tensor_type.shape.dim]
output_shape = [dim.dim_value for dim in outputs[0].type.tensor_type.shape.dim]

if "INFERENCE_OUTPUT_CHANNELS" not in param:
slack_message(f"Set `INFERENCE_OUTPUT_CHANNELS` to `{output_shape[1]}`")
param["INFERENCE_OUTPUT_CHANNELS"] = output_shape[1]
elif param["INFERENCE_OUTPUT_CHANNELS"] != output_shape[1]:
slack_message(f":u7981:*ERROR: Specified `INFERENCE_OUTPUT_CHANNELS = {param['INFERENCE_OUTPUT_CHANNELS']}`, does not match ONNX output shape `{output_shape}`*")
raise ValueError('Inference output channel error')

if any(x != y for x, y in zip(input_shape[-3:], output_shape[-3:])):
slack_message(f":u7981:*ERROR: The input shape `{input_shape[-3:][::-1]}` does not match the output shape `{output_shape[-3:][::-1]}`*")
raise ValueError('Input shape does not match the output shape')

if "INPUT_PATCH_SIZE" not in param:
slack_message(f"Set `INPUT_PATH_SIZE` to `{input_shape[-3:][::-1]}`")
param["INPUT_PATCH_SIZE"] = output_shape[-3:][::-1]
elif any(x != y for x, y in zip(param["INPUT_PATCH_SIZE"], input_shape[-3:][::-1])):
slack_message(f":u7981:*ERROR: Specified `INPUT_PATCH_SIZE = {param['INPUT_PATCH_SIZE']}`, does not match ONNX input shape `{input_shape[-3:][::-1]}`*")
raise ValueError('Input patch size error')


@mount_secrets
def supply_default_parameters():
from docker_helper import health_check_info
Expand Down Expand Up @@ -160,6 +196,13 @@ def check_matching_mip(path1, path2):
slack_message(":u7981:*ERROR: Cannot specify pytorch model and onnx model at the same time*")
raise ValueError('Can only use one backend')

if "ONNX_MODEL_PATH" in param:
try:
check_onnx_model(param)
except Exception:
slack_message(":u7981:*ERROR: Failed to check the ONNX model*")
raise ValueError('Check ONNX model failed')

if param.get("ENABLE_FP16", False):
slack_message(":exclamation:*Enable FP16 inference for TensorRT*")

Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ RUN savedAptMark="$(apt-mark showmanual)" \
&& CONSTRAINT_URL="https://raw.githubusercontent.com/apache/airflow/constraints-${AIRFLOW_VERSION}/constraints-${PYTHON_VERSION}.txt" \
&& pip install --no-cache-dir -U pip \
&& pip install --no-cache-dir --compile --global-option=build git+https://github.com/seung-lab/chunk_iterator#egg=chunk-iterator \
&& pip install --no-cache-dir igneous-pipeline \
&& pip install --no-cache-dir igneous-pipeline onnx \
&& pip install --no-cache-dir "apache-airflow[celery,postgres,rabbitmq,docker,slack,google,statsd]==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}" \
&& mkdir -p ${AIRFLOW_HOME}/version \
&& groupadd -r docker \
Expand Down

0 comments on commit 2d0b430

Please sign in to comment.