Skip to content

Commit

Permalink
Batched state updates (#59)
Browse files Browse the repository at this point in the history
* allow for batched state updates

* update poetry lock

* restrict tf version

* revert whitespaces
  • Loading branch information
EthanMarx committed Oct 3, 2024
1 parent 2814849 commit b0b4ec1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
7 changes: 6 additions & 1 deletion hermes/aeriel/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,13 @@ def infer(
except KeyError:
raise ValueError(f"Missing state {name}")

# sometimes we can have a batched state, in which
# case don't append a batch dimension
if shape[0] == 1 and value.ndim < len(shape):
value = value[None]

# add the update to our running list of updates
state_values.append(value[None])
state_values.append(value)

# if we have more than one state, combine them
# into a single tensor along the channel axis
Expand Down
7 changes: 4 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ protobuf = "^3.17"
requests = "^2.26.0"

# quiver optional dependencies
tensorflow = {version = "^2.3", optional = true}
tensorflow = {version = "<2.14", optional = true}
torch = {version = "^2.0", optional = true}
google-cloud-storage = {version = "^1.38", optional = true }
nvidia-tensorrt = { version = "^8.0", optional = true, source = "ngc" }
Expand Down

0 comments on commit b0b4ec1

Please sign in to comment.