diff --git a/.vscode/settings.json b/.vscode/settings.json index 3a060e1..3275f93 100755 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -21,7 +21,9 @@ "ndarray", "numpy", "ONNX", + "onnxconverter", "onnxruntime", + "opset", "packbits", "preprocess", "pretrained", @@ -48,5 +50,8 @@ "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" }, - "python.formatting.provider": "none" + "python.formatting.provider": "none", + "window.autoDetectColorScheme": true, + "workbench.colorTheme": "Default Dark+", + "workbench.preferredDarkColorTheme": "Default Dark+" } \ No newline at end of file diff --git a/python/scripts/export_encoders.ipynb b/python/scripts/export_encoders.ipynb index 029e60a..a9fdaf4 100644 --- a/python/scripts/export_encoders.ipynb +++ b/python/scripts/export_encoders.ipynb @@ -19,75 +19,148 @@ "metadata": {}, "outputs": [], "source": [ - "!pip uninstall -y uform\n", "!pip install --upgrade \"uform[torch]\" coremltools" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "import os\n", - "model_name = \"uform-vl-english-small\"\n", - "output_directory = \"../../\"" + "\n", + "working_directory = \"../..\"\n", + "model_name = \"uform3-image-text-english-small\"\n", + "model_directory = os.path.join(working_directory, \"models\", model_name)\n", + "model_weights_path = os.path.join(model_directory, \"torch_weight.pt\")\n", + "config_path = os.path.join(model_directory, \"config.json\")\n", + "tokenizer_path = os.path.join(model_directory, \"tokenizer.json\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "['image_encoder', 'text_encoder']" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import uform\n", - "from PIL import Image\n", - "\n", - "model, processor = uform.get_model('unum-cloud/' + model_name)\n", - "text = 'a small red panda in a zoo'\n", - "image = Image.open('../../assets/unum.png')\n", - "\n", - "image_data = processor.preprocess_image(image)\n", - "text_data = processor.preprocess_text(text)\n", - "\n", - "image_features, image_embedding = model.encode_image(image_data, return_features=True)\n", - "text_features, text_embedding = model.encode_text(text_data, return_features=True)\n", + "import torch\n", "\n", - "image_features.shape, text_features.shape, image_embedding.shape, text_embedding.shape" + "state_dict = torch.load(model_weights_path)\n", + "list(state_dict.keys())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/av/miniconda3/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: dlopen(/Users/av/miniconda3/lib/python3.10/site-packages/torchvision/image.so, 0x0006): Symbol not found: __ZN3c106detail19maybe_wrap_dim_slowExxb\n", + " Referenced from: <0B637046-A38B-3A5C-80C6-E847C27DCCD5> /Users/av/miniconda3/lib/python3.10/site-packages/torchvision/image.so\n", + " Expected in: <3AE92490-D363-3FD7-8532-CB6F5F795BC8> /Users/av/miniconda3/lib/python3.10/site-packages/torch/lib/libc10.dylib\n", + " warn(f\"Failed to load image Python extension: {e}\")\n" + ] + } + ], "source": [ - "model.text_encoder" + "from uform.torch_encoders import ImageEncoder, TextEncoder\n", + "from uform.torch_processors import ImageProcessor, TextProcessor" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(ImageEncoder(dim=384, patch_size=16, image_size=224, num_layers=12, num_heads=6, embedding_dim=256, pooling='cls', num_reg_tokens=0),\n", + " TextEncoder(model_type='bert', dim=768, context_dim=384, vocab_size=30522, padding_idx=0, num_layers=4, num_heads=12, embedding_dim=256, multimodal_layers_ids=[2, 3], head_one_neuron=False, pooling='cls', max_position_embeddings=64, dropout_prob=0.1))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "model.image_encoder" + "image_encoder = ImageEncoder.from_pretrained(config_path, state_dict)\n", + "text_encoder = TextEncoder.from_pretrained(config_path, state_dict)\n", + "image_encoder, text_encoder" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(,\n", + " )" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# Assuming `model` is your loaded model with image_encoder and text_encoder attributes\n", - "for name, module in model.image_encoder.named_children():\n", - " print(f\"First layer of image_encoder: {name}\")\n", - " break # We break after the first layer\n", + "text_processor = TextProcessor(config_path, tokenizer_path)\n", + "image_processor = ImageProcessor(config_path)\n", + "text_processor, image_processor" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1, 197, 384]),\n", + " torch.Size([1, 64, 768]),\n", + " torch.Size([1, 256]),\n", + " torch.Size([1, 256]))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import uform\n", + "from PIL import Image\n", + "\n", + "text = 'a small red panda in a zoo'\n", + "image = Image.open('../../assets/unum.png')\n", "\n", - "for name, module in model.text_encoder.named_children():\n", - " print(f\"First layer of text_encoder: {name}\")\n", - " break # We break after the first layer" + "text_data = text_processor(text)\n", + "image_data = image_processor(image)\n", + "\n", + "image_features, image_embedding = image_encoder.forward(image_data, return_features=True)\n", + "text_features, text_embedding = text_encoder.forward(text_data, return_features=True)\n", + "\n", + "image_features.shape, text_features.shape, image_embedding.shape, text_embedding.shape" ] }, { @@ -99,9 +172,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "scikit-learn version 1.2.1 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.\n", + "Torch version 2.1.1 has not been tested with coremltools. You may run into unexpected errors. Torch 2.1.0 is the most recent version that has been tested.\n" + ] + } + ], "source": [ "import coremltools as ct\n", "import torch" @@ -109,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -137,9 +219,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "((RangeDim(lower_bound=1, upper_bound=64, default=1, symbol=\"is0\"),\n", + " 3,\n", + " 224,\n", + " 224),\n", + " (RangeDim(lower_bound=1, upper_bound=64, default=1, symbol=\"is1\"), 64),\n", + " (RangeDim(lower_bound=1, upper_bound=64, default=1, symbol=\"is2\"), 64))" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "def generalize_first_dimensions(input_shape, upper_bound=64):\n", " if upper_bound == 1:\n", @@ -147,16 +245,16 @@ " input_shape = (ct.RangeDim(lower_bound=1, upper_bound=upper_bound, default=1),) + input_shape[1:]\n", " return input_shape\n", "\n", - "generalize_first_dimensions(image_data.shape), generalize_first_dimensions(text_data[\"input_ids\"].shape), generalize_first_dimensions(text_data[\"attention_mask\"].shape)" + "generalize_first_dimensions(image_data[\"images\"].shape), generalize_first_dimensions(text_data[\"input_ids\"].shape), generalize_first_dimensions(text_data[\"attention_mask\"].shape)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "image_input = ct.TensorType(name=\"images\", shape=generalize_first_dimensions(image_data.shape, 1))\n", + "image_input = ct.TensorType(name=\"images\", shape=generalize_first_dimensions(image_data[\"images\"].shape, 1))\n", "text_input = ct.TensorType(name=\"input_ids\", shape=generalize_first_dimensions(text_data[\"input_ids\"].shape, 1))\n", "text_attention_input = ct.TensorType(name=\"attention_mask\", shape=generalize_first_dimensions(text_data[\"attention_mask\"].shape, 1))\n", "text_features = ct.TensorType(name=\"features\")\n", @@ -167,23 +265,282 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ImageEncoder(\n", + " original_name=ImageEncoder\n", + " (patch_embed): Conv2d(original_name=Conv2d)\n", + " (blocks): Sequential(\n", + " original_name=Sequential\n", + " (0): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (1): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (2): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (3): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (4): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (5): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (6): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (7): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (8): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (9): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (10): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " (11): ImageEncoderBlock(\n", + " original_name=ImageEncoderBlock\n", + " (norm1): LayerNorm(original_name=LayerNorm)\n", + " (attn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (ls1): LayerScale(original_name=LayerScale)\n", + " (norm2): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (ls2): LayerScale(original_name=LayerScale)\n", + " )\n", + " )\n", + " (norm): LayerNorm(original_name=LayerNorm)\n", + " (embedding_projection): Linear(original_name=Linear)\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "module = model.image_encoder\n", + "module = image_encoder\n", "module.eval()\n", "module.return_features = True\n", "\n", - "traced_script_module = torch.jit.trace(module, example_inputs=image_data)\n", + "traced_script_module = torch.jit.trace(module, example_inputs=image_data[\"images\"])\n", "traced_script_module" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tuple detected at graph output. This will be flattened in the converted model.\n", + "Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 453/455 [00:00<00:00, 5950.04 ops/s]\n", + "Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 410.60 passes/s]\n", + "Running MIL default pipeline: 100%|██████████| 69/69 [00:00<00:00, 209.62 passes/s]\n", + "Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 682.56 passes/s]\n" + ] + } + ], "source": [ "coreml_model = ct.convert(\n", " traced_script_module, source=\"pytorch\",\n", @@ -193,16 +550,127 @@ "coreml_model.author = 'Unum Cloud'\n", "coreml_model.license = 'Apache 2.0'\n", "coreml_model.short_description = 'Pocket-Sized Multimodal AI for Content Understanding'\n", - "coreml_model.save(os.path.join(output_directory, \"image_encoder.mlpackage\"))" + "coreml_model.save(os.path.join(model_directory, \"image_encoder.mlpackage\"))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "TextEncoder(\n", + " original_name=TextEncoder\n", + " (word_embeddings): Embedding(original_name=Embedding)\n", + " (position_embeddings): Embedding(original_name=Embedding)\n", + " (layer_norm): LayerNorm(original_name=LayerNorm)\n", + " (dropout): Dropout(original_name=Dropout)\n", + " (blocks): ModuleList(\n", + " original_name=ModuleList\n", + " (0): TextEncoderBlock(\n", + " original_name=TextEncoderBlock\n", + " (norm_attn): LayerNorm(original_name=LayerNorm)\n", + " (attention): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_mlp): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (dropout): Dropout(original_name=Dropout)\n", + " )\n", + " (1): TextEncoderBlock(\n", + " original_name=TextEncoderBlock\n", + " (norm_attn): LayerNorm(original_name=LayerNorm)\n", + " (attention): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_mlp): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (dropout): Dropout(original_name=Dropout)\n", + " )\n", + " (2): TextEncoderBlock(\n", + " original_name=TextEncoderBlock\n", + " (norm_attn): LayerNorm(original_name=LayerNorm)\n", + " (attention): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_crossattn): LayerNorm(original_name=LayerNorm)\n", + " (crossattn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_mlp): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (dropout): Dropout(original_name=Dropout)\n", + " )\n", + " (3): TextEncoderBlock(\n", + " original_name=TextEncoderBlock\n", + " (norm_attn): LayerNorm(original_name=LayerNorm)\n", + " (attention): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_crossattn): LayerNorm(original_name=LayerNorm)\n", + " (crossattn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_mlp): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (dropout): Dropout(original_name=Dropout)\n", + " )\n", + " )\n", + " (embedding_projection): Linear(original_name=Linear)\n", + " (matching_head): Linear(original_name=Linear)\n", + " (context_projection): Linear(original_name=Linear)\n", + ")" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "module = model.text_encoder\n", + "module = text_encoder\n", "module.eval()\n", "module.return_features = True\n", "\n", @@ -212,9 +680,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tuple detected at graph output. This will be flattened in the converted model.\n", + "Converting PyTorch Frontend ==> MIL Ops: 0%| | 0/157 [00:00 MIL Ops: 99%|█████████▊| 155/157 [00:00<00:00, 8885.02 ops/s]\n", + "Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 2373.15 passes/s]\n", + "Running MIL default pipeline: 100%|██████████| 69/69 [00:00<00:00, 742.73 passes/s]\n", + "Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 3188.98 passes/s]\n" + ] + } + ], "source": [ "coreml_model = ct.convert(\n", " traced_script_module, source=\"pytorch\",\n", @@ -224,7 +705,7 @@ "coreml_model.author = 'Unum Cloud'\n", "coreml_model.license = 'Apache 2.0'\n", "coreml_model.short_description = 'Pocket-Sized Multimodal AI for Content Understanding'\n", - "coreml_model.save(os.path.join(output_directory, \"text_encoder.mlpackage\"))" + "coreml_model.save(os.path.join(model_directory, \"text_encoder.mlpackage\"))" ] }, { @@ -242,7 +723,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -253,68 +734,104 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ImageEncoder(dim=384, patch_size=16, image_size=224, num_layers=12, num_heads=6, embedding_dim=256, pooling='cls', num_reg_tokens=0)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "model.image_encoder.eval()\n", - "model.image_encoder.to(dtype=torch.bfloat16)" + "image_encoder.eval()\n", + "image_encoder.to(dtype=torch.bfloat16)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ - "torch.save(model.image_encoder.state_dict(), os.path.join(output_directory, \"image_encoder.pt\"))" + "torch.save(image_encoder.state_dict(), os.path.join(model_directory, \"image_encoder.pt\"))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ - "save_file(model.image_encoder.state_dict(), os.path.join(output_directory, \"image_encoder.safetensors\"))" + "save_file(image_encoder.state_dict(), os.path.join(model_directory, \"image_encoder.safetensors\"))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "TextEncoder(model_type='bert', dim=768, context_dim=384, vocab_size=30522, padding_idx=0, num_layers=4, num_heads=12, embedding_dim=256, multimodal_layers_ids=[2, 3], head_one_neuron=False, pooling='cls', max_position_embeddings=64, dropout_prob=0.1)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "model.text_encoder.eval()\n", - "model.text_encoder.to(dtype=torch.bfloat16)" + "text_encoder.eval()\n", + "text_encoder.to(dtype=torch.bfloat16)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ - "torch.save(model.text_encoder.state_dict(), os.path.join(output_directory, \"text_encoder.pt\"))" + "torch.save(text_encoder.state_dict(), os.path.join(model_directory, \"text_encoder.pt\"))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ - "save_file(model.text_encoder.state_dict(), os.path.join(output_directory, \"text_encoder.safetensors\"))" + "save_file(text_encoder.state_dict(), os.path.join(model_directory, \"text_encoder.safetensors\"))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1, 197, 384]),\n", + " torch.Size([1, 64, 768]),\n", + " torch.Size([1, 256]),\n", + " torch.Size([1, 256]))" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "image_features, image_embedding = model.encode_image(image_data.to(dtype=torch.bfloat16), return_features=True)\n", - "text_features, text_embedding = model.encode_text(text_data, return_features=True)\n", + "image_features, image_embedding = image_encoder.forward(image_data[\"images\"].to(dtype=torch.bfloat16), return_features=True)\n", + "text_features, text_embedding = text_encoder.forward(text_data, return_features=True)\n", "\n", "image_features.shape, text_features.shape, image_embedding.shape, text_embedding.shape" ] @@ -337,7 +854,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -354,11 +871,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ - "module = model.text_encoder\n", + "module = text_encoder\n", "module.eval()\n", "module.return_features = True\n", "module.to(dtype=torch.float32)\n", @@ -366,7 +883,7 @@ "onnx_export(\n", " module,\n", " (text_data[\"input_ids\"], text_data[\"attention_mask\"]), \n", - " os.path.join(output_directory, \"text_encoder.onnx\"), \n", + " os.path.join(model_directory, \"text_encoder.onnx\"), \n", " export_params=True,\n", " opset_version=15,\n", " do_constant_folding=True,\n", @@ -388,19 +905,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ - "module = model.image_encoder\n", + "module = image_encoder\n", "module.eval()\n", "module.return_features = True\n", "module.to(dtype=torch.float32)\n", "\n", "torch.onnx.export(\n", " module,\n", - " image_data, \n", - " os.path.join(output_directory, \"image_encoder.onnx\"), \n", + " image_data[\"images\"], \n", + " os.path.join(model_directory, \"image_encoder.onnx\"), \n", " export_params=True,\n", " opset_version=15,\n", " do_constant_folding=True,\n", @@ -437,7 +954,7 @@ "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"text_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"text_encoder.onnx\")\n", "module = onnx.load(module_path)\n", "module_fp16 = float16.convert_float_to_float16(module)\n", "onnx.save(module_fp16, module_path)" @@ -449,7 +966,7 @@ "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"image_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"image_encoder.onnx\")\n", "module = onnx.load(module_path)\n", "module_fp16 = float16.convert_float_to_float16(module)\n", "onnx.save(module_fp16, module_path)" @@ -467,7 +984,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -476,21 +993,37 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 48, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Please consider to run pre-processing before quantization. Refer to example: https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md \n" + ] + } + ], "source": [ - "module_path = os.path.join(output_directory, \"text_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"text_encoder.onnx\")\n", "quantize_dynamic(module_path, module_path, weight_type=QuantType.QUInt8)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Please consider to run pre-processing before quantization. Refer to example: https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md \n" + ] + } + ], "source": [ - "module_path = os.path.join(output_directory, \"image_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"image_encoder.onnx\")\n", "quantize_dynamic(module_path, module_path, weight_type=QuantType.QUInt8)" ] }, @@ -503,7 +1036,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -512,7 +1045,7 @@ "from onnx import helper\n", "\n", "# Load the ONNX model\n", - "module_path = os.path.join(output_directory, \"text_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"text_encoder.onnx\")\n", "module = onnx.load(module_path)\n", "\n", "# Get the module's graph\n", @@ -543,7 +1076,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -584,7 +1117,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -595,21 +1128,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"text_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"text_encoder.onnx\")\n", "session = ort.InferenceSession(module_path, sess_options=session_options)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ - "module_path = os.path.join(output_directory, \"image_encoder.onnx\")\n", + "module_path = os.path.join(model_directory, \"image_encoder.onnx\")\n", "session = ort.InferenceSession(module_path, sess_options=session_options)" ] }, @@ -620,6 +1153,34 @@ "# Upload to Hugging Face" ] }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Consider using `hf_transfer` for faster uploads. This solution comes with some limitations. See https://huggingface.co/docs/huggingface_hub/hf_transfer for more details.\n", + "https://huggingface.co/unum-cloud/uform3-image-text-english-small/tree/main/.\n" + ] + } + ], + "source": [ + "!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../models/uform3-image-text-english-small/ . --exclude=\"torch_weight.pt\"" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/python/scripts/test_encoders.py b/python/scripts/test_encoders.py index d26e4f2..bd26690 100644 --- a/python/scripts/test_encoders.py +++ b/python/scripts/test_encoders.py @@ -27,16 +27,16 @@ torch_models = [ "unum-cloud/uform3-image-text-english-small", - "unum-cloud/uform3-image-text-english-base", - "unum-cloud/uform3-image-text-english-large", - "unum-cloud/uform3-image-text-multilingual-base", + # "unum-cloud/uform3-image-text-english-base", + # "unum-cloud/uform3-image-text-english-large", + # "unum-cloud/uform3-image-text-multilingual-base", ] onnx_models = [ "unum-cloud/uform3-image-text-english-small", - "unum-cloud/uform3-image-text-english-base", - "unum-cloud/uform3-image-text-english-large", - "unum-cloud/uform3-image-text-multilingual-base", + # "unum-cloud/uform3-image-text-english-base", + # "unum-cloud/uform3-image-text-english-large", + # "unum-cloud/uform3-image-text-multilingual-base", ] # Let's check if the HuggingFace Hub API token is set in the environment variable. @@ -198,8 +198,8 @@ def test_onnx_one_embedding(model_name: str, device: str): # Test if the model outputs actually make sense cross_references_image_and_text_embeddings( - lambda text: model_text(processor_text(text)), - lambda image: model_image(processor_image(image)), + lambda text: model_text(processor_text(text))[1], + lambda image: model_image(processor_image(image))[1], ) except ExecutionProviderError as e: diff --git a/python/uform/numpy_processors.py b/python/uform/numpy_processors.py index a5faca2..027bc0d 100644 --- a/python/uform/numpy_processors.py +++ b/python/uform/numpy_processors.py @@ -34,7 +34,7 @@ def __call__(self, texts: Union[str, List[str]]) -> Dict[str, np.ndarray]: input_ids = np.full( (len(texts), self._max_seq_len), fill_value=self._pad_token_idx, - dtype=np.int64, + dtype=np.int32, ) attention_mask = np.zeros( diff --git a/python/uform/onnx_encoders.py b/python/uform/onnx_encoders.py index 9f63fa4..a6f27d3 100644 --- a/python/uform/onnx_encoders.py +++ b/python/uform/onnx_encoders.py @@ -64,6 +64,7 @@ def __init__( model_path: str, *, device: Literal["cpu", "cuda"] = "cpu", + return_features: bool = True, ): """ :param model_path: Path to onnx model @@ -73,14 +74,21 @@ def __init__( session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + self.return_features = return_features self.session = ort.InferenceSession( model_path, sess_options=session_options, providers=available_providers(device), ) - def __call__(self, images: ndarray) -> Tuple[ndarray, ndarray]: - return self.session.run(None, {"images": images}) + def __call__( + self, images: ndarray, return_features: Optional[bool] = None + ) -> Union[ndarray, Tuple[ndarray, ndarray]]: + features, embeddings = self.session.run(None, {"images": images}) + return_features = return_features if return_features is not None else self.return_features + if return_features: + return features, embeddings + return embeddings class TextEncoder: @@ -89,6 +97,7 @@ def __init__( model_path: str, *, device: Literal["cpu", "cuda"] = "cpu", + return_features: bool = True, ): """ :param text_encoder_path: Path to onnx of text encoder @@ -98,11 +107,31 @@ def __init__( session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + self.return_features = return_features self.text_encoder_session = ort.InferenceSession( model_path, sess_options=session_options, providers=available_providers(device), ) - def __call__(self, input_ids: ndarray, attention_mask: ndarray) -> Tuple[ndarray, ndarray]: - return self.text_encoder_session.run(None, {"input_ids": input_ids, "attention_mask": attention_mask}) + def __call__( + self, + x: Union[ndarray, dict], + attention_mask: Optional[ndarray] = None, + return_features: Optional[bool] = None, + ) -> Union[ndarray, Tuple[ndarray, ndarray]]: + if isinstance(x, dict): + assert attention_mask is None, "If `x` is a dictionary, then `attention_mask` should be None" + attention_mask = x["attention_mask"] + input_ids = x["input_ids"] + else: + input_ids = x + + features, embeddings = self.text_encoder_session.run( + None, {"input_ids": input_ids, "attention_mask": attention_mask} + ) + + return_features = return_features if return_features is not None else self.return_features + if return_features: + return features, embeddings + return embeddings diff --git a/python/uform/torch_encoders.py b/python/uform/torch_encoders.py index 8ac7c36..0504a74 100644 --- a/python/uform/torch_encoders.py +++ b/python/uform/torch_encoders.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from os import PathLike -from typing import Dict, Optional, Tuple, Union, Callable +from typing import Dict, Optional, Union, Mapping, Any import json import torch @@ -274,7 +274,12 @@ def forward( return embeddings @staticmethod - def from_pretrained(config: Union[PathLike, str, object], model_path: Union[PathLike, str]) -> TextEncoder: + def from_pretrained(config: Union[PathLike, str, object], model: Union[PathLike, str]) -> TextEncoder: + """Load the image encoder from the given configuration and model path. + + :param config: the configuration dictionary or path to the JSON configuration file + :param model: the model state dictionary or path to the `.pt` model file + """ if isinstance(config, (PathLike, str)): config = json.load(open(config, "r")) if "text_encoder" in config: @@ -283,9 +288,15 @@ def from_pretrained(config: Union[PathLike, str, object], model_path: Union[Path # We must strip all the non-member attributes before initializing the classes. text_fields = TextEncoder.__dataclass_fields__ config = {k: v for k, v in config.items() if k in text_fields} - - state = torch.load(model_path) encoder = TextEncoder(**config) + + # Load from disk + if isinstance(model, (PathLike, str)): + state = torch.load(model) + else: + state = model + if "text_encoder" in state: + state = state["text_encoder"] encoder.load_state_dict(state) return encoder @@ -351,7 +362,15 @@ def forward(self, x: Tensor, return_features: Optional[bool] = None) -> Tensor: return embeddings @staticmethod - def from_pretrained(config: Union[PathLike, str, object], model_path: Union[PathLike, str]) -> ImageEncoder: + def from_pretrained( + config: Union[PathLike, str, object], + model: Union[PathLike, str, Mapping[str, Any]], + ) -> ImageEncoder: + """Load the image encoder from the given configuration and model path. + + :param config: the configuration dictionary or path to the JSON configuration file + :param model: the model state dictionary or path to the `.pt` model file + """ if isinstance(config, (PathLike, str)): config = json.load(open(config, "r")) if "image_encoder" in config: @@ -360,8 +379,14 @@ def from_pretrained(config: Union[PathLike, str, object], model_path: Union[Path # We must strip all the non-member attributes before initializing the classes. image_fields = ImageEncoder.__dataclass_fields__ config = {k: v for k, v in config.items() if k in image_fields} - - state = torch.load(model_path) encoder = ImageEncoder(**config) + + # Load from disk + if isinstance(model, (PathLike, str)): + state = torch.load(model) + else: + state = model + if "image_encoder" in state: + state = state["image_encoder"] encoder.load_state_dict(state) return encoder