diff --git a/hermes/aeriel/client/client.py b/hermes/aeriel/client/client.py index 0398872..4606e69 100644 --- a/hermes/aeriel/client/client.py +++ b/hermes/aeriel/client/client.py @@ -497,8 +497,13 @@ def infer( except KeyError: raise ValueError(f"Missing state {name}") + # sometimes we can have a batched state, in which + # case don't append a batch dimension + if shape[0] == 1 and value.ndim < len(shape): + value = value[None] + # add the update to our running list of updates - state_values.append(value[None]) + state_values.append(value) # if we have more than one state, combine them # into a single tensor along the channel axis diff --git a/poetry.lock b/poetry.lock index 46d2caf..fddc468 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2093,7 +2093,7 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] name = "onnx" version = "1.15.0" description = "Open Neural Network Exchange" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "onnx-1.15.0-cp310-cp310-macosx_10_12_universal2.whl", hash = "sha256:51cacb6aafba308aaf462252ced562111f6991cdc7bc57a6c554c3519453a8ff"}, @@ -2484,6 +2484,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3479,9 +3480,9 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] gcs = ["google-cloud-storage"] tensorflow = ["tensorflow"] tensorrt = ["nvidia-tensorrt"] -torch = ["torch", "urllib3"] +torch = ["onnx", "torch", "urllib3"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "a32cc5daeb679b9541a77a31e1d5c26350a891d21b6c5bae4043b5417c093214" +content-hash = "ead8ff14ef32e6eacd3b6717b1736acdf541b70c30b3ac0f810a61fe6e48d493" diff --git a/pyproject.toml b/pyproject.toml index 86d6d81..1fb7320 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ protobuf = "^3.17" requests = "^2.26.0" # quiver optional dependencies -tensorflow = {version = "^2.3", optional = true} +tensorflow = {version = "<2.14", optional = true} torch = {version = "^2.0", optional = true} google-cloud-storage = {version = "^1.38", optional = true } nvidia-tensorrt = { version = "^8.0", optional = true, source = "ngc" }