Skip to content

Commit

Permalink
Fix/bernoulli (#186)
Browse files Browse the repository at this point in the history
* TF-like Bernoulli mode

* pre-commit

* Default dmc config
  • Loading branch information
belerico authored Jan 12, 2024
1 parent 9cff9a7 commit 2c9c0b3
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 14 deletions.
7 changes: 4 additions & 3 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor
from torch.distributions import Bernoulli, Distribution, Independent
from torch.distributions import Distribution, Independent
from torch.optim import Optimizer
from torchmetrics import SumMetric

Expand All @@ -27,6 +27,7 @@
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.envs.wrappers import RestartOnException
from sheeprl.utils.distribution import (
BernoulliSafeMode,
MSEDistribution,
OneHotCategoricalValidateArgs,
SymlogDistribution,
Expand Down Expand Up @@ -145,7 +146,7 @@ def train(

# Compute the distribution over the terminal steps, if required
pc = Independent(
Bernoulli(logits=world_model.continue_model(latent_states), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(latent_states), validate_args=validate_args),
1,
validate_args=validate_args,
)
Expand Down Expand Up @@ -229,7 +230,7 @@ def train(
predicted_values = TwoHotEncodingDistribution(critic(imagined_trajectories), dims=1).mean
predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean
continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args=validate_args,
).mode
Expand Down
9 changes: 5 additions & 4 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
from omegaconf import DictConfig
from torch import Tensor, nn
from torch.distributions import Bernoulli, Distribution, Independent
from torch.distributions import Distribution, Independent
from torchmetrics import SumMetric

from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel
Expand All @@ -21,6 +21,7 @@
from sheeprl.algos.p2e_dv3.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.distribution import (
BernoulliSafeMode,
MSEDistribution,
OneHotCategoricalValidateArgs,
SymlogDistribution,
Expand Down Expand Up @@ -161,7 +162,7 @@ def train(

# Compute the distribution over the terminal steps, if required
pc = Independent(
Bernoulli(logits=world_model.continue_model(latent_states.detach()), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(latent_states.detach()), validate_args=validate_args),
1,
validate_args=validate_args,
)
Expand Down Expand Up @@ -268,7 +269,7 @@ def train(
# Predict values and continues
predicted_values = TwoHotEncodingDistribution(critic["module"](imagined_trajectories), dims=1).mean
continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args=validate_args,
).mode
Expand Down Expand Up @@ -412,7 +413,7 @@ def train(
predicted_values = TwoHotEncodingDistribution(critic_task(imagined_trajectories), dims=1).mean
predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean
continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args=validate_args,
).mode
Expand Down
11 changes: 5 additions & 6 deletions sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ seed: 5

# Environment
env:
num_envs: 1
max_episode_steps: 1000
num_envs: 4
max_episode_steps: -1
id: walker_walk
wrapper:
from_vectors: True
from_vectors: False
from_pixels: True

# Checkpoint
Expand All @@ -34,9 +34,8 @@ algo:
encoder:
- rgb
mlp_keys:
encoder:
- state
learning_starts: 8000
encoder: []
learning_starts: 1024
train_every: 2
dense_units: 512
mlp_layers: 2
Expand Down
12 changes: 11 additions & 1 deletion sheeprl/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical, Distribution, constraints
from torch.distributions import Bernoulli, Categorical, Distribution, constraints
from torch.distributions.kl import _kl_categorical_categorical, register_kl
from torch.distributions.utils import broadcast_all

Expand Down Expand Up @@ -402,3 +402,13 @@ def rsample(self, sample_shape=torch.Size()):
@register_kl(OneHotCategoricalValidateArgs, OneHotCategoricalValidateArgs)
def _kl_onehotcategoricalvalidateargs_onehotcategoricalvalidateargs(p, q):
return _kl_categorical_categorical(p._categorical, q._categorical)


class BernoulliSafeMode(Bernoulli):
def __init__(self, probs=None, logits=None, validate_args=None):
super().__init__(probs, logits, validate_args)

@property
def mode(self):
mode = (self.probs > 0.5).to(self.probs)
return mode

0 comments on commit 2c9c0b3

Please sign in to comment.