Skip to content

Commit

Permalink
Chore: pre-commit again.
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeyers committed Aug 22, 2024
1 parent a0cf2cc commit 4a86475
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion og_marl/vault_utils/combine_vaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def combine_vaults(rel_dir: str, vault_name: str, vault_uids: Optional[list[str]
# check that a subsampled vault by the same name does not already exist
if check_directory_exists_and_not_empty(f"./{rel_dir}/{new_vault_name}"):
print(
f"Vault '{rel_dir}/{new_vault_name.removesuffix('.vlt')}' already exists. To combine from scratch, please remove the current combined vault from its directory." #noqa
f"Vault '{rel_dir}/{new_vault_name.removesuffix('.vlt')}' already exists. To combine from scratch, please remove the current combined vault from its directory." # noqa
)
return new_vault_name

Expand Down
8 changes: 4 additions & 4 deletions og_marl/vault_utils/download_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Dict

import os
import sys
Expand Down Expand Up @@ -139,7 +139,7 @@
}


def print_download_options() -> None:
def print_download_options() -> Dict[str, Dict]:
print("VAULT_INFO:")
for source in VAULT_INFO.keys():
print(f"\t {source}")
Expand All @@ -163,7 +163,7 @@ def download_and_unzip_vault(
f"{dataset_base_dir}/{dataset_source}/{env_name}/{scenario_name}.vlt"
):
print(
f"Vault '{dataset_base_dir}/{dataset_source}/{env_name}/{scenario_name}' already exists."
f"Vault '{dataset_base_dir}/{dataset_source}/{env_name}/{scenario_name}' already exists." # noqa
)
return f"{dataset_base_dir}/{dataset_source}/{env_name}/{scenario_name}.vlt"

Expand All @@ -178,7 +178,7 @@ def download_and_unzip_vault(
print(
"Dataset from "
+ str(dataset_download_url)
+ " could not be downloaded. Try entering a different URL, or removing the part which auto-downloads."
+ " could not be downloaded. Try entering a different URL, or removing the part which auto-downloads." # noqa
)
return f"{dataset_base_dir}/{dataset_source}/{env_name}/{scenario_name}.vlt"

Expand Down
2 changes: 1 addition & 1 deletion og_marl/vault_utils/subsample_similar.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def subsample_similar(
# check that a subsampled vault by the same name does not already exist
if check_directory_exists_and_not_empty(f"./{new_rel_dir}/{new_vault_name}"):
print(
f"Vault '{new_rel_dir}/{new_vault_name.removesuffix('.vlt')}' already exists. To subsample from scratch, please remove the current subsampled vault from its directory." #noqa
f"Vault '{new_rel_dir}/{new_vault_name.removesuffix('.vlt')}' already exists. To subsample from scratch, please remove the current subsampled vault from its directory." # noqa
)
return

Expand Down
11 changes: 6 additions & 5 deletions og_marl/vault_utils/subsample_smaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from git import Optional

import jax
import pickle
Expand Down Expand Up @@ -107,21 +107,22 @@ def stitch_vault_from_sampled_episodes_(

for start, end in zip(len_start_end_sample[:, 1], len_start_end_sample[:, 2]):
sample_experience = jax.tree_util.tree_map(
lambda x: x[:, int(start) : int(end + 1), ...], experience
lambda x: x[:, int(start) : int(end + 1), ...], # noqa
experience,
)
dest_state = buffer_add(dest_state, sample_experience)

timesteps_written = dest_vault.write(dest_state)

print(timesteps_written)

return timesteps_written
return int(timesteps_written)


def subsample_smaller_vault(
vaults_dir: str,
vault_name: str,
vault_uids: list = None,
vault_uids: Optional[list] = None,
target_number_of_transitions: int = 500000,
) -> str:
# check that the vault to be subsampled exists
Expand All @@ -141,7 +142,7 @@ def subsample_smaller_vault(
# check that a subsampled vault by the same name does not already exist
if check_directory_exists_and_not_empty(f"./{vaults_dir}/{new_vault_name}"):
print(
f"Vault '{vaults_dir}/{new_vault_name.removesuffix('.vlt')}' already exists. To subsample from scratch, please remove the current subsampled vault from its directory." #noqa
f"Vault '{vaults_dir}/{new_vault_name.removesuffix('.vlt')}' already exists. To subsample from scratch, please remove the current subsampled vault from its directory." # noqa
)
return f"./{vaults_dir}/{vault_name}"

Expand Down

0 comments on commit 4a86475

Please sign in to comment.