Skip to content

Commit

Permalink
Update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Aug 23, 2024
1 parent b84bcb6 commit bc8f0b2
Showing 1 changed file with 162 additions and 8 deletions.
170 changes: 162 additions & 8 deletions examples/save_load_model_and_share_with_hf_hub.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -82,9 +82,11 @@
"license: mit\n",
"pipeline_tag: image-segmentation\n",
"tags:\n",
"- model_hub_mixin\n",
"- pytorch_model_hub_mixin\n",
"- segmentation-models-pytorch\n",
"- semantic-segmentation\n",
"- pytorch\n",
"- segmentation-models-pytorch\n",
"languages:\n",
"- python\n",
"---\n",
Expand Down Expand Up @@ -157,7 +159,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "075ae026811542bdb4030e53b943efc7",
"model_id": "1d6fe9d868c24175aa5f23a2893a2c21",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -179,13 +181,13 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2921a81d7fd747939b4a425cc17d6104",
"model_id": "2f4f5e4973e44f9a857e89d9ac707b53",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -199,10 +201,10 @@
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/9f821c7bc3a12db827c0da96a31f354ec6ba5253', commit_message='Push model using huggingface_hub.', commit_description='', oid='9f821c7bc3a12db827c0da96a31f354ec6ba5253', pr_url=None, pr_revision=None, pr_num=None)"
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', commit_message='Push model using huggingface_hub.', commit_description='', oid='4ac3d2925d34cf183dc79a2e21b6e2f4bfe87586', pr_url=None, pr_revision=None, pr_num=None)"
]
},
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -224,6 +226,158 @@
"\n",
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save model with preprocessing (using albumentations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install -U albumentations numpy==1.*"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import albumentations as A\n",
"import segmentation_models_pytorch as smp"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# define a preprocessing transform for image that would be used during inference\n",
"preprocessing_transform = A.Compose([\n",
" A.Resize(256, 256),\n",
" A.Normalize()\n",
"])\n",
"\n",
"model = smp.Unet()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1aa3f4db4cd2489baeac3b844977d5a2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/97.8M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-transform/commit/680dad16431fa6efbb25832d33a24056bdf7dc1a', commit_message='Push transform using huggingface_hub.', commit_description='', oid='680dad16431fa6efbb25832d33a24056bdf7dc1a', pr_url=None, pr_revision=None, pr_num=None)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"directory_or_repo_on_the_hub = \"qubvel-hf/unet-with-transform\"\n",
"\n",
"# save the model\n",
"model.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)\n",
"\n",
"# save transform\n",
"preprocessing_transform.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's restore model and preprocessing transform for inference:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading weights from local directory\n",
"Compose([\n",
" Resize(p=1.0, height=256, width=256, interpolation=1),\n",
" Normalize(p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, normalization='standard'),\n",
"], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n"
]
}
],
"source": [
"restored_model = smp.from_pretrained(directory_or_repo_on_the_hub)\n",
"restored_transform = A.Compose.from_pretrained(directory_or_repo_on_the_hub)\n",
"\n",
"print(restored_transform)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Compose([\n",
" HorizontalFlip(p=0.5),\n",
" RandomBrightnessContrast(p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=True),\n",
" ShiftScaleRotate(p=0.5, shift_limit_x=(-0.0625, 0.0625), shift_limit_y=(-0.0625, 0.0625), scale_limit=(-0.09999999999999998, 0.10000000000000009), rotate_limit=(-45, 45), interpolation=1, border_mode=4, value=0.0, mask_value=0.0, rotate_method='largest_box'),\n",
"], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n"
]
}
],
"source": [
"# You can also save training augmentations to the Hub too (and load it back)!\n",
"#! Just make sure to provide key=\"train\" when saving and loading the augmentations.\n",
"\n",
"train_augmentations = A.Compose([\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomBrightnessContrast(p=0.2),\n",
" A.ShiftScaleRotate(p=0.5),\n",
"])\n",
"\n",
"train_augmentations.save_pretrained(directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True)\n",
"\n",
"restored_train_augmentations = A.Compose.from_pretrained(directory_or_repo_on_the_hub, key=\"train\")\n",
"print(restored_train_augmentations)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See saved model and `albumentations` configs on the hub: https://huggingface.co/qubvel-hf/unet-with-transform/tree/main"
]
}
],
"metadata": {
Expand Down

0 comments on commit bc8f0b2

Please sign in to comment.