diff --git a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb new file mode 100644 index 0000000000000..455b5cccd5d52 --- /dev/null +++ b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb @@ -0,0 +1,854 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fFjof1NgAJwu", + "cellView": "form" + }, + "outputs": [], + "source": [ + "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", + "\n", + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "A8xNRyZMW1yK" + }, + "source": [ + "# Use Apache Beam and Bigtable to enrich data\n", + "\n", + "\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HrCtxslBGK8Z" + }, + "source": [ + "This notebook shows how to enrich data by using the Apache Beam [enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) with [Bigtable](https://cloud.google.com/bigtable/docs/overview). The enrichment transform is a turnkey transform in Apache Beam that lets you enrich data by using a key-value lookup. This transform has the following features:\n", + "\n", + "- The transform has a built-in Apache Beam handler that interacts with Bigtable to get data to use in the enrichment.\n", + "- The enrichment transform uses client-side throttling to manage rate limiting the requests. The requests are exponentially backed off with a default retry strategy. You can configure rate limiting to suit your use case." + ] + }, + { + "cell_type": "markdown", + "source": [ + "This notebook demonstrates the following ecommerce use case:\n", + "\n", + "A stream of online transaction from [Pub/Sub](https://cloud.google.com/pubsub/docs/guides) contains the following fields: `sale_id`, `product_id`, `customer_id`, `quantity`, and `price`. Additional customer demographic data is stored in a separate Bigtable cluster. The demographic data is used to enrich the event stream from Pub/Sub. Then, the enriched data is used to predict the next product to recommended to a customer." + ], + "metadata": { + "id": "ltn5zrBiGS9C" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gVCtGOKTHMm4" + }, + "source": [ + "## Before you begin\n", + "Set up your environment and download dependencies." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YDHPlMjZRuY0" + }, + "source": [ + "### Install Apache Beam\n", + "To use the enrichment transform with the built-in Bigtable handler, install the Apache Beam SDK version 2.54.0 or later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jBakpNZnAhqk" + }, + "outputs": [], + "source": [ + "!pip install torch\n", + "!pip install apache_beam[interactive,gcp]==2.54.0 --quiet" + ] + }, + { + "cell_type": "code", + "source": [ + "import datetime\n", + "import json\n", + "import math\n", + "\n", + "from typing import Any\n", + "from typing import Dict\n", + "\n", + "import torch\n", + "from google.cloud import pubsub_v1\n", + "from google.cloud.bigtable import Client\n", + "from google.cloud.bigtable import column_family\n", + "\n", + "import apache_beam as beam\n", + "import apache_beam.runners.interactive.interactive_beam as ib\n", + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n", + "from apache_beam.options import pipeline_options\n", + "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n", + "from apache_beam.transforms.enrichment import Enrichment\n", + "from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler" + ], + "metadata": { + "id": "SiJii48A2Rnb" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X80jy3FqHjK4" + }, + "source": [ + "### Authenticate with Google Cloud\n", + "This notebook reads data from Pub/Sub and Bigtable. To use your Google Cloud account, authenticate this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Kz9sccyGBqz3" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Replace ``, ``, and `` with the appropriate values for your setup. These fields are used with Bigtable." + ], + "metadata": { + "id": "nAmGgUMt48o9" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "wEXucyi2liij" + }, + "outputs": [], + "source": [ + "PROJECT_ID = \"\"\n", + "INSTANCE_ID = \"\"\n", + "TABLE_ID = \"\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Train the model\n", + "\n" + ], + "metadata": { + "id": "RpqZFfFfA_Dt" + } + }, + { + "cell_type": "markdown", + "source": [ + "Create sample data by using the format `[product_id, quantity, price, customer_id, customer_location, recommend_product_id]`." + ], + "metadata": { + "id": "8cUpV7mkB_xE" + } + }, + { + "cell_type": "code", + "source": [ + "data = [\n", + " [3, 5, 127, 9, 'China', 7], [1, 6, 167, 5, 'Peru', 4], [5, 4, 91, 2, 'USA', 8], [7, 2, 52, 1, 'India', 4], [1, 8, 118, 3, 'UK', 8], [4, 6, 132, 8, 'Mexico', 2],\n", + " [6, 3, 154, 6, 'Brazil', 3], [4, 7, 163, 1, 'India', 7], [5, 2, 80, 4, 'Egypt', 9], [9, 4, 107, 7, 'Bangladesh', 1], [2, 9, 192, 8, 'Mexico', 4], [4, 5, 116, 5, 'Peru', 8],\n", + " [8, 1, 195, 1, 'India', 7], [8, 6, 153, 5, 'Peru', 1], [5, 3, 120, 6, 'Brazil', 2], [2, 7, 187, 7, 'Bangladesh', 4], [1, 8, 103, 6, 'Brazil', 8], [2, 9, 181, 1, 'India', 8],\n", + " [6, 5, 166, 3, 'UK', 5], [3, 4, 115, 8, 'Mexico', 1], [4, 7, 170, 4, 'Egypt', 2], [9, 3, 141, 7, 'Bangladesh', 3], [9, 3, 157, 1, 'India', 2], [7, 6, 128, 9, 'China', 1],\n", + " [1, 8, 102, 3, 'UK', 4], [5, 2, 107, 4, 'Egypt', 6], [6, 5, 164, 8, 'Mexico', 9], [4, 7, 188, 5, 'Peru', 1], [8, 1, 184, 1, 'India', 2], [8, 6, 198, 2, 'USA', 5],\n", + " [5, 3, 105, 6, 'Brazil', 7], [2, 7, 162, 7, 'Bangladesh', 7], [1, 8, 133, 9, 'China', 3], [2, 9, 173, 1, 'India', 7], [6, 5, 183, 5, 'Peru', 8], [3, 4, 191, 3, 'UK', 6],\n", + " [4, 7, 123, 2, 'USA', 5], [9, 3, 159, 8, 'Mexico', 2], [9, 3, 146, 4, 'Egypt', 8], [7, 6, 194, 1, 'India', 8], [3, 5, 112, 6, 'Brazil', 1], [4, 6, 101, 7, 'Bangladesh', 2],\n", + " [8, 1, 192, 4, 'Egypt', 4], [7, 2, 196, 5, 'Peru', 6], [9, 4, 124, 9, 'China', 7], [3, 4, 129, 5, 'Peru', 6], [6, 3, 151, 8, 'Mexico', 9], [5, 7, 114, 7, 'Bangladesh', 4],\n", + " [4, 7, 175, 6, 'Brazil', 5], [1, 8, 121, 1, 'India', 2], [4, 6, 187, 2, 'USA', 5], [6, 5, 144, 9, 'China', 9], [9, 4, 103, 5, 'Peru', 3], [5, 3, 84, 3, 'UK', 1],\n", + " [3, 5, 193, 2, 'USA', 4], [4, 7, 135, 1, 'India', 1], [7, 6, 148, 8, 'Mexico', 8], [1, 6, 160, 5, 'Peru', 7], [8, 6, 155, 6, 'Brazil', 9], [5, 7, 183, 7, 'Bangladesh', 2],\n", + " [2, 9, 125, 4, 'Egypt', 4], [6, 3, 111, 9, 'China', 9], [5, 2, 132, 3, 'UK', 3], [4, 5, 104, 7, 'Bangladesh', 7], [2, 7, 177, 8, 'Mexico', 7]]" + ], + "metadata": { + "id": "TpxDHGObBEsj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "countries_to_id = {'India': 1, 'USA': 2, 'UK': 3, 'Egypt': 4, 'Peru': 5,\n", + " 'Brazil': 6, 'Bangladesh': 7, 'Mexico': 8, 'China': 9}" + ], + "metadata": { + "id": "bQt1cB4-CSBd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Preprocess the data:\n", + "\n", + "1. Convert the lists to tensors.\n", + "2. Separate the features from the expected prediction." + ], + "metadata": { + "id": "Y0Duet4nCdN1" + } + }, + { + "cell_type": "code", + "source": [ + "X = [torch.tensor(item[:4]+[countries_to_id[item[4]]], dtype=torch.float) for item in data]\n", + "Y = [torch.tensor(item[-1], dtype=torch.float) for item in data]" + ], + "metadata": { + "id": "7TT1O7sBCaZN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Define a simple model that has five input features and predicts a single value." + ], + "metadata": { + "id": "q6wB_ZsXDjjd" + } + }, + { + "cell_type": "code", + "source": [ + "def build_model(n_inputs, n_outputs):\n", + " \"\"\"build_model builds and returns a model that takes\n", + " `n_inputs` features and predicts `n_outputs` value\"\"\"\n", + " return torch.nn.Sequential(\n", + " torch.nn.Linear(n_inputs, 8),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(8, 16),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(16, n_outputs))" + ], + "metadata": { + "id": "nphNfhUnESES" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Train the model." + ], + "metadata": { + "id": "_sBSzDllEmCz" + } + }, + { + "cell_type": "code", + "source": [ + "model = build_model(n_inputs=5, n_outputs=1)\n", + "\n", + "loss_fn = torch.nn.MSELoss()\n", + "optimizer = torch.optim.Adam(model.parameters())\n", + "\n", + "for epoch in range(1000):\n", + " print(f'Epoch {epoch}: ---')\n", + " optimizer.zero_grad()\n", + " for i in range(len(X)):\n", + " pred = model(X[i])\n", + " loss = loss_fn(pred, Y[i])\n", + " loss.backward()\n", + " optimizer.step()" + ], + "metadata": { + "id": "CaYrplaPDayp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Save the model to the `STATE_DICT_PATH` variable." + ], + "metadata": { + "id": "_rJYv8fFFPYb" + } + }, + { + "cell_type": "code", + "source": [ + "STATE_DICT_PATH = './model.pth'\n", + "\n", + "torch.save(model.state_dict(), STATE_DICT_PATH)" + ], + "metadata": { + "id": "W4t260o9FURP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Set up the Bigtable table\n", + "\n", + "Create a sample Bigtable table for this notebook." + ], + "metadata": { + "id": "ouMQZ4sC4zuO" + } + }, + { + "cell_type": "code", + "source": [ + "# Connect to the Bigtable instance. If you don't have admin access, then drop `admin=True`.\n", + "client = Client(project=PROJECT_ID, admin=True)\n", + "instance = client.instance(INSTANCE_ID)\n", + "\n", + "# Create a column family.\n", + "column_family_id = 'demograph'\n", + "max_versions_rule = column_family.MaxVersionsGCRule(2)\n", + "column_families = {column_family_id: max_versions_rule}\n", + "\n", + "# Create a table.\n", + "table = instance.table(TABLE_ID)\n", + "\n", + "# You need admin access to use `.exists()`. If you don't have the admin access, then\n", + "# comment out the if-else block.\n", + "if not table.exists():\n", + " table.create(column_families=column_families)\n", + "else:\n", + " print(\"Table %s already exists in %s:%s\" % (TABLE_ID, PROJECT_ID, INSTANCE_ID))" + ], + "metadata": { + "id": "E7Y4ipuL5kFD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Add rows to the table for the enrichment example." + ], + "metadata": { + "id": "eQLkSg3p7WAm" + } + }, + { + "cell_type": "code", + "source": [ + "# Define column names for the table.\n", + "customer_id = 'customer_id'\n", + "customer_name = 'customer_name'\n", + "customer_location = 'customer_location'\n", + "\n", + "# The following data is sample data to insert into Bigtable.\n", + "customers = [\n", + " {\n", + " 'customer_id': 1, 'customer_name': 'Sam', 'customer_location': 'India'\n", + " },\n", + " {\n", + " 'customer_id': 2, 'customer_name': 'John', 'customer_location': 'USA'\n", + " },\n", + " {\n", + " 'customer_id': 3, 'customer_name': 'Travis', 'customer_location': 'UK'\n", + " },\n", + "]\n", + "\n", + "for customer in customers:\n", + " row_key = str(customer[customer_id]).encode()\n", + " row = table.direct_row(row_key)\n", + " row.set_cell(\n", + " column_family_id,\n", + " customer_id.encode(),\n", + " str(customer[customer_id]),\n", + " timestamp=datetime.datetime.utcnow())\n", + " row.set_cell(\n", + " column_family_id,\n", + " customer_name.encode(),\n", + " customer[customer_name],\n", + " timestamp=datetime.datetime.utcnow())\n", + " row.set_cell(\n", + " column_family_id,\n", + " customer_location.encode(),\n", + " customer[customer_location],\n", + " timestamp=datetime.datetime.utcnow())\n", + " row.commit()\n", + " print('Inserted row for key: %s' % customer[customer_id])" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LI6oYkZ97Vtu", + "outputId": "c72b28b5-8692-40f5-f8da-85622437d8f7" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Inserted row for key: 1\n", + "Inserted row for key: 2\n", + "Inserted row for key: 3\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Publish messages to Pub/Sub\n", + "\n", + "Use the Pub/Sub Python client to publish messages.\n" + ], + "metadata": { + "id": "pHODouJDwc60" + } + }, + { + "cell_type": "code", + "source": [ + "# Replace with the name of your Pub/Sub topic.\n", + "TOPIC = \"\"\n", + "\n", + "# Replace with the subscription for your topic.\n", + "SUBSCRIPTION = \"\"\n" + ], + "metadata": { + "id": "QKCuwDioxw-f" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "messages = [\n", + " {'sale_id': i, 'customer_id': i, 'product_id': i, 'quantity': i, 'price': i*100}\n", + " for i in range(1,4)\n", + " ]\n", + "\n", + "publisher = pubsub_v1.PublisherClient()\n", + "topic_name = publisher.topic_path(PROJECT_ID, TOPIC)\n", + "\n", + "for message in messages:\n", + " data = json.dumps(message).encode('utf-8')\n", + " publish_future = publisher.publish(topic_name, data)" + ], + "metadata": { + "id": "MaCJwaPexPKZ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Use the Bigtable enrichment handler\n", + "\n", + "The [`BigTableEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler) is a built-in handler included in the Apache Beam SDK versions 2.54.0 and later." + ], + "metadata": { + "id": "zPSFEMm02omi" + } + }, + { + "cell_type": "markdown", + "source": [ + "To establish a client for the Bigtable enrichment handler, replace ``, ``, and `` with the appropriate values for those fields. The `row_key` variable is the field name from the input row that contains the row key to use when querying Bigtable.\n", + "\n", + "To convert a `string` type to a `byte` type or a `byte` type to a `string` type from Bigtable, you can configure additional options, such as [`app_profile_id`](https://cloud.google.com/bigtable/docs/app-profiles), [`row_filter`](https://cloud.google.com/python/docs/reference/bigtable/latest/row-filters), and [`encoding`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler:~:text=for%20more%20details.-,encoding,-(str)%20%E2%80%93%20encoding) type.\n", + "\n", + "The default `encoding` type is `utf-8`.\n", + "\n", + "\n" + ], + "metadata": { + "id": "K41xhvmA5yQk" + } + }, + { + "cell_type": "code", + "source": [ + "row_key = 'customer_id'" + ], + "metadata": { + "id": "3dB26jhI45gd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "bigtable_handler = BigTableEnrichmentHandler(project_id=PROJECT_ID,\n", + " instance_id=INSTANCE_ID,\n", + " table_id=TABLE_ID,\n", + " row_key=row_key)" + ], + "metadata": { + "id": "cr1j_DHK4gA4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The `BigTableEnrichmentHandler` returns the latest value from the table without its associated timestamp for the `row_key` that you provide. If you want to fetch the `timestamp` associated with the `row_key` value, then pass `include_timestamp=True` to the handler.\n", + "\n", + "**Note:** When exceptions occur, by default, the logging severity is set to warning ([`ExceptionLevel.WARN`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel.WARN)). To configure the severity to raise exceptions, set `exception_level` to [`ExceptionLevel.RAISE`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel.RAISE). To ignore exceptions, set `exception_level` to [`ExceptionLevel.QUIET`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel.QUIET)." + ], + "metadata": { + "id": "yFMcaf8i7TbI" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Use the enrichment transform\n", + "\n", + "To use the [enrichment transform](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.Enrichment), the [`EnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.EnrichmentSourceHandler) parameter is required. You can also use a configuration parameter to specify a `lambda` for a join function, a timeout, a throttler, and a repeater (retry strategy).\n", + "\n", + "\n", + "* `join_fn`: A lambda function that takes dictionaries as input and returns an enriched row (`Callable[[Dict[str, Any], Dict[str, Any]], beam.Row]`). The enriched row specifies how to join the data fetched from the API. Defaults to a [cross-join](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.cross_join).\n", + "* `timeout`: The number of seconds to wait for the request to be completed by the API before timing out. Defaults to 30 seconds.\n", + "* `throttler`: Specifies the throttling mechanism. The only supported option is default client-side adaptive throttling.\n", + "* `repeater`: Specifies the retry strategy when errors like `TooManyRequests` and `TimeoutException` occur. Defaults to [`ExponentialBackOffRepeater`](https://beam.apache.org/releases/pydoc/current/apache_beam.io.requestresponse.html#apache_beam.io.requestresponse.ExponentialBackOffRepeater).\n" + ], + "metadata": { + "id": "-Lvo8O2V-0Ey" + } + }, + { + "cell_type": "markdown", + "source": [ + "The following example demonstrates the code needed to add this transform to your pipeline.\n", + "\n", + "\n", + "```\n", + "with beam.Pipeline() as p:\n", + " output = (p\n", + " ...\n", + " | \"Enrich with BigTable\" >> Enrichment(bigtable_handler, timeout=10)\n", + " | \"RunInference\" >> RunInference(model_handler)\n", + " ...\n", + " )\n", + "```\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "xJTCfSmiV1kv" + } + }, + { + "cell_type": "markdown", + "source": [ + "To make a prediction, use the following fields: `product_id`, `quantity`, `price`, `customer_id`, and `customer_location`. Retrieve the value of the `customer_location` field from Bigtable.\n", + "\n", + "Because the enrichment transform performs a [`cross_join`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.cross_join) by default, design the custom join to enrich the input data. This design ensures that the join includes only the specified fields." + ], + "metadata": { + "id": "F-xjiP_pHWZr" + } + }, + { + "cell_type": "code", + "source": [ + "def custom_join(left: Dict[str, Any], right: Dict[str, Any]):\n", + " enriched = {}\n", + " enriched['product_id'] = left['product_id']\n", + " enriched['quantity'] = left['quantity']\n", + " enriched['price'] = left['price']\n", + " enriched['customer_id'] = left['customer_id']\n", + " enriched['customer_location'] = right['demograph']['customer_location']\n", + " return beam.Row(**enriched)" + ], + "metadata": { + "id": "8LnCnEPNIPtg" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Use the `PyTorchModelHandlerTensor` interface to run inference\n", + "\n" + ], + "metadata": { + "id": "CX9Cqybu6scV" + } + }, + { + "cell_type": "markdown", + "source": [ + "Because the enrichment transform outputs data in the format `beam.Row`, to make it compatible with the [`PyTorchModelHandlerTensor`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.pytorch_inference.html#apache_beam.ml.inference.pytorch_inference.PytorchModelHandlerTensor) interface, convert it to `torch.tensor`. Additionally, the enriched field `customer_location` is a `string` type, but the model requires a `float` type. Convert the `customer_location` field to a `float` type." + ], + "metadata": { + "id": "zy5Jl7_gLklX" + } + }, + { + "cell_type": "code", + "source": [ + "def convert_row_to_tensor(element: beam.Row):\n", + " row_dict = element._asdict()\n", + " row_dict['customer_location'] = countries_to_id[row_dict['customer_location']]\n", + " return torch.tensor(list(row_dict.values()), dtype=torch.float)" + ], + "metadata": { + "id": "KBKoB06nL4LF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Initialize the model handler with the preprocessing function." + ], + "metadata": { + "id": "-tGHyB_vL3rJ" + } + }, + { + "cell_type": "code", + "source": [ + "model_handler = PytorchModelHandlerTensor(state_dict_path=STATE_DICT_PATH,\n", + " model_class=build_model,\n", + " model_params={'n_inputs':5, 'n_outputs':1}\n", + " ).with_preprocess_fn(convert_row_to_tensor)" + ], + "metadata": { + "id": "VqUUEwcU-r2e" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Define a `DoFn` to format the output." + ], + "metadata": { + "id": "vNHI4gVgNec2" + } + }, + { + "cell_type": "code", + "source": [ + "class PostProcessor(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " print('Customer %d who bought product %d is recommended to buy product %d' % (element.example[3], element.example[0], math.ceil(element.inference[0])))" + ], + "metadata": { + "id": "rkN-_Yf4Nlwy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0a1zerXycQ0z" + }, + "source": [ + "## Run the pipeline\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Configure the pipeline to run in streaming mode." + ], + "metadata": { + "id": "WrwY0_gV_IDK" + } + }, + { + "cell_type": "code", + "source": [ + "options = pipeline_options.PipelineOptions()\n", + "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True" + ], + "metadata": { + "id": "t0425sYBsYtB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`." + ], + "metadata": { + "id": "DBNijQDY_dRe" + } + }, + { + "cell_type": "code", + "source": [ + "class DecodeBytes(beam.DoFn):\n", + " \"\"\"\n", + " The DecodeBytes `DoFn` converts the data read from Pub/Sub to `beam.Row`.\n", + " First, decode the encoded string. Convert the output to\n", + " a `dict` with `json.loads()`, which is used to create a `beam.Row`.\n", + " \"\"\"\n", + " def process(self, element, *args, **kwargs):\n", + " element_dict = json.loads(element.decode('utf-8'))\n", + " yield beam.Row(**element_dict)" + ], + "metadata": { + "id": "sRw9iL8pKP5O" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Use the following code to run the pipeline.\n", + "\n", + "**Note:** Because this pipeline is a streaming pipeline, you need to manually stop the cell. If you don't stop the cell, the pipeline continues to run." + ], + "metadata": { + "id": "xofUJym-_GuB" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "St07XoibcQSb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "34e0a603-fb77-455c-e40b-d15b672edeb2" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " if (typeof window.interactive_beam_jquery == 'undefined') {\n", + " var jqueryScript = document.createElement('script');\n", + " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n", + " jqueryScript.type = 'text/javascript';\n", + " jqueryScript.onload = function() {\n", + " var datatableScript = document.createElement('script');\n", + " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n", + " datatableScript.type = 'text/javascript';\n", + " datatableScript.onload = function() {\n", + " window.interactive_beam_jquery = jQuery.noConflict(true);\n", + " window.interactive_beam_jquery(document).ready(function($){\n", + " \n", + " });\n", + " }\n", + " document.head.appendChild(datatableScript);\n", + " };\n", + " document.head.appendChild(jqueryScript);\n", + " } else {\n", + " window.interactive_beam_jquery(document).ready(function($){\n", + " \n", + " });\n", + " }" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Customer 1 who bought product 1 is recommended to buy product 3\n", + "Customer 2 who bought product 2 is recommended to buy product 5\n", + "Customer 3 who bought product 3 is recommended to buy product 7\n" + ] + } + ], + "source": [ + "with beam.Pipeline(options=options) as p:\n", + " _ = (p\n", + " | \"Read from Pub/Sub\" >> beam.io.ReadFromPubSub(subscription=SUBSCRIPTION)\n", + " | \"ConvertToRow\" >> beam.ParDo(DecodeBytes())\n", + " | \"Enrichment\" >> Enrichment(bigtable_handler, join_fn=custom_join)\n", + " | \"RunInference\" >> RunInference(model_handler)\n", + " | \"Format Output\" >> beam.ParDo(PostProcessor())\n", + " )\n" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file