From 2fcd83b35974dcf47d8059435d70d477c09fad9b Mon Sep 17 00:00:00 2001 From: trivoldus28 Date: Thu, 22 Feb 2024 23:29:29 +0000 Subject: [PATCH 1/2] initial commit --- .../aio-pairwise-alignment-examples/aio.cue | 510 ++++++++ .../alignment/pairwise_alignment.py | 1147 +++++++++++++++++ 2 files changed, 1657 insertions(+) create mode 100644 specs/tri/aio-pairwise-alignment-examples/aio.cue create mode 100644 zetta_utils/mazepa_layer_processing/alignment/pairwise_alignment.py diff --git a/specs/tri/aio-pairwise-alignment-examples/aio.cue b/specs/tri/aio-pairwise-alignment-examples/aio.cue new file mode 100644 index 000000000..40d758a11 --- /dev/null +++ b/specs/tri/aio-pairwise-alignment-examples/aio.cue @@ -0,0 +1,510 @@ +#IMG_PATH: "precomputed://gs://dkronauer-ant-001-raw/brain" +#IMG_RES: [4, 4, 42] +#IMG_SIZE: [102400, 96256, 6112] + +#TEST_SMALL: true +#TEST_LOCAL: true +#CLUSTER_NUM_WORKERS: 16 + +#RUN_ENCODE: true +#RUN_DEFECT: true +#RUN_MASK_ENCODE: true +#RUN_ALIGN: true +#RUN_MISD: true + +#PROJECT: "test-240222" + +#PAIR_FLOW_TMPL: project_folder: _ | *"gs://dkronauer-ant-001-alignment/\(#PROJECT)" + +// #PRECOMPUTED_ENCODINGS: "gs://dkronauer-ant-001-alignment/test-240109-z50-gen-v3-32nm-1/encodings" +// #PRECOMPUTED_MASKED_ENCODINGS: "gs://dkronauer-ant-001-alignment/test-240109-z50-gen-v3-64nm-v2-defect-opening-2/masked_encodings" + +////////////////////////////////// +// DEFECT DETECTOR CONFIG // +////////////////////////////////// +#DEFECT_MODEL_PATH: "gs://zetta_lee_fly_cns_001_models/jit/20221114-defects-step50000.static-1.11.0.jit" +#DEFECT_MODEL_RES: [64, 64, #IMG_RES[2]] +#PAIR_FLOW_TMPL: defect_flow_kwargs: fn: model_path: #DEFECT_MODEL_PATH + +////////////////////////////////// +// MISALIGNMENT DETECTOR CONFIG // +////////////////////////////////// +if true { + #MISD_MODEL_PATH: "gs://zetta-research-nico/training_artifacts/aced_misd_general/3.2.0_dsfactor2_thr1.5_lr0.0001_z1/epoch=44-step=11001-backup.ckpt.model.spec.json" + #MISD_ENCODER: #GENERAL_ENC_MODELS_V2["64"] + #PAIR_FLOW_TMPL: enc_warped_imgs_flow_kwargs: dst_path: "imgs_warped_encoded-v2-z1" + #PAIR_FLOW_TMPL: enc_warped_imgs_flow_kwargs: reencode_tgt: {dst_path: "encodings_misd-v2-z1"} + #PAIR_FLOW_TMPL: misd_flow_kwargs: dst_path: "misalignments-v2-z1" + #MISD_MODEL: {fn: {"@type": "MisalignmentDetector", model_path: #MISD_MODEL_PATH, + apply_sigmoid: true}} + #PAIR_FLOW_TMPL: misd_flow_kwargs: models: [#MISD_MODEL] +} + +if #TEST_SMALL { + #BBOX: {#BBOX_TMPL & {start_coord: [11*4096, 5*4096, 50] + end_coord: [18*4096, 10*4096, 104]}} + #PAIR_FLOW_TMPL: z_offsets: [-1, -2] + + // Don't align bogus pairs (e.g., z=50 to z=49) + #PAIR_FLOW_TMPL: compute_field_flow_kwargs: shrink_bbox_to_z_offsets: true + #PAIR_FLOW_TMPL: invert_field_flow_kwargs: shrink_bbox_to_z_offsets: true + #PAIR_FLOW_TMPL: warp_flow_kwargs: shrink_bbox_to_z_offsets: true + #PAIR_FLOW_TMPL: enc_warped_imgs_flow_kwargs: shrink_bbox_to_z_offsets: true + #PAIR_FLOW_TMPL: misd_flow_kwargs: shrink_bbox_to_z_offsets: true + #ENCODING_MODELS: [...#ENCODING_MODEL_TMPL] + #PAIR_FLOW_TMPL: encoding_flow_kwargs: models: #ENCODING_MODELS + #ENCODING_MODELS: [ + #GEN_THICK_MODELS["512"], + #GEN_THICK_MODELS["256"], + #GEN_THICK_MODELS["128"], + #GEN_THICK_MODELS["64"], + #GEN_THICK_MODELS["32"], + ] + #PAIR_FLOW_TMPL: mask_encodings_flow_kwargs: dst_resolution_list: [ + {dst_resolution: [512, 512, #IMG_RES[2]], fn_kwargs: {opening_width: 2, dilation_width: 2}}, + {dst_resolution: [256, 256, #IMG_RES[2]], fn_kwargs: {opening_width: 2, dilation_width: 2}}, + {dst_resolution: [128, 128, #IMG_RES[2]], fn_kwargs: {opening_width: 2, dilation_width: 2}}, + {dst_resolution: [ 64, 64, #IMG_RES[2]], fn_kwargs: {opening_width: 2, dilation_width: 2}}, + {dst_resolution: [ 32, 32, #IMG_RES[2]], fn_kwargs: {opening_width: 2, dilation_width: 2}}, + ] + #PAIR_FLOW_TMPL: compute_field_flow_kwargs: stages: [ + {dst_resolution: [512, 512, #IMG_RES[2]], fn_kwargs: {sm: 300, num_iter: 700, lr: 0.015}}, + {dst_resolution: [256, 256, #IMG_RES[2]], fn_kwargs: {sm: 150, num_iter: 700, lr: 0.030}}, + {dst_resolution: [128, 128, #IMG_RES[2]], fn_kwargs: {sm: 100, num_iter: 500, lr: 0.050}}, + {dst_resolution: [ 64, 64, #IMG_RES[2]], fn_kwargs: {sm: 50, num_iter: 300, lr: 0.100}}, + {dst_resolution: [ 32, 32, #IMG_RES[2]], fn_kwargs: {sm: 25, num_iter: 200, lr: 0.100}}, + ] + + #PAIR_FLOW_TMPL: run_encoding: #RUN_ENCODE + if #RUN_ENCODE == false + if #PRECOMPUTED_ENCODINGS != "" { + #PAIR_FLOW_TMPL: encoding_flow_kwargs: dst_path: #PRECOMPUTED_ENCODINGS + } + #PAIR_FLOW_TMPL: run_defect: #RUN_DEFECT + #PAIR_FLOW_TMPL: run_binarize_defect: #RUN_DEFECT + if #RUN_MASK_ENCODE == true { + #PAIR_FLOW_TMPL: run_mask_encodings: true + } + if #RUN_MASK_ENCODE == false { + if #PRECOMPUTED_MASKED_ENCODINGS != "" { + #PAIR_FLOW_TMPL: mask_encodings_flow_kwargs: dst_path: #PRECOMPUTED_MASKED_ENCODINGS + } + } + if #RUN_ALIGN { + #PAIR_FLOW_TMPL: run_compute_field: true + #PAIR_FLOW_TMPL: run_invert_field: true + #PAIR_FLOW_TMPL: run_warp: true + } + if #RUN_MISD { + #PAIR_FLOW_TMPL: run_enc_warped_imgs: true + #PAIR_FLOW_TMPL: run_misd: true + } + + /////////////////////////////// + // HACKS FOR RUNNING CUTOUTS // + /////////////////////////////// + + // Reduce processing chunk size for cutouts + #PAIR_FLOW_TMPL: encoding_flow_kwargs: processing_chunk_sizes: [[512, 512, 1]] + #PAIR_FLOW_TMPL: defect_flow_kwargs: processing_chunk_sizes: [[512, 512, 1]] + #PAIR_FLOW_TMPL: binarize_defect_flow_kwargs: processing_chunk_sizes: [[1024, 1024, 1]] + #PAIR_FLOW_TMPL: mask_encodings_flow_kwargs: processing_chunk_sizes: [[1024, 1024, 1]] + #PAIR_FLOW_TMPL: compute_field_flow_kwargs: processing_chunk_sizes: [[2048, 2048, 1]] + #PAIR_FLOW_TMPL: invert_field_flow_kwargs: processing_chunk_sizes: [[2048, 2048, 1]] + #PAIR_FLOW_TMPL: warp_flow_kwargs: processing_chunk_sizes: [[2048, 2048, 1]] + #PAIR_FLOW_TMPL: enc_warped_imgs_flow_kwargs: processing_chunk_sizes: [[2048, 2048, 1]] + #PAIR_FLOW_TMPL: misd_flow_kwargs: processing_chunk_sizes: [[2048, 2048, 1]] + + // Use intermediaries to avoid unaligned writes for testing cutouts + let use_intermediaries = {skip_intermediaries: false} + #TMP_DIRS: [...{...}] + if #TEST_LOCAL { + #TMP_DIRS: [ + {level_intermediaries_dirs: ["file://~/.zetta_utils/tmp0/"]}, + {level_intermediaries_dirs: ["file://~/.zetta_utils/tmp1/"]}, + {level_intermediaries_dirs: ["file://~/.zetta_utils/tmp2/"]}, + {level_intermediaries_dirs: ["file://~/.zetta_utils/tmp3/"]}, + ] + } + if #TEST_LOCAL == false { + #TMP_DIRS: [ + {level_intermediaries_dirs: ["gs://tmp_2w/ant/\(#PROJECT)/tmp0/"]}, + {level_intermediaries_dirs: ["gs://tmp_2w/ant/\(#PROJECT)/tmp1/"]}, + {level_intermediaries_dirs: ["gs://tmp_2w/ant/\(#PROJECT)/tmp2/"]}, + {level_intermediaries_dirs: ["gs://tmp_2w/ant/\(#PROJECT)/tmp3/"]}, + ] + } + #PAIR_FLOW_TMPL: encoding_flow_kwargs: subchunkable_kwargs: use_intermediaries & #TMP_DIRS[0] + #PAIR_FLOW_TMPL: defect_flow_kwargs: subchunkable_kwargs: use_intermediaries & #TMP_DIRS[1] + #PAIR_FLOW_TMPL: binarize_defect_flow_kwargs: subchunkable_kwargs: use_intermediaries & #TMP_DIRS[2] + #PAIR_FLOW_TMPL: mask_encodings_flow_kwargs: subchunkable_kwargs: use_intermediaries & #TMP_DIRS[3] + + // Avoid unaligned writes by adjusting dst offsets + let output_voxel_offset = [#BBOX.start_coord[0], #BBOX.start_coord[1], 0] + let info_voxel_offset_map_ = { + "1024_1024_\(#IMG_RES[2])": [output_voxel_offset[0]/(1024/4), output_voxel_offset[1]/(1024/4), 0] + "512_512_\(#IMG_RES[2])": [output_voxel_offset[0]/(512/4), output_voxel_offset[1]/(512/4), 0] + "256_256_\(#IMG_RES[2])": [output_voxel_offset[0]/(256/4), output_voxel_offset[1]/(256/4), 0] + "128_128_\(#IMG_RES[2])": [output_voxel_offset[0]/(128/4), output_voxel_offset[1]/(128/4), 0] + "64_64_\(#IMG_RES[2])": [output_voxel_offset[0]/(64/4), output_voxel_offset[1]/(64/4), 0] + "32_32_\(#IMG_RES[2])": [output_voxel_offset[0]/(32/4), output_voxel_offset[1]/(32/4), 0] + } + #PAIR_FLOW_TMPL: compute_field_flow_kwargs: dst_factory_kwargs: info_voxel_offset_map: info_voxel_offset_map_ + #PAIR_FLOW_TMPL: warp_flow_kwargs: dst_factory_kwargs: info_voxel_offset_map: info_voxel_offset_map_ + #PAIR_FLOW_TMPL: enc_warped_imgs_flow_kwargs: dst_factory_kwargs: info_voxel_offset_map: info_voxel_offset_map_ + #BBOX_LIST: [#BBOX] +} + + +#RUN_ENCODE: _ | *false +#RUN_DEFECT: _ | *false +#RUN_MASK_ENCODE: _ | *false +#RUN_ALIGN: _ | *false +#RUN_MISD: _ | *false +#PRECOMPUTED_ENCODINGS: _ | *"" +#PRECOMPUTED_MASKED_ENCODINGS: _ | *"" + +#PAIR_FLOW_TMPL: { + "@type": "build_pairwise_alignment_flow" + bbox?: _ + bbox_list?: _ + src_image_path: #IMG_PATH + project_folder?: _ + z_offsets: _ | *[-1, -2] + + run_encoding: _ | *false + encoding_flow_kwargs: #ENCODING_FLOW_SCHEMA & { + dst_path?: _ // defaults to "encodings" + processing_chunk_sizes: _ | *[[4096, 4096, 1], [4096, 4096, 1]] + crop_pad: _ | *[32, 32, 0] + models: _ + } + + run_defect: _ | *false + defect_flow_kwargs: #SUBCHUNKABLE_FLOW_SCHEMA & { + dst_path?: _ // defaults to "defect" + dst_resolution: #DEFECT_MODEL_RES + processing_chunk_sizes: _ | *[[4096, 4096, 1], [512, 512, 1]] + crop_pad: _ | *[512, 512, 0] // good for processing_chunk_size=512 + fn: _ | *{ + "@type": "DefectDetector" + model_path: _ + ds_factor?: _ + tile_size: null // don't use tiling + } + } + + run_binarize_defect: _ | *false + binarize_defect_flow_kwargs: #SUBCHUNKABLE_FLOW_SCHEMA & { + dst_path?: _ // defaults to "defect_binarized" + dst_resolution: #DEFECT_MODEL_RES + processing_chunk_sizes: _ | *[[4096, 4096, 1], [1024, 1024, 1]] + crop_pad: _ | *[128, 128, 0] + fn: _ | *{"@type": "binarize_defect_prediction", "@mode": "partial"} + fn_kwargs: _ | *{ + threshold: 100 + kornia_opening_width: 11 + kornia_dilation_width: 3 + // filter_cc_threshold: 320 + filter_cc_threshold: 240 + // kornia_closing_width: 25 + kornia_closing_width: 30 + } + } + + run_mask_encodings: _ | *false + mask_encodings_flow_kwargs: #MASK_ENCODINGS_FLOW_SCHEMA & { + dst_path?: _ // defaults to "encodings_masked" + fn: _ | *{"@type": "zero_out_src_with_mask", "@mode": "partial"} + mask_resolution: #DEFECT_MODEL_RES + processing_chunk_sizes: _ | *[[4096, 4096, 1], [1024, 1024, 1]] + } + + run_compute_field: _ | *false + compute_field_flow_kwargs: #COMPUTE_FIELD_FLOW_SCHEMA & { + dst_factory_kwargs: { + info_chunk_size: [2048, 2048, 1] + per_scale_config: { + "encoding": "zfpc", + "zfpc_correlated_dims": [true, true, false, false], + "zfpc_tolerance": 0.001953125, + } + } + processing_chunk_sizes: _ | *[[4096, 4096, 1], [2048, 2048, 1]] + crop_pad: [64, 64, 0] + } + + run_invert_field: _ | *false + invert_field_flow_kwargs: #INVERT_FIELD_FLOW_SCHEMA & { + dst_factory_kwargs: { + per_scale_config: { + "encoding": "zfpc", + "zfpc_correlated_dims": [true, true, false, false], + "zfpc_tolerance": 0.001953125, + } + } + processing_chunk_sizes: _ | *[[4096, 4096, 1], [2048, 2048, 1]] + crop_pad: [64, 64, 0] + fn: {"@type": "invert_field", "@mode": "partial"} + } + + run_warp: _ | *false + warp_flow_kwargs: #WARP_FLOW_SCHEMA & { + processing_chunk_sizes: _ | *[[4096, 4096, 1], [2048, 2048, 1]] + crop_pad: [256, 256, 0] + dst_resolution: _ | *[32, 32, #IMG_RES[2]] + } + + run_enc_warped_imgs: _ | *false + enc_warped_imgs_flow_kwargs: #ENCODE_WARPED_IMGS_FLOW_SCHEMA & { + processing_chunk_sizes: _ | *[[4096, 4096, 1], [4096, 4096, 1]] + crop_pad: [32, 32, 0] + model: #MISD_ENCODER + } + + run_misd: _ | *false + misd_flow_kwargs: #MISALIGNMENT_DETECTOR_FLOW_SCHEMA & { + dst_resolution: #MISD_ENCODER.dst_resolution + processing_chunk_sizes: _ | *[[4096, 4096, 1], [2048, 2048, 1]] + crop_pad: [32*6, 32*6, 0] + models: [...#MISD_MODEL_TMPL] // one for each z, or one for all + } +} + +#DEFAULT_LAYER_FACTORY_SCHEMA: { + path?: string + resolution_list?: [...[int, int, int]] // list of res to generate/keep + per_scale_config?: _ // dict of attrs to be added to each scale + ... // any other kwargs for build_cv_layer +} + +#COMMON_FLOW_SCHEMA: { + dst_layer?: _ + dst_path?: string // To be used with dst_factory_kwargs if dst_layer is not provided + dst_factory?: _ // Should be a Callable(), defaults to DEFAULT_LAYER_FACTORY_SCHEMA + dst_factory_kwargs?: _ + ... +} + +#SUBCHUNKABLE_FLOW_SCHEMA: #COMMON_FLOW_SCHEMA & { + dst_resolution?: [int, int, int] + op?: _ + op_kwargs?: _ // kwargs for `op` (not build_subchunkable_apply_flow's op_kwargs) + fn?: _ + fn_kwargs?: _ + processing_chunk_sizes: [...[int, int, int]] + crop_pad?: [int, int, int] // crops for L1 chunks + subchunkable_kwargs?: _ // kwargs for `build_subchunkable_apply_flow` + ... +} + +#ENCODING_MODEL_TMPL: { + path: _ // model path + dst_resolution: _ // output res of this encoder + dst_path?: _ // override flow's dst_path + res_change_mult?: [int, int, int] // defaults to [1, 1, 1] + max_processing_chunk_size?: [int, int] // restricts processing_chunk_size[0:1] + fn: _ | *{ + "@type": "BaseCoarsener" + model_path: path + ds_factor: res_change_mult[0] + tile_size: null // don't use tiling + } + fn_kwargs?: _ // per-model overrides + op_kwargs?: _ // per-model overrides + subchunkable_kwargs?: _ // per-model overrides +} + +#ENCODING_FLOW_SCHEMA: #SUBCHUNKABLE_FLOW_SCHEMA & { + models: [...#ENCODING_MODEL_TMPL] +} + +#MASK_FN_TMPL: { + dst_resolution: _ // output res of this masking step + fn_kwargs?: _ // per-model overrides + src_path?: _ + src_layer?: _ +} + +#MASK_ENCODINGS_FLOW_SCHEMA: #SUBCHUNKABLE_FLOW_SCHEMA & { + dst_resolution_list: [...#MASK_FN_TMPL] + mask_path?: string + mask_layer?: _ + mask_resolution: [int, int, int] +} + +#COMPUTE_FIELD_STAGE_SCHEMA: { + dst_resolution: [int, int, int] + fn: _ | *{"@type": "align_with_online_finetuner", "@mode": "partial"} + fn_kwargs?: _ // args for `fn` + path?: string // shorthand to override src & tgt + ... // any other kwargs for `ComputeFieldStage` +} + +#COMPUTE_FIELD_FLOW_SCHEMA: #COMMON_FLOW_SCHEMA & { + src_layer?: _ + src_path?: string + tgt_layer?: _ + tgt_path?: string // defaults to src_path + stages: [...#COMPUTE_FIELD_STAGE_SCHEMA] + z_offsets?: [...int] // sets by project-wide value if empty + compute_field_multistage_kwargs?: _ // kwargs for `build_compute_field_multistage_flow` + compute_field_stage_kwargs?: _ // kwargs for `ComputeFieldStage` + shrink_bbox_to_z_offsets?: _ | *false +} + +#INVERT_FIELD_FLOW_SCHEMA: #SUBCHUNKABLE_FLOW_SCHEMA & { + dst_path?: string // defaults to "fields_inv" + src_path?: string + z_offsets?: [...int] // sets by project + dst_resolution?: [int, int, int] // defaults to last res in compute field + shrink_bbox_to_z_offsets?: bool +} + +#WARP_FLOW_SCHEMA: #SUBCHUNKABLE_FLOW_SCHEMA & { + dst_path?: string // defaults to "imgs_warped" + dst_resolution: [int, int, int] // output warping res + field_path?: string + field_resolution?: [int, int, int] // defaults to invert_field's res + z_offsets?: [int, int, int] // defaults to project's + dst_resolution?: [int, int, int] // defaults to last res in compute field + shrink_bbox_to_z_offsets?: bool +} + +#REENCODE_TGT_OPTIONS: { + src_path?: _ // input path, defaults to flow's + dst_path: _ // output path +} + +#MISD_ENCODER: #ENCODING_MODEL_TMPL + +#ENCODE_WARPED_IMGS_FLOW_SCHEMA: #SUBCHUNKABLE_FLOW_SCHEMA & { + model: #ENCODING_MODEL_TMPL + dst_path?: string // defaults to "imgs_warped_encoded" + z_offsets?: [...int] // defaults to project's + dst_resolution?: [int, int, int] // defaults to last res in compute field + shrink_bbox_to_z_offsets?: bool + reencode_tgt?: #REENCODE_TGT_OPTIONS // if misd's encoder is different from alignment's encoders +} + +#MISD_MODEL_TMPL: { + dst_resolution?: [int, int, int] // defaults to flow's + fn: _ + max_processing_chunk_size?: [int, int] +} + +#MISALIGNMENT_DETECTOR_FLOW_SCHEMA: #SUBCHUNKABLE_FLOW_SCHEMA & { + models: [...#MISD_MODEL_TMPL] // One per z_offset. Duplicate if only 1 is given. + tgt_layer?: _ + tgt_path?: string // defaults to src_path + shrink_bbox_to_z_offsets?: bool +} + +#BBOX_TMPL: { + "@type": "BBox3D.from_coords" + start_coord: _ + end_coord: _ + resolution: #IMG_RES +} +#BBOX_LIST: _ | *{#BBOX_TMPL & {start_coord: [0, 0, 0] + end_coord: #IMG_SIZE}} + +#TEST_LOCAL: _ | *false +#TOP_LEVEL_FLOW: _ | *#GCP_FLOW +if #TEST_LOCAL { + #TOP_LEVEL_FLOW: #LOCAL_FLOW +} +#LOCAL_FLOW: { + "@type": "mazepa.execute_locally" + num_procs: 4 + semaphores_spec: { + read: num_procs + write: num_procs + cuda: 2 + cpu: num_procs + } + target: _ +} +#GCP_FLOW: { + "@type": "mazepa.execute_on_gcp_with_sqs" + worker_cluster_region: "us-east1" + worker_cluster_project: "zetta-research" + worker_cluster_name: "zutils-x3" + worker_image: "us-east1-docker.pkg.dev/zetta-research/zutils/zetta_utils:tri-240117-test-ant" + worker_resources: { + memory: "21000Mi" // sized for n1-highmem-4 + "nvidia.com/gpu": "1" + } + worker_replicas: #CLUSTER_NUM_WORKERS + local_test: #TEST_LOCAL + target: _ +} +#TOP_LEVEL_FLOW & { + target: #PAIR_FLOW_TMPL & {bbox: #BBOX_LIST[0]} +} + +#GENERAL_ENC_MODELS_V2: { + "32": { + path: "gs://alignment_models/general_encoders_2023/32_32_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [1, 1, 1] // 0 3-3: 32-32 + dst_resolution: [32, 32, #IMG_RES[2]] + max_processing_chunk_size: [4096, 4096] + }, + "64": { + path: "gs://alignment_models/general_encoders_2023/32_64_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [2, 2, 1] // 1 3-4: 32-64 + dst_resolution: [64, 64, #IMG_RES[2]] + max_processing_chunk_size: [2048, 2048] + }, + "128": { + path: "gs://alignment_models/general_encoders_2023/32_128_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [4, 4, 1] // 2 3-5: 32-128 + dst_resolution: [128, 128, #IMG_RES[2]] + max_processing_chunk_size: [1024, 1024] + }, + "256": { + path: "gs://alignment_models/general_encoders_2023/32_256_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [8, 8, 1] // 3 3-6: 32-256 + dst_resolution: [256, 256, #IMG_RES[2]] + max_processing_chunk_size: [512, 512] + }, + "512": { + path: "gs://alignment_models/general_encoders_2023/32_512_C1/2023-11-20.static-2.0.1-model.jit" + res_change_mult: [16, 16, 1] // 4 3-7: 32-512 + dst_resolution: [512, 512, #IMG_RES[2]] + max_processing_chunk_size: [256, 256] + }, +} + +#GEN_THICK_MODELS: { + "32": { + path: "gs://zetta-research-nico/training_artifacts/general_encoder_loss/4.0.1_M3_M3_C1_lr0.0002_locality1.0_similarity0.0_l10.0-0.03_N1x4/last.ckpt.model.spec.json" + res_change_mult: [1, 1, 1] // 0 3-3: 32-32 + dst_resolution: [32, 32, #IMG_RES[2]] + max_processing_chunk_size: [4096, 4096] + }, + "64": { + path: "gs://zetta-research-nico/training_artifacts/general_coarsener_loss/4.0.0_M3_M4_C1_lr0.0002_locality1.0_similarity0.0_l10.0-0.06_N1x4/last.ckpt.model.spec.json" + res_change_mult: [2, 2, 1] // 1 3-4: 32-64 + dst_resolution: [64, 64, #IMG_RES[2]] + max_processing_chunk_size: [2048, 2048] + }, + "128": { + path: "gs://zetta-research-nico/training_artifacts/general_coarsener_loss/4.0.0_M3_M5_C1_lr0.0002_locality1.0_similarity0.0_l10.0-0.08_N1x4/last.ckpt.model.spec.json" + res_change_mult: [4, 4, 1] // 2 3-5: 32-128 + dst_resolution: [128, 128, #IMG_RES[2]] + max_processing_chunk_size: [1024, 1024] + }, + "256": { + path: "gs://zetta-research-nico/training_artifacts/general_coarsener_loss/4.4.0_M3_M6_C1_lr0.0002_locality1.0_similarity0.0_l10.05-0.12_N1x4/last.ckpt.model.spec.json" + res_change_mult: [8, 8, 1] // 3 3-6: 32-256 + dst_resolution: [256, 256, #IMG_RES[2]] + max_processing_chunk_size: [512, 512] + }, + "512": { + path: "gs://zetta-research-nico/training_artifacts/general_coarsener_loss/4.0.0_M3_M7_C1_lr0.0002_locality1.0_similarity0.0_l10.0-0.12_N1x4/last.ckpt.model.spec.json" + res_change_mult: [16, 16, 1] // 4 3-7: 32-512 + dst_resolution: [512, 512, #IMG_RES[2]] + max_processing_chunk_size: [256, 256] + } +} diff --git a/zetta_utils/mazepa_layer_processing/alignment/pairwise_alignment.py b/zetta_utils/mazepa_layer_processing/alignment/pairwise_alignment.py new file mode 100644 index 000000000..fab7dd95b --- /dev/null +++ b/zetta_utils/mazepa_layer_processing/alignment/pairwise_alignment.py @@ -0,0 +1,1147 @@ +from __future__ import annotations + +import copy +import os +from functools import partial +from typing import Any, Callable, Mapping, Sequence + +import attrs +import torch + +from zetta_utils import builder, mazepa +from zetta_utils.geometry import BBox3D, IntVec3D +from zetta_utils.layer.volumetric import VolumetricIndexTranslator, VolumetricLayer +from zetta_utils.layer.volumetric.cloudvol import build_cv_layer +from zetta_utils.mazepa_layer_processing.common import ( + VolumetricCallableOperation, + build_subchunkable_apply_flow, +) +from zetta_utils.tensor_ops.common import compare +from zetta_utils.tensor_ops.convert import to_uint8 +from zetta_utils.tensor_ops.mask import ( + filter_cc, + kornia_closing, + kornia_dilation, + kornia_opening, +) + +from .compute_field_multistage_flow import ( + ComputeFieldMultistageFlowSchema, + ComputeFieldStage, +) +from .warp_operation import WarpOperation + + +def _default_layer_factory( + path: str, + resolution_list: Sequence[Sequence[int]] | None = None, + per_scale_config: Mapping[str, Any] | None = None, + **build_cv_kwargs, +): + if resolution_list is None: + resolution_list = [] + if per_scale_config is None: + per_scale_config = {} + + info_add_scales = None + if resolution_list is not None: + assert isinstance(resolution_list, Sequence) + info_add_scales = [] + for res in resolution_list: + scale = { + "resolution": res, + } | per_scale_config + info_add_scales.append(scale) + info_add_scales = build_cv_kwargs.pop("info_add_scales", info_add_scales) + + return build_cv_layer( + path=path, + info_add_scales=build_cv_kwargs.pop("info_add_scales", info_add_scales), + info_add_scales_mode=build_cv_kwargs.pop("info_add_scales_mode", "replace"), + on_info_exists=build_cv_kwargs.pop("on_info_exists", "overwrite"), + **build_cv_kwargs, + ) + + +def _pad_crop_pads(crop_pad, length): + return [[0, 0, 0]] * (length - 1) + [crop_pad] + + +def _set_volumetric_callable_default_op_kwargs( + res_change_mult=None, + fn_uses_cuda=False, + task_name=None, + overrides=None, +): + if overrides is None: + overrides = {} + kwargs = {} + if task_name is not None: + kwargs["operation_name"] = task_name + if fn_uses_cuda: + kwargs["fn_semaphores"] = ["cuda"] + if res_change_mult is not None: + kwargs["res_change_mult"] = res_change_mult + kwargs |= overrides + return kwargs + + +@mazepa.flow_schema_cls +@attrs.mutable +class SubchunkableFnFlowSchema: + op: Callable + subchunkable_kwargs: Mapping[str, Any] + + def __init__( + self, + dst_resolution: Sequence[int], + processing_chunk_sizes: Sequence[Sequence[int]], + task_name: str = None, + op=None, + op_kwargs=None, + fn=None, + fn_kwargs: Mapping[str, Any] | None = None, + fn_uses_cuda: bool = False, + model_res_change_mult: Sequence[int] | None = None, + model_max_processing_chunk_size: Sequence[int] | None = None, + src_path: str | None = None, + src_layer: VolumetricLayer | None = None, + dst_path: str | None = None, + dst_layer: VolumetricLayer | None = None, + dst_factory: Callable | None = None, + dst_factory_kwargs: Mapping[str, Any] | None = None, + crop_pad: Sequence[int] | None = None, + subchunkable_kwargs=None, + ): + if src_layer is None: + src_layer = build_cv_layer(src_path) + if dst_factory_kwargs is None: + dst_factory_kwargs = {} + if subchunkable_kwargs is None: + subchunkable_kwargs = {} + if op_kwargs is None: + op_kwargs = {} + if fn_kwargs is None: + fn_kwargs = {} + + if dst_layer is None: + if dst_factory is None: + dst_factory = _default_layer_factory + # Provide some essential arguments + dst_factory_kwargs = { + "info_reference_path": src_path, + "resolution_list": [dst_resolution], + } | dst_factory_kwargs + dst_layer = dst_factory(dst_path, **dst_factory_kwargs) + + if crop_pad is None: + crop_pad = [0, 0, 0] + + if model_max_processing_chunk_size is not None: + max_chunk_size = list(model_max_processing_chunk_size) + assert len(max_chunk_size) == 2 or len(max_chunk_size) == 3 + if len(max_chunk_size) == 2: + # pad dimension + max_chunk_size.append(1) + processing_chunk_sizes = copy.deepcopy(processing_chunk_sizes) + processing_chunk_sizes[-1] = [ + min(a, b) for a, b in zip(max_chunk_size, processing_chunk_sizes[-1]) + ] + + self.subchunkable_kwargs = { + "processing_chunk_sizes": processing_chunk_sizes, + "processing_crop_pads": _pad_crop_pads(crop_pad, len(processing_chunk_sizes)), + "skip_intermediaries": True, + "level_intermediaries_dirs": None, + "dst": dst_layer, + "dst_resolution": dst_resolution, + "op_kwargs": {"src": src_layer}, + } | subchunkable_kwargs + + if op is None: + # wrap provided fn with VolumetricCallableOperation + if len(fn_kwargs): + fn = partial(fn, **fn_kwargs) + op_kwargs = _set_volumetric_callable_default_op_kwargs( + res_change_mult=model_res_change_mult, + fn_uses_cuda=fn_uses_cuda, + task_name=task_name, + overrides=op_kwargs, + ) + self.op = VolumetricCallableOperation(fn, **op_kwargs) + else: + if len(op_kwargs): + op = partial(op, **op_kwargs) + self.op = op + + def flow(self, bbox): + flow = build_subchunkable_apply_flow( + bbox=bbox, + op=self.op, + **self.subchunkable_kwargs, + ) + yield flow + + +@mazepa.flow_schema_cls +@attrs.mutable +class EncodingFlowSchema: + subchunkable_flows: Sequence[SubchunkableFnFlowSchema] + + def __init__( + self, + models: Sequence[Any], + dst_path: str | None = None, + dst_factory_kwargs: Mapping[str, Any] | None = None, + subchunkable_kwargs=None, + op_kwargs=None, + fn_kwargs=None, + **kwargs, + ): + if subchunkable_kwargs is None: + subchunkable_kwargs = {} + if op_kwargs is None: + op_kwargs = {} + if fn_kwargs is None: + fn_kwargs = {} + if dst_factory_kwargs is None: + dst_factory_kwargs = {} + + dst_factory_kwargs = { + "resolution_list": [model["dst_resolution"] for model in models], + "info_field_overrides": {"data_type": "int8"}, + } | dst_factory_kwargs + + self.subchunkable_flows = [] + for model in models: + + fn = model["fn"] + model_fn_kwargs = fn_kwargs | model.get("fn_kwargs", {}) + model_subchunkable_kwargs = subchunkable_kwargs | model.get("subchunkable_kwargs", {}) + model_op_kwargs = op_kwargs | model.get("op_kwargs", {}) + dst_resolution = model["dst_resolution"] + dst_path_ = model.get("dst_path", dst_path) + + flow = SubchunkableFnFlowSchema( + fn=fn, + fn_kwargs=model_fn_kwargs, + task_name=f"ImageEncoding-{dst_resolution[0]}", + fn_uses_cuda=True, + dst_resolution=dst_resolution, + model_res_change_mult=model.get("res_change_mult", [1, 1, 1]), + model_max_processing_chunk_size=model.get("max_processing_chunk_size", None), + dst_path=dst_path_, + # dst_factory_kwargs=model_dst_factory_kwargs, + dst_factory_kwargs=dst_factory_kwargs, + subchunkable_kwargs=model_subchunkable_kwargs, + op_kwargs=model_op_kwargs, + **kwargs, + ) + self.subchunkable_flows.append(flow) + + def flow(self, bbox): + for flow in self.subchunkable_flows: + yield flow(bbox) + + +@mazepa.flow_schema_cls +@attrs.mutable +class DefectFlowSchema: + subchunkable_flow: SubchunkableFnFlowSchema + + def __init__( + self, + **kwargs, + ): + self.subchunkable_flow = SubchunkableFnFlowSchema( + task_name="Defect", + fn_uses_cuda=True, + **kwargs, + ) + + def flow(self, bbox): + yield self.subchunkable_flow(bbox) + + +@builder.register("binarize_defect_prediction") +def binarize_defect_prediction( + src: torch.Tensor, + threshold, + kornia_opening_width: int = 0, + kornia_dilation_width: int = 0, + filter_cc_threshold: int = 0, + kornia_closing_width: int = 0, +): + pred = compare(src, mode=">=", value=threshold, binarize=True) + + mask = to_uint8(pred) # kornia errors out with `bool`? + if kornia_opening_width: + # remove thin line from mask + mask = kornia_opening(mask, width=kornia_opening_width) + if kornia_dilation_width: + # grow mask a little + mask = kornia_dilation(mask, width=kornia_dilation_width) + + pred = torch.where(mask > 0, 0, pred) + + if filter_cc_threshold: + # remove small islands that are likely FPs + pred = filter_cc(pred, mode="keep_large", thr=filter_cc_threshold) + if kornia_closing_width: + # connect disconnected folds + pred = kornia_closing(pred, width=kornia_closing_width) + + return to_uint8(pred) + + +@mazepa.flow_schema_cls +@attrs.mutable +class BinarizeDefectFlowSchema: + subchunkable_flow: SubchunkableFnFlowSchema + + def __init__( + self, + **kwargs, + ): + self.subchunkable_flow = SubchunkableFnFlowSchema( + task_name="BinarizeDefect", + **kwargs, + ) + + def flow(self, bbox): + yield self.subchunkable_flow(bbox) + + +@builder.register("zero_out_src_with_mask") +def zero_out_src_with_mask2(src, mask, opening_width=0, dilation_width=0): + # opening_width=2 finds and removes >=2px wide masks + # opening_width=3 finds and removes >=3px wide masks + # dilation_width=2 grows mask by 1px + # dilation_width=3 grows mask by 2px + if opening_width > 0: + mask0 = mask + exclusion_from_dilation = kornia_opening(mask, width=opening_width) + mask = mask & torch.logical_not(exclusion_from_dilation) + if dilation_width > 0: + mask = kornia_dilation(mask, width=dilation_width) + if opening_width > 0: + mask |= mask0 + return torch.where(mask > 0, 0, src) # where(cond, true, false) + + +@mazepa.flow_schema_cls +@attrs.mutable +class MaskEncodingsFlowSchema: + subchunkable_flows: Sequence[SubchunkableFnFlowSchema] + + def __init__( + self, + dst_resolution_list: Sequence[Sequence[int] | Mapping[str, Any]], + fn_kwargs: Mapping[str, Any] | None = None, + src_path: str | None = None, + src_layer: VolumetricLayer | None = None, + mask_path: str | None = None, + mask_layer: VolumetricLayer | None = None, + mask_resolution: Sequence[int] | None = None, + dst_factory_kwargs: Mapping[str, Any] | None = None, + subchunkable_kwargs=None, + **kwargs, + ): + if fn_kwargs is None: + fn_kwargs = {} + if subchunkable_kwargs is None: + subchunkable_kwargs = {} + if dst_factory_kwargs is None: + dst_factory_kwargs = {} + + # assume fn takes src & mask as inputs, add mask src to op_kwargs + if mask_layer is None: + assert mask_path is not None + mask_layer = build_cv_layer( + mask_path, + data_resolution=mask_resolution, + interpolation_mode="mask", + ) + op_kwargs_mask = {"op_kwargs": {"mask": mask_layer}} + subchunkable_kwargs = op_kwargs_mask | subchunkable_kwargs + + dst_factory_kwargs = { + "resolution_list": [model["dst_resolution"] for model in dst_resolution_list], + "info_field_overrides": {"data_type": "int8"}, + } | dst_factory_kwargs + + self.subchunkable_flows = [] + for model in dst_resolution_list: + dst_resolution = model["dst_resolution"] + fn_kwargs_ = fn_kwargs | model.get("fn_kwargs", {}) + + # add src to op_kwargs in subchunkable_kwargs + src_path_ = model.get("src_path", src_path) + src_layer_ = model.get("src_layer", src_layer) + if src_layer_ is None: + src_layer_ = build_cv_layer(src_path_) + subchunkable_kwargs_ = copy.deepcopy(subchunkable_kwargs) + subchunkable_kwargs_["op_kwargs"]["src"] = src_layer_ + + flow = SubchunkableFnFlowSchema( + fn_kwargs=fn_kwargs_, + task_name=f"MaskEncodings-{dst_resolution[0]}", + fn_uses_cuda=False, + dst_resolution=dst_resolution, + model_res_change_mult=(1, 1, 1), + src_path=src_path_, + src_layer=src_layer_, + dst_factory_kwargs=dst_factory_kwargs, + subchunkable_kwargs=subchunkable_kwargs_, + **kwargs, + ) + self.subchunkable_flows.append(flow) + + def flow(self, bbox): + for flow in self.subchunkable_flows: + yield flow(bbox) + + +@mazepa.flow_schema_cls +@attrs.mutable +class ComputeFieldFlowSchema: + flows: Sequence[ComputeFieldMultistageFlowSchema] + z_offsets: Sequence[int] + shrink_bbox_to_z_offsets: bool + z_offset_resolution: int + + def __init__( + self, + stages: Sequence[Mapping[str, Any]], + z_offsets: Sequence[int], + # resume_path: str | None = None, + # resume_resolution: Sequence[int] | None = None, + processing_chunk_sizes: Sequence[Sequence[int]], + shrink_bbox_to_z_offsets: bool = False, + src_path: str | None = None, + src_layer: VolumetricLayer | None = None, + tgt_path: str | None = None, + tgt_layer: VolumetricLayer | None = None, + dst_path: str | None = None, + dst_factory: Callable | None = None, + dst_factory_kwargs: Mapping[str, Any] | None = None, + crop_pad: Sequence[int] | None = None, + compute_field_multistage_kwargs: Mapping[str, Any] | None = None, + compute_field_stage_kwargs: Mapping[str, Any] | None = None, + ): + if len(stages) == 0: + raise RuntimeError("Input `stages` is empty") + + if compute_field_multistage_kwargs is None: + compute_field_multistage_kwargs = {} + if compute_field_stage_kwargs is None: + compute_field_stage_kwargs = {} + if crop_pad is None: + crop_pad = [0, 0, 0] + + self.z_offsets = z_offsets + self.shrink_bbox_to_z_offsets = shrink_bbox_to_z_offsets + + if tgt_path is None: + tgt_path = src_path + if src_layer is None: + src_layer = build_cv_layer(src_path) + if tgt_layer is None: + tgt_layer = build_cv_layer(tgt_path) + if dst_factory is None: + dst_factory = _default_layer_factory + + z_offset_resolution = {stage["dst_resolution"][2] for stage in stages} + if len(z_offset_resolution) > 1: + raise RuntimeError("Inconsistent z resolutions between stages!") + z_offset_resolution = z_offset_resolution.pop() + self.z_offset_resolution = z_offset_resolution + + cf_stages = [] + for stage in stages: + cf_kwargs = {} + cf_kwargs["dst_resolution"] = stage["dst_resolution"] + cf_kwargs["processing_chunk_sizes"] = processing_chunk_sizes + cf_kwargs["processing_crop_pads"] = _pad_crop_pads( + crop_pad, len(processing_chunk_sizes) + ) + cf_kwargs["expand_bbox_processing"] = True + cf_kwargs["shrink_processing_chunk"] = False + cf_kwargs["fn"] = partial(stage["fn"], **stage.get("fn_kwargs", {})) + if "path" in stage: + layer = build_cv_layer(stage["path"]) + cf_kwargs["src"] = layer + cf_kwargs["tgt"] = layer + # override with user values + cf_kwargs |= compute_field_stage_kwargs + cf_kwargs |= stage.get("cf_kwargs", {}) + cf_stage = ComputeFieldStage(**cf_kwargs) + cf_stages.append(cf_stage) + + self.flows = [] + for z_offset in z_offsets: + + dst_path_ = os.path.join(dst_path, str(z_offset)) + dst_kwargs = { + "path": dst_path_, + "info_reference_path": src_path, + "resolution_list": [stage["dst_resolution"] for stage in stages], + "info_field_overrides": { + "type": "image", + "data_type": "float32", + "num_channels": 2, + }, + } | dst_factory_kwargs + dst_layer = dst_factory(**dst_kwargs) + default_tmp_layer_factory = partial( + build_cv_layer, + info_reference_path=dst_path_, + on_info_exists="overwrite", + ) + default_tmp_layer_dir = os.path.join(dst_path_, "tmp") + + ms_kwargs = {} + ms_kwargs["stages"] = cf_stages + ms_kwargs["tmp_layer_dir"] = default_tmp_layer_dir + ms_kwargs["tmp_layer_factory"] = default_tmp_layer_factory + ms_kwargs["src"] = src_layer + ms_kwargs["tgt"] = tgt_layer + ms_kwargs["dst"] = dst_layer + ms_kwargs["tgt_offset"] = [0, 0, z_offset] + ms_kwargs["offset_resolution"] = [1, 1, z_offset_resolution] + # override with user values + ms_kwargs |= compute_field_multistage_kwargs + + flow = ComputeFieldMultistageFlowSchema( + stages=ms_kwargs.pop("stages"), + tmp_layer_dir=ms_kwargs.pop("tmp_layer_dir"), + tmp_layer_factory=ms_kwargs.pop("tmp_layer_factory"), + ) + self.flows.append(partial(flow, **ms_kwargs)) + + def flow(self, bbox): + for flow, z_offset in zip(self.flows, self.z_offsets): + bbox_ = bbox + if self.shrink_bbox_to_z_offsets: + bbox_ = _shrink_bbox_to_z_offset(bbox, z_offset, self.z_offset_resolution) + yield flow(bbox=bbox_) + + +def _shrink_bbox_to_z_offset(bbox, z_offset, z_offset_resolution): + if z_offset < 0: + bbox_ = BBox3D.from_coords( + bbox.start + IntVec3D(0, 0, -z_offset * z_offset_resolution), bbox.end + ) + else: + bbox_ = BBox3D.from_coords( + bbox.start, bbox.end - IntVec3D(0, 0, z_offset * z_offset_resolution) + ) + return bbox_ + + +@mazepa.flow_schema_cls +@attrs.mutable +class InvertFieldFlowSchema: + flows: Sequence[SubchunkableFnFlowSchema] + z_offsets: Sequence[int] + shrink_bbox_to_z_offsets: bool + z_offset_resolution: int + + def __init__( + self, + dst_resolution: Sequence[int], + z_offsets: Sequence[int], + shrink_bbox_to_z_offsets: bool = False, + src_path: str | None = None, + dst_path: str | None = None, + **kwargs, + ): + self.z_offsets = z_offsets + self.shrink_bbox_to_z_offsets = shrink_bbox_to_z_offsets + self.z_offset_resolution = dst_resolution[2] + self.flows = [] + for z_offset in z_offsets: + src_path_ = os.path.join(src_path, str(z_offset)) + dst_path_ = os.path.join(dst_path, str(z_offset)) + flow = SubchunkableFnFlowSchema( + task_name=f"InvertField_z{z_offset}", + dst_resolution=dst_resolution, + src_path=src_path_, + dst_path=dst_path_, + **kwargs, + ) + self.flows.append(flow) + + def flow(self, bbox): + for flow, z_offset in zip(self.flows, self.z_offsets): + bbox_ = bbox + if self.shrink_bbox_to_z_offsets: + bbox_ = _shrink_bbox_to_z_offset(bbox, z_offset, self.z_offset_resolution) + yield flow(bbox=bbox_) + + +@mazepa.flow_schema_cls +@attrs.mutable +class WarpFlowSchema: + flows: Sequence[SubchunkableFnFlowSchema] + z_offsets: Sequence[int] + shrink_bbox_to_z_offsets: bool + z_offset_resolution: int + + def __init__( + self, + dst_resolution: Sequence[int], + z_offsets: Sequence[int], + src_path: str, + field_path: str, + dst_path: str, + shrink_bbox_to_z_offsets: bool = False, + field_resolution: Sequence[int] = None, + subchunkable_kwargs=None, + **kwargs, + ): + if subchunkable_kwargs is None: + subchunkable_kwargs = {} + + self.z_offsets = z_offsets + self.shrink_bbox_to_z_offsets = shrink_bbox_to_z_offsets + self.z_offset_resolution = dst_resolution[2] + + self.flows = [] + for z_offset in z_offsets: + + field_path_ = os.path.join(field_path, str(z_offset)) + dst_path_ = os.path.join(dst_path, str(z_offset)) + src_idx_translator = VolumetricIndexTranslator( + offset=[0, 0, z_offset], resolution=dst_resolution + ) + src_layer = build_cv_layer(src_path, index_procs=[src_idx_translator]) + field_layer = build_cv_layer( + field_path_, data_resolution=field_resolution, interpolation_mode="field" + ) + subchunkable_kwargs_ = { + "op_kwargs": { + "src": src_layer, + "field": field_layer, + } + } | subchunkable_kwargs + + flow = SubchunkableFnFlowSchema( + task_name=f"Warp_{z_offset}", + op=WarpOperation(mode="img"), + dst_resolution=dst_resolution, + fn_uses_cuda=False, + src_path=src_path, + src_layer=src_layer, + dst_path=dst_path_, + subchunkable_kwargs=subchunkable_kwargs_, + **kwargs, + ) + self.flows.append(flow) + + def flow(self, bbox): + for flow, z_offset in zip(self.flows, self.z_offsets): + bbox_ = bbox + if self.shrink_bbox_to_z_offsets: + bbox_ = _shrink_bbox_to_z_offset(bbox, z_offset, self.z_offset_resolution) + yield flow(bbox=bbox_) + + +@mazepa.flow_schema_cls +@attrs.mutable +class EncodeWarpedImgsFlowSchema: + flows: Sequence[SubchunkableFnFlowSchema] + z_offsets: Sequence[int] + shrink_bbox_to_z_offsets: bool + z_offset_resolution: int + + def __init__( + self, + model: Mapping[str, Any], + z_offsets: Sequence[int], + src_path: str, + dst_path: str, + src_resolution: Sequence[int] | None = None, + shrink_bbox_to_z_offsets: bool = False, + dst_factory_kwargs: Mapping[str, Any] | None = None, + reencode_tgt: Mapping[str, Any] | None = None, + **kwargs, + ): + if dst_factory_kwargs is None: + dst_factory_kwargs = {} + + fn = model["fn"] + dst_resolution = model["dst_resolution"] + + self.z_offsets = copy.deepcopy(z_offsets) + self.shrink_bbox_to_z_offsets = shrink_bbox_to_z_offsets + self.z_offset_resolution = dst_resolution[2] + + dst_factory_kwargs = { + "resolution_list": [dst_resolution], + "info_field_overrides": {"data_type": "int8"}, + } | dst_factory_kwargs + + self.flows = [] + for z_offset in self.z_offsets: + + src_path_ = os.path.join(src_path, str(z_offset)) + dst_path_ = os.path.join(dst_path, str(z_offset)) + src_layer = build_cv_layer( + src_path_, data_resolution=src_resolution, interpolation_mode="img" + ) + + flow = SubchunkableFnFlowSchema( + fn=fn, + task_name=f"EncodeWarpedImg_z{z_offset}", + fn_uses_cuda=True, + dst_resolution=dst_resolution, + model_res_change_mult=model.get("res_change_mult", [1, 1, 1]), + model_max_processing_chunk_size=model.get("max_processing_chunk_size", None), + dst_factory_kwargs=dst_factory_kwargs, + src_path=src_path_, + src_layer=src_layer, + dst_path=dst_path_, + **kwargs, + ) + self.flows.append(flow) + + if reencode_tgt is not None: + # re-encode tgt with the given encoder + src_path_ = reencode_tgt["src_path"] + dst_path_ = reencode_tgt["dst_path"] + flow = SubchunkableFnFlowSchema( + fn=fn, + task_name="EncodeWarpedImg_tgt", + fn_uses_cuda=True, + dst_resolution=dst_resolution, + model_res_change_mult=model.get("res_change_mult", [1, 1, 1]), + model_max_processing_chunk_size=model.get("max_processing_chunk_size", None), + dst_factory_kwargs=dst_factory_kwargs, + src_path=src_path_, + dst_path=dst_path_, + **kwargs, + ) + self.flows.append(flow) + self.z_offsets.append(0) + + def flow(self, bbox): + for flow, z_offset in zip(self.flows, self.z_offsets): + bbox_ = bbox + if self.shrink_bbox_to_z_offsets: + bbox_ = _shrink_bbox_to_z_offset(bbox, z_offset, self.z_offset_resolution) + yield flow(bbox=bbox_) + + +@mazepa.flow_schema_cls +@attrs.mutable +class MisalignmentDetectorFlowSchema: + flows: Sequence[SubchunkableFnFlowSchema] + z_offsets: Sequence[int] + shrink_bbox_to_z_offsets: bool + z_offset_resolution: int + + def __init__( + self, + dst_resolution: Sequence[int], + models: Sequence[Mapping[str, Any]], # one per z_offset + z_offsets: Sequence[int], + src_path: str, + dst_path: str, + tgt_path: str | None = None, + tgt_layer: VolumetricLayer | None = None, + shrink_bbox_to_z_offsets: bool = False, + dst_factory_kwargs: Mapping[str, Any] | None = None, + subchunkable_kwargs=None, + **kwargs, + ): + if dst_factory_kwargs is None: + dst_factory_kwargs = {} + if subchunkable_kwargs is None: + subchunkable_kwargs = {} + + self.z_offsets = z_offsets + self.shrink_bbox_to_z_offsets = shrink_bbox_to_z_offsets + self.z_offset_resolution = dst_resolution[2] + + dst_factory_kwargs = { + "resolution_list": [dst_resolution], + "info_field_overrides": {"data_type": "uint8"}, + } | dst_factory_kwargs + + # add tgt to op_kwargs + if tgt_layer is None: + assert tgt_path is not None + tgt_layer = build_cv_layer(tgt_path) + op_kwargs_mask = {"op_kwargs": {"tgt": tgt_layer}} + subchunkable_kwargs = op_kwargs_mask | subchunkable_kwargs + + if len(models) == 1: + models = models * len(z_offsets) + assert len(models) == len(z_offsets) + + self.flows = [] + for z_offset, model in zip(z_offsets, models): + + fn = model["fn"] + dst_resolution_ = model.get("dst_resolution", dst_resolution) + + src_path_ = os.path.join(src_path, str(z_offset)) + dst_path_ = os.path.join(dst_path, str(z_offset)) + src_layer = build_cv_layer(src_path_) + + subchunkable_kwargs_ = copy.deepcopy(subchunkable_kwargs) + subchunkable_kwargs_["op_kwargs"]["src"] = src_layer + + flow = SubchunkableFnFlowSchema( + fn=fn, + task_name=f"Misd_z{z_offset}", + fn_uses_cuda=True, + dst_resolution=dst_resolution_, + model_max_processing_chunk_size=model.get("max_processing_chunk_size", None), + dst_factory_kwargs=dst_factory_kwargs, + src_path=src_path_, + src_layer=src_layer, + dst_path=dst_path_, + subchunkable_kwargs=subchunkable_kwargs_, + **kwargs, + ) + self.flows.append(flow) + + def flow(self, bbox): + for flow, z_offset in zip(self.flows, self.z_offsets): + bbox_ = bbox + if self.shrink_bbox_to_z_offsets: + bbox_ = _shrink_bbox_to_z_offset(bbox, z_offset, self.z_offset_resolution) + yield flow(bbox=bbox_) + + +@builder.register("binarize_misd_mask") +def binarize_misd_mask(src, threshold): + src = compare(src, mode=">=", value=threshold) + return to_uint8(src) + + +@mazepa.flow_schema_cls +@attrs.mutable +class BinarizeMisalignmentFlowSchema: + flows: Sequence[SubchunkableFnFlowSchema] + z_offsets: Sequence[int] + shrink_bbox_to_z_offsets: bool + z_offset_resolution: int + + def __init__( + self, + dst_resolution: Sequence[int], + models: Sequence[Mapping[str, Any]], # one per z_offset + z_offsets: Sequence[int], + src_path: str, + dst_path: str, + shrink_bbox_to_z_offsets: bool = False, + **kwargs, + ): + self.z_offsets = z_offsets + self.shrink_bbox_to_z_offsets = shrink_bbox_to_z_offsets + self.z_offset_resolution = dst_resolution[2] + + if len(models) == 1: + models = models * len(z_offsets) + assert len(models) == len(z_offsets) + + self.flows = [] + for z_offset, model in zip(z_offsets, models): + + fn = model["fn"] + dst_resolution_ = model.get("dst_resolution", dst_resolution) + + src_path_ = os.path.join(src_path, str(z_offset)) + dst_path_ = os.path.join(dst_path, str(z_offset)) + + flow = SubchunkableFnFlowSchema( + fn=fn, + task_name=f"BinarizeMisd_z{z_offset}", + dst_resolution=dst_resolution_, + src_path=src_path_, + dst_path=dst_path_, + **kwargs, + ) + self.flows.append(flow) + + def flow(self, bbox): + for flow, z_offset in zip(self.flows, self.z_offsets): + bbox_ = bbox + if self.shrink_bbox_to_z_offsets: + bbox_ = _shrink_bbox_to_z_offset(bbox, z_offset, self.z_offset_resolution) + yield flow(bbox=bbox_) + + +@builder.register("PairwiseAlignmentFlowSchema") +@mazepa.flow_schema_cls +@attrs.mutable +class PairwiseAlignmentFlowSchema: + encoding_flow: EncodingFlowSchema | None + defect_flow: DefectFlowSchema | None + binarize_defect_flow: BinarizeDefectFlowSchema | None + mask_encodings_flow: MaskEncodingsFlowSchema | None + compute_field_flow: ComputeFieldFlowSchema | None + invert_field_flow: InvertFieldFlowSchema | None + warp_flow: WarpFlowSchema | None + enc_warped_imgs_flow: EncodeWarpedImgsFlowSchema | None + misd_flow: MisalignmentDetectorFlowSchema | None + binarize_misd_flow: MisalignmentDetectorFlowSchema | None + + def flow(self, bbox): + encoding_task = [] + if self.encoding_flow is not None: + encoding_task = self.encoding_flow(bbox) + yield encoding_task + + defect_task = [] + if self.defect_flow is not None: + defect_task = self.defect_flow(bbox) + yield defect_task + + binarized_defect_task = [] + if self.binarize_defect_flow is not None: + yield mazepa.Dependency(defect_task) + binarized_defect_task = self.binarize_defect_flow(bbox) + yield binarized_defect_task + + mask_encodings_task = [] + if self.mask_encodings_flow is not None: + yield mazepa.Dependency(encoding_task) + yield mazepa.Dependency(binarized_defect_task) + mask_encodings_task = self.mask_encodings_flow(bbox) + yield mask_encodings_task + + compute_field_task = [] + if self.compute_field_flow is not None: + yield mazepa.Dependency(encoding_task) + yield mazepa.Dependency(mask_encodings_task) + compute_field_task = self.compute_field_flow(bbox) + yield compute_field_task + + invert_field_task = [] + if self.invert_field_flow is not None: + yield mazepa.Dependency(compute_field_task) + invert_field_task = self.invert_field_flow(bbox) + yield invert_field_task + + warp_task = [] + if self.warp_flow is not None: + yield mazepa.Dependency(invert_field_task) + warp_task = self.warp_flow(bbox) + yield warp_task + + enc_warped_imgs_task = [] + if self.enc_warped_imgs_flow is not None: + yield mazepa.Dependency(warp_task) + enc_warped_imgs_task = self.enc_warped_imgs_flow(bbox) + yield enc_warped_imgs_task + + misd_task = [] + if self.misd_flow is not None: + yield mazepa.Dependency(enc_warped_imgs_task) + misd_task = self.misd_flow(bbox) + yield misd_task + + binarize_misd_task = [] + if self.binarize_misd_flow is not None: + yield mazepa.Dependency(misd_task) + binarize_misd_task = self.binarize_misd_flow(bbox) + yield binarize_misd_task + + +@builder.register("build_pairwise_alignment_flow") +def build_pairwise_alignment_flow( + bbox: BBox3D | None = None, + bbox_list: Sequence[BBox3D] | None = None, + src_image_path: str = "", + project_folder: str = "", + z_offsets: Sequence[int] = (-1,), + run_encoding: bool = False, + encoding_flow_kwargs: Mapping[str, Any] | None = None, + run_defect: bool = False, + skipped_defect: bool = False, + defect_flow_kwargs: Mapping[str, Any] | None = None, + run_binarize_defect: bool = False, + binarize_defect_flow_kwargs: Mapping[str, Any] | None = None, + run_mask_encodings: bool = False, + mask_encodings_flow_kwargs: Mapping[str, Any] | None = None, + run_compute_field: bool = False, + compute_field_subproject: str | None = None, + compute_field_flow_kwargs: Mapping[str, Any] | None = None, + run_invert_field: bool = False, + invert_field_subproject: str | None = None, + invert_field_flow_kwargs: Mapping[str, Any] | None = None, + run_warp: bool = False, + warp_subproject: str | None = None, + warp_flow_kwargs: Mapping[str, Any] | None = None, + run_enc_warped_imgs: bool = False, + enc_warped_imgs_flow_kwargs: Mapping[str, Any] | None = None, + run_misd: bool = False, + misd_flow_kwargs: Mapping[str, Any] | None = None, + run_binarize_misd: bool = False, + binarize_misd_flow_kwargs: Mapping[str, Any] | None = None, +) -> mazepa.Flow: # pylint: disable=too-many-statements) + + if bbox is None and bbox_list is None: + raise RuntimeError("Either `bbox` and `bbox_list` must be provided") + if bbox is not None and bbox_list is not None: + raise RuntimeError("`bbox` and `bbox_list` cannot be both specified") + if bbox_list is None: + bbox_list = [bbox] + + def resolve_path(path, default=None): + if path is None: + path = default + if "gs://" in path: + # path is absolute + return path + # otherwise path is relative + return os.path.join(project_folder, path) + + def set_path(config, key, default, subproject=None): + if subproject is not None: + default = os.path.join(subproject, default) + config[key] = resolve_path(config.get(key, None), default) + + encoding_flow = None + if encoding_flow_kwargs is None: + encoding_flow_kwargs = {} + set_path(encoding_flow_kwargs, "src_path", default=src_image_path) + set_path(encoding_flow_kwargs, "dst_path", default="encodings") + if run_encoding: + encoding_flow = EncodingFlowSchema(**encoding_flow_kwargs) + + defect_flow = None + if defect_flow_kwargs is None: + defect_flow_kwargs = {} + set_path(defect_flow_kwargs, "src_path", default=src_image_path) + set_path(defect_flow_kwargs, "dst_path", default="defect") + if run_defect: + defect_flow = DefectFlowSchema(**defect_flow_kwargs) + + binarize_defect_flow = None + if binarize_defect_flow_kwargs is None: + binarize_defect_flow_kwargs = {} + set_path(binarize_defect_flow_kwargs, "src_path", default=defect_flow_kwargs["dst_path"]) + set_path(binarize_defect_flow_kwargs, "dst_path", default="defect_binarized") + if run_binarize_defect: + binarize_defect_flow = BinarizeDefectFlowSchema(**binarize_defect_flow_kwargs) + + mask_encodings_flow = None + if mask_encodings_flow_kwargs is None: + mask_encodings_flow_kwargs = {} + set_path( + mask_encodings_flow_kwargs, "mask_path", default=binarize_defect_flow_kwargs["dst_path"] + ) + set_path(mask_encodings_flow_kwargs, "src_path", default=encoding_flow_kwargs["dst_path"]) + set_path(mask_encodings_flow_kwargs, "dst_path", default="encodings_masked") + if run_mask_encodings: + mask_encodings_flow = MaskEncodingsFlowSchema(**mask_encodings_flow_kwargs) + + compute_field_flow = None + if compute_field_flow_kwargs is None: + compute_field_flow_kwargs = {} + cf_default_src = ( + encoding_flow_kwargs["dst_path"] + if skipped_defect + else mask_encodings_flow_kwargs["dst_path"] + ) + set_path(compute_field_flow_kwargs, "src_path", default=cf_default_src) + set_path(compute_field_flow_kwargs, "dst_path", "fields_fwd", compute_field_subproject) + compute_field_flow_kwargs = {"z_offsets": z_offsets} | compute_field_flow_kwargs + if run_compute_field: + compute_field_flow = ComputeFieldFlowSchema(**compute_field_flow_kwargs) + + invert_field_flow = None + if invert_field_flow_kwargs is None: + invert_field_flow_kwargs = {} + set_path(invert_field_flow_kwargs, "src_path", default=compute_field_flow_kwargs["dst_path"]) + set_path(invert_field_flow_kwargs, "dst_path", "fields_inv", invert_field_subproject) + if "stages" in compute_field_flow_kwargs and len(compute_field_flow_kwargs["stages"]): + # try to get field_resolution from the previous step + cf_dst_resolution = compute_field_flow_kwargs["stages"][-1]["dst_resolution"] + invert_field_flow_kwargs = {"dst_resolution": cf_dst_resolution} | invert_field_flow_kwargs + invert_field_flow_kwargs = {"z_offsets": z_offsets} | invert_field_flow_kwargs + if run_invert_field: + invert_field_flow = InvertFieldFlowSchema(**invert_field_flow_kwargs) + + warp_flow = None + if warp_flow_kwargs is None: + warp_flow_kwargs = {} + set_path(warp_flow_kwargs, "src_path", default=src_image_path) + set_path(warp_flow_kwargs, "field_path", default=invert_field_flow_kwargs["dst_path"]) + set_path(warp_flow_kwargs, "dst_path", "imgs_warped", warp_subproject) + if "dst_resolution" in invert_field_flow_kwargs: + # try to get field_resolution from the previous step + warp_flow_kwargs = { + "field_resolution": invert_field_flow_kwargs["dst_resolution"] + } | warp_flow_kwargs + warp_flow_kwargs = {"z_offsets": z_offsets} | warp_flow_kwargs + if run_warp: + warp_flow = WarpFlowSchema(**warp_flow_kwargs) + + enc_warped_imgs_flow = None + if enc_warped_imgs_flow_kwargs is None: + enc_warped_imgs_flow_kwargs = {} + set_path(enc_warped_imgs_flow_kwargs, "src_path", default=warp_flow_kwargs["dst_path"]) + set_path(enc_warped_imgs_flow_kwargs, "dst_path", default="imgs_warped_encoded") + if "dst_resolution" in warp_flow_kwargs: + # try to get resolution from the previous step + enc_warped_imgs_flow_kwargs = { + "src_resolution": warp_flow_kwargs["dst_resolution"] + } | enc_warped_imgs_flow_kwargs + enc_warped_imgs_flow_kwargs = {"z_offsets": z_offsets} | enc_warped_imgs_flow_kwargs + if "reencode_tgt" in enc_warped_imgs_flow_kwargs: + set_path(enc_warped_imgs_flow_kwargs["reencode_tgt"], "src_path", default=src_image_path) + set_path(enc_warped_imgs_flow_kwargs["reencode_tgt"], "dst_path", default="encodings_misd") + if run_enc_warped_imgs: + enc_warped_imgs_flow = EncodeWarpedImgsFlowSchema(**enc_warped_imgs_flow_kwargs) + + misd_flow = None + if misd_flow_kwargs is None: + misd_flow_kwargs = {} + set_path(misd_flow_kwargs, "src_path", default=enc_warped_imgs_flow_kwargs["dst_path"]) + set_path(misd_flow_kwargs, "dst_path", default="misalignments") + # determine tgt + if "reencode_tgt" in enc_warped_imgs_flow_kwargs: + set_path( + misd_flow_kwargs, + "tgt_path", + default=enc_warped_imgs_flow_kwargs["reencode_tgt"]["dst_path"], + ) + else: + set_path(misd_flow_kwargs, "tgt_path", default=encoding_flow_kwargs["dst_path"]) + misd_flow_kwargs = {"z_offsets": z_offsets} | misd_flow_kwargs + if run_misd: + misd_flow = MisalignmentDetectorFlowSchema(**misd_flow_kwargs) + + binarize_misd_flow = None + if binarize_misd_flow_kwargs is None: + binarize_misd_flow_kwargs = {} + set_path(binarize_misd_flow_kwargs, "src_path", default=misd_flow_kwargs["dst_path"]) + set_path(binarize_misd_flow_kwargs, "dst_path", default="misalignments_binarized") + binarize_misd_flow_kwargs = {"z_offsets": z_offsets} | binarize_misd_flow_kwargs + if run_binarize_misd: + binarize_misd_flow = BinarizeMisalignmentFlowSchema(**binarize_misd_flow_kwargs) + + flow_schema = PairwiseAlignmentFlowSchema( + encoding_flow=encoding_flow, + defect_flow=defect_flow, + binarize_defect_flow=binarize_defect_flow, + mask_encodings_flow=mask_encodings_flow, + compute_field_flow=compute_field_flow, + invert_field_flow=invert_field_flow, + warp_flow=warp_flow, + enc_warped_imgs_flow=enc_warped_imgs_flow, + misd_flow=misd_flow, + binarize_misd_flow=binarize_misd_flow, + ) + + @mazepa.flow_schema + def run_multi(bbox_list): + yield [flow_schema(bbox) for bbox in bbox_list] + + return run_multi(bbox_list) From c1f2663eb77c8310daed62022648ee1b0139c052 Mon Sep 17 00:00:00 2001 From: trivoldus28 Date: Tue, 27 Feb 2024 22:13:58 +0000 Subject: [PATCH 2/2] Add missing imports in init --- zetta_utils/mazepa_layer_processing/alignment/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/zetta_utils/mazepa_layer_processing/alignment/__init__.py b/zetta_utils/mazepa_layer_processing/alignment/__init__.py index 4af35447d..8a1093903 100644 --- a/zetta_utils/mazepa_layer_processing/alignment/__init__.py +++ b/zetta_utils/mazepa_layer_processing/alignment/__init__.py @@ -6,7 +6,10 @@ ComputeFieldMultistageFlowSchema, build_compute_field_multistage_flow, ) +from .pairwise_alignment import ( + build_pairwise_alignment_flow, +) from . import warp_operation from . import aced_relaxation_flow -from . import annotated_section_copy +from . import annotated_section_copy \ No newline at end of file