diff --git a/aiod_registry/__init__.py b/aiod_registry/__init__.py index b578176..a95319c 100644 --- a/aiod_registry/__init__.py +++ b/aiod_registry/__init__.py @@ -1,2 +1,2 @@ -from aiod_registry.schema import ModelManifest +from aiod_registry.schema import ModelManifest, TASK_NAMES from aiod_registry.utils import get_manifest_paths, load_manifests diff --git a/aiod_registry/manifests/mitonet.json b/aiod_registry/manifests/mitonet.json index f2ce70f..c2c3875 100644 --- a/aiod_registry/manifests/mitonet.json +++ b/aiod_registry/manifests/mitonet.json @@ -1,69 +1,86 @@ { "name": "Mitonet", - "versions": [ - { - "name": "MitoNet v1", - "tasks": [ - { - "task": "mito", + "metadata": { + "description": "MitoNet is a deep learning model for mitochondria segmentation in EM images.", + "authors": [ + { + "name": "Ryan Conrad", + "affiliation": "Center for Molecular Microscopy, Center for Cancer Research, National Cancer Institute, National Institutes of Health, Bethesda, MD 20892, USA" + }, + { + "name": "Kedar Narayan", + "affiliation": "Center for Molecular Microscopy, Center for Cancer Research, National Cancer Institute, National Institutes of Health, Bethesda, MD 20892, USA" + } + ], + "pubs": [ + { + "info": "Main paper that describes model & data", + "url": "https://doi.org/10.1016/j.cels.2022.12.006" + } + ] + }, + "versions": { + "MitoNet v1": { + "tasks": { + "mito": { "location": "https://zenodo.org/record/6861565/files/MitoNet_v1.pth?download=1" } - ] + } }, - { - "name": "MitoNet Mini v1", - "tasks": [ - { - "task": "mito", + "MitoNet Mini v1": { + "tasks": { + "mito": { "location": "https://zenodo.org/record/6861565/files/MitoNet_v1_mini.pth?download=1" } - ] + } } - ], + }, "params": [ { "name": "Plane", + "arg_name": "plane", "value": ["XY", "XZ", "YZ", "All"], "tooltip": "Whether to use all planes (XY, XZ, YZ) or a single plane" }, { "name": "Downsampling", + "arg_name": "downsampling", "value": [1, 2, 4, 8, 16, 32, 64], "tooltip": "Downsampling factor for the input image" }, { "name": "Segmentation threshold", - "short_name": "conf_threshold", + "arg_name": "conf_threshold", "value": 0.5, "tooltip": "Confidence threshold for the segmentation" }, { "name": "Center threshold", - "short_name": "center_threshold", + "arg_name": "center_threshold", "value": 0.1, "tooltip": "Confidence threshold for the center" }, { "name": "Minimum distance", - "short_name": "min_distance", + "arg_name": "min_distance", "value": 3, "tooltip": "Minimum distance between object centers" }, { "name": "Maximum objects", - "short_name": "max_objects", + "arg_name": "max_objects", "value": 1000, "tooltip": "Maximum number of objects to segment per class" }, { "name": "Semantic only", - "short_name": "semantic_only", + "arg_name": "semantic_only", "value": false, "tooltip": "Only run semantic segmentation for all classes" }, { "name": "Fine boundaries", - "short_name": "fine_boundaries", + "arg_name": "fine_boundaries", "value": false, "tooltip": "Finer boundaries between objects" } diff --git a/aiod_registry/manifests/sam_test.json b/aiod_registry/manifests/sam.json similarity index 53% rename from aiod_registry/manifests/sam_test.json rename to aiod_registry/manifests/sam.json index 967a07c..0243641 100644 --- a/aiod_registry/manifests/sam_test.json +++ b/aiod_registry/manifests/sam.json @@ -1,122 +1,118 @@ { "name": "Segment Anything", "short_name": "sam", - "versions": [ - { - "name": "default", - "tasks": [ - { - "task": "everything", - "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", - "config_path": null + "metadata": { + "description": "Segment Anything is a vision foundation model with flexible prompting.", + "url": "https://segment-anything.com/", + "repo": "https://github.com/facebookresearch/segment-anything", + "pubs": [ + { + "info": "Main paper that describes model & data", + "url": "https://arxiv.org/abs/2304.02643" + } + ] + }, + "versions": { + "default": { + "tasks": { + "everything": { + "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" } - ] + } }, - { - "name": "vit_h", - "tasks": [ - { - "task": "everything", - "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", - "config_path": null + "vit_h": { + "tasks": { + "everything": { + "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" } - ] + } }, - { - "name": "vit_l", - "tasks": [ - { - "task": "everything", - "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", - "config_path": null + "vit_l": { + "tasks": { + "everything": { + "location": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth" } - ] + } }, - { - "name": "vit_b", - "tasks": [ - { - "task": "Mito", - "location":"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - "config_path": null + "vit_b": { + "tasks": { + "everything": { + "location":"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" } - ] + } }, - { - "name": "MedSAM", - "tasks": [ - { - "task": "everything", - "location": "https://syncandshare.desy.de/index.php/s/yLfdFbpfEGSHJWY/download/medsam_20230423_vit_b_0.0.1.pth", - "config_path": null + "MedSAM": { + "tasks": { + "everything": { + "location": "https://syncandshare.desy.de/index.php/s/yLfdFbpfEGSHJWY/download/medsam_20230423_vit_b_0.0.1.pth" } - ] + } } - ], + }, "params": [ { "name": "Points per side", - "short_name": "points_per_side", + "arg_name": "points_per_side", "value": 32, "tooltip": "" }, { "name": "Points per batch", - "short_name": "points_per_batch", + "arg_name": "points_per_batch", "value": 64, "tooltip": "" }, { "name": "Pred IoU threshold", - "short_name": "pred_iou_thresh", + "arg_name": "pred_iou_thresh", "value": 0.88, "tooltip": "" }, { "name": "Stability score threshold", - "short_name": "stability_score_thresh", + "arg_name": "stability_score_thresh", "value": 0.95, "tooltip": "" }, { "name": "Stability score offset", - "short_name": "stability_score_offset", + "arg_name": "stability_score_offset", "value": 1, "tooltip": "" }, { "name": "Box nms_thresh", - "short_name": "box_nms_thresh", + "arg_name": "box_nms_thresh", "value": 0.7, "tooltip": "" }, { "name": "Crop N layers", - "short_name": "crop_n_layers", + "arg_name": "crop_n_layers", "value": 0, "tooltip": "" }, { "name": "Crop NMS thresh", - "short_name": "crop_nms_thresh", + "arg_name": "crop_nms_thresh", "value": 0.7, "tooltip": "" }, { "name": "Crop overlap ratio", - "short_name": "crop_overlap_ratio", + "arg_name": "crop_overlap_ratio", "value": 0.34133, "tooltip": "" }, { "name": "Crop B points downscale factor", - "short_name": "crop_n_points_downscale_factor", + "arg_name": "crop_n_points_downscale_factor", "value": 0.5, "tooltip": "" }, { "name": "Min mask region area", - "short_name": "min_mask_region_area", + "arg_name": "min_mask_region_area", "value": 3, "tooltip": "" } diff --git a/aiod_registry/manifests/unet_seai.json b/aiod_registry/manifests/unet_seai.json index 8de2444..ae0e046 100644 --- a/aiod_registry/manifests/unet_seai.json +++ b/aiod_registry/manifests/unet_seai.json @@ -1,31 +1,29 @@ { "name": "SEAI U-Net", "short_name": "seai_unet", - "versions": [ - { - "name": "U-Net", - "tasks": [ - { - "task": "mito", - "location": "/nemo/stp/ddt/working/shandc/aiod_models/mito_5nm_intensity_augs_warp.best.969.pt", - "config_path": "/nemo/stp/ddt/working/shandc/aiod_models/mito_5nm_intensity_augs_warp.yml" + "metadata": { + "description": "SEAI U-Net developed on internal Crick EM data" + }, + "versions": { + "U-Net": { + "tasks": { + "mito": { + "location": "/Users/shandc/Documents/ai_ondemand/mito_5nm_intensity_augs_warp.best.969.pt", + "config_path": "/Users/shandc/Documents/ai_ondemand/mito_5nm_intensity_augs_warp.yml" } - ] + } }, - { - "name": "Attention U-Net", - "tasks": [ - { - "task": "mito", + "Attention U-Net": { + "tasks": { + "mito": { "location": "/nemo/stp/ddt/working/shandc/aiod_models/Attention_HUNet_3e5_Adam_restart_12_16.best.1266.pt", "config_path": "/nemo/stp/ddt/working/shandc/aiod_models/Attention_HUNet_3e5_Adam_restart_12_16.yml" }, - { - "task": "ne", + "ne": { "location": "/nemo/stp/ddt/working/shandc/aiod_models/Attention_HUNet_NE.best.368.pt", "config_path": "/nemo/stp/ddt/working/shandc/aiod_models/Attention_HUNet_NE.yml" } - ] + } } - ] + } } \ No newline at end of file diff --git a/aiod_registry/schema.json b/aiod_registry/schema.json new file mode 100644 index 0000000..5ad1d12 --- /dev/null +++ b/aiod_registry/schema.json @@ -0,0 +1,417 @@ +{ + "$defs": { + "Author": { + "additionalProperties": false, + "properties": { + "name": { + "title": "Name", + "type": "string" + }, + "affiliation": { + "title": "Affiliation", + "type": "string" + }, + "email": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Email" + }, + "url": { + "anyOf": [ + { + "format": "uri", + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Url" + }, + "github": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Github" + }, + "orcid": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Orcid" + } + }, + "required": [ + "name", + "affiliation" + ], + "title": "Author", + "type": "object" + }, + "Metadata": { + "additionalProperties": false, + "properties": { + "description": { + "description": "A short description of the model to provide context.", + "title": "Description", + "type": "string" + }, + "authors": { + "anyOf": [ + { + "items": { + "$ref": "#/$defs/Author" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Authors" + }, + "pubs": { + "anyOf": [ + { + "items": { + "$ref": "#/$defs/Publication" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pubs" + }, + "url": { + "anyOf": [ + { + "format": "uri", + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Url" + }, + "repo": { + "anyOf": [ + { + "format": "uri", + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Repo" + } + }, + "required": [ + "description" + ], + "title": "Metadata", + "type": "object" + }, + "ModelParam": { + "additionalProperties": false, + "properties": { + "name": { + "description": "Name of the parameter. If `arg_name` is not provided, this will be used as the argument name to the underlying model.", + "maxLength": 50, + "minLength": 1, + "title": "Name", + "type": "string" + }, + "arg_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Arg Name" + }, + "value": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + } + ] + }, + "type": "array" + } + ], + "description": "Default parameter value. If a list, the parameters will be treated as dropdown choices, where the first is the default. The type of the first element will be used to determine the type of the parameter.", + "title": "Value" + }, + "tooltip": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Tooltip" + } + }, + "required": [ + "name", + "value" + ], + "title": "ModelParam", + "type": "object" + }, + "ModelVersion": { + "additionalProperties": false, + "properties": { + "tasks": { + "patternProperties": { + "^(?i:mito|er|ne|everything)$": { + "$ref": "#/$defs/ModelVersionTask" + } + }, + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "tasks" + ], + "title": "ModelVersion", + "type": "object" + }, + "ModelVersionTask": { + "additionalProperties": false, + "properties": { + "location": { + "description": "Either a url or a filepath (will be skipped if the path does not exist/cannot be read!)", + "title": "Location", + "type": "string" + }, + "config_path": { + "anyOf": [ + { + "format": "path", + "type": "string" + }, + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Config Path" + }, + "params": { + "anyOf": [ + { + "items": { + "$ref": "#/$defs/ModelParam" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Params" + }, + "location_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Location Type" + } + }, + "required": [ + "location" + ], + "title": "ModelVersionTask", + "type": "object" + }, + "Publication": { + "additionalProperties": false, + "properties": { + "info": { + "description": "Information on publication, whether it pertains to the model or the underlying data or something else.", + "title": "Info", + "type": "string" + }, + "url": { + "format": "uri", + "minLength": 1, + "title": "Url", + "type": "string" + }, + "doi": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Doi" + }, + "authors": { + "anyOf": [ + { + "items": { + "$ref": "#/$defs/Author" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Authors" + } + }, + "required": [ + "info", + "url" + ], + "title": "Publication", + "type": "object" + } + }, + "additionalProperties": false, + "properties": { + "name": { + "maxLength": 50, + "minLength": 1, + "title": "Name", + "type": "string" + }, + "short_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Short Name" + }, + "versions": { + "additionalProperties": { + "$ref": "#/$defs/ModelVersion" + }, + "title": "Versions", + "type": "object" + }, + "params": { + "anyOf": [ + { + "items": { + "$ref": "#/$defs/ModelParam" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Params" + }, + "config": { + "anyOf": [ + { + "format": "path", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Config" + }, + "metadata": { + "$ref": "#/$defs/Metadata" + } + }, + "required": [ + "name", + "versions", + "metadata" + ], + "title": "ModelManifest", + "type": "object" +} \ No newline at end of file diff --git a/aiod_registry/schema.py b/aiod_registry/schema.py index c03578a..e639e2c 100644 --- a/aiod_registry/schema.py +++ b/aiod_registry/schema.py @@ -1,8 +1,39 @@ -import json from pathlib import Path from typing import Optional, Union +from urllib.parse import urlparse from pydantic import BaseModel, ConfigDict, Field, model_validator, AnyUrl +from typing_extensions import Annotated + +TASK_NAMES = { + "mito": "Mitochondria", + "er": "Endoplasmic Reticulum", + "ne": "Nuclear Envelope", + "everything": "Everything!", +} +task_names = "|".join(TASK_NAMES.keys()) + +# Define custom types/fields +# Centralise to make it easier to change later +# Regex pattern to match task names, ignoring case +Task = Annotated[str, Field(..., pattern=rf"^(?i:{task_names})$")] +ModelName = Annotated[str, Field(..., min_length=1, max_length=50)] +ParamName = Annotated[ + str, + Field( + ..., + min_length=1, + max_length=50, + description="Name of the parameter. If `arg_name` is not provided, this will be used as the argument name to the underlying model.", + ), +] +ParamValue = Annotated[ + Union[str, int, float, bool, list[Union[str, int, float, bool]]], + Field( + ..., + description="Default parameter value. If a list, the parameters will be treated as dropdown choices, where the first is the default. The type of the first element will be used to determine the type of the parameter.", + ), +] def shorten_name(name: str) -> str: @@ -13,43 +44,124 @@ class StrictModel(BaseModel): model_config = ConfigDict(extra="forbid") +class ModelParam(StrictModel): + name: ParamName + arg_name: Optional[str] = None + value: ParamValue + tooltip: Optional[str] = None + _dtype = None + + @model_validator(mode="after") + def create_arg_name(self): + if self.arg_name is None: + self.arg_name = self.name + return self + + @model_validator(mode="after") + def extract_arg_type(self): + if isinstance(self.value, list): + self._dtype = type(self.value[0]) + else: + self._dtype = type(self.value) + return self + + class ModelVersionTask(StrictModel): - # Regex pattern to match task names, ignoring case - task: str = Field(..., pattern=r"^(?i:mito|er|ne|everything)$") - location: Union[Path, AnyUrl, str] = Field( + location: str = Field( ..., - description="Either a url or a filepath (will be skipped if the path does not exist/cannot be read)", + description="Either a url or a filepath (will be skipped if the path does not exist/cannot be read!)", ) config_path: Optional[Union[Path, str]] = None + params: Optional[list[ModelParam]] = None + location_type: Optional[str] = None + + @model_validator(mode="after") + def get_location_type(self): + # Skip if provided + if self.location_type is not None: + return self + # Otherwise, determine the type + res = urlparse(self.location) + if res.scheme in ("http", "https"): + self.location_type = "url" + elif res.scheme in ("file", ""): + self.location_type = "file" + else: + # NOTE: Because of including "" above, it is unlikely this will be reached + raise TypeError( + f"Cannot determine type (file/url) of location: {self.location}!" + ) + return self class ModelVersion(StrictModel): - name: str = Field(..., min_length=1, max_length=50) - tasks: list[ModelVersionTask] + tasks: dict[Task, ModelVersionTask] -class ModelParam(StrictModel): - name: str = Field(..., min_length=1, max_length=50) - short_name: Optional[str] = None - value: Union[str, int, float, bool, list[Union[str, int, float, bool]]] - tooltip: Optional[str] = None +class Author(StrictModel): + name: str + affiliation: str + email: Optional[str] = None + url: Optional[AnyUrl] = None + github: Optional[str] = None + orcid: Optional[str] = None - @model_validator(mode="after") - def create_short_name(self): - if self.short_name is None: - self.short_name = shorten_name(self.name) - return self + +class Publication(StrictModel): + info: Annotated[ + str, + Field( + ..., + description="Information on publication, whether it pertains to the model or the underlying data or something else.", + ), + ] + url: AnyUrl + doi: Optional[str] = None + authors: Optional[list[Author]] = None + + +class Metadata(StrictModel): + description: Annotated[ + str, + Field( + ..., + description="A short description of the model to provide context.", + ), + ] + authors: Optional[list[Author]] = None + pubs: Optional[list[Publication]] = None + url: Optional[AnyUrl] = None + repo: Optional[AnyUrl] = None class ModelManifest(StrictModel): name: str = Field(..., min_length=1, max_length=50) short_name: Optional[str] = None - versions: list[ModelVersion] + versions: dict[ModelName, ModelVersion] params: Optional[list[ModelParam]] = None config: Optional[Path] = None + metadata: Metadata @model_validator(mode="after") def create_short_name(self): if self.short_name is None: self.short_name = shorten_name(self.name) return self + + # Embed base model params into each version if not provided + @model_validator(mode="after") + def fill_empty_params(self): + for version in self.versions.values(): + for task in version.tasks.values(): + if task.params is None: + task.params = self.params + + +if __name__ == "__main__": + import json + + schema_fpath = Path(__file__).parent / "schema.json" + + # Write the schema to file + with open(schema_fpath, "w") as f: + f.write(json.dumps(ModelManifest.model_json_schema(), indent=2))