Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Img augmentation for drq and bc agent #67

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 68 additions & 15 deletions serl_launcher/serl_launcher/agents/continuous/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from serl_launcher.networks.actor_critic_nets import Policy
from serl_launcher.networks.mlp import MLP
from serl_launcher.utils.train_utils import _unpack
from serl_launcher.vision.data_augmentations import batched_random_crop
from serl_launcher.vision.data_augmentations import (
batched_random_crop,
batched_color_transform,
)


class BCAgent(flax.struct.PyTreeNode):
Expand All @@ -24,24 +27,61 @@ class BCAgent(flax.struct.PyTreeNode):

def data_augmentation_fn(self, rng, observations):
for pixel_key in self.config["image_keys"]:
observations = observations.copy(
add_or_replace={
pixel_key: batched_random_crop(
observations[pixel_key], rng, padding=4, num_batch_dims=2
)
}
)
if self.config.get("image_aug_random_crop", False):
# apply random crop to the img observations
observations = observations.copy(
add_or_replace={
pixel_key: batched_random_crop(
observations[pixel_key], rng, padding=4, num_batch_dims=2
)
}
)

if self.config.get("image_aug_color_transform", False):
# NOTE: the original image is in uint8, and the color_transform function
# requires float32, thus we need to convert the image to float32 first
# then convert it back to uint8 after the color transformation
observations = observations.copy(
add_or_replace={
pixel_key: jnp.array(observations[pixel_key], dtype=jnp.float32)
/ 255.0,
}
)
observations = observations.copy(
add_or_replace={
pixel_key: batched_color_transform(
observations[pixel_key],
rng,
brightness=0.1,
contrast=0.1,
saturation=0.1,
hue=0.1,
apply_prob=1.0,
to_grayscale_prob=0.0, # don't convert to grayscale
color_jitter_prob=0.5,
shuffle=False, # wont shuffle the color channels
num_batch_dims=2, # 2 images observations
),
}
)
observations = observations.copy(
add_or_replace={
pixel_key: jnp.array(
observations[pixel_key] * 255.0, dtype=jnp.uint8
),
}
)
return observations

@partial(jax.jit, static_argnames="pmap_axis")
def update(self, batch: Batch, pmap_axis: str = None):
if self.config["image_keys"][0] not in batch["next_observations"]:
batch = _unpack(batch)

# rng = self.state.rng
# rng, obs_rng, next_obs_rng = jax.random.split(rng, 3)
# obs = self.data_augmentation_fn(obs_rng, batch["observations"])
# batch = batch.copy(add_or_replace={"observations": obs})
rng = self.state.rng
rng, obs_rng, next_obs_rng = jax.random.split(rng, 3)
obs = self.data_augmentation_fn(obs_rng, batch["observations"])
batch = batch.copy(add_or_replace={"observations": obs})

def loss_fn(params, rng):
rng, key = jax.random.split(rng)
Expand Down Expand Up @@ -98,6 +138,9 @@ def sample_actions(

@jax.jit
def get_debug_metrics(self, batch, **kwargs):
if self.config["image_keys"][0] not in batch["next_observations"]:
batch = _unpack(batch)

dist = self.state.apply_fn(
{"params": self.state.params},
batch["observations"],
Expand All @@ -109,9 +152,9 @@ def get_debug_metrics(self, batch, **kwargs):
mse = ((pi_actions - batch["actions"]) ** 2).sum(-1)

return {
"mse": mse,
"log_probs": log_probs,
"pi_actions": pi_actions,
"eval/mse": mse.mean(),
"eval/log_probs": log_probs.mean(),
"eval/pi_actions": pi_actions.mean(),
}

@classmethod
Expand All @@ -132,6 +175,8 @@ def create(
},
# Optimizer
learning_rate: float = 3e-4,
image_augmentation: Iterable[str] = (),
# image_augmentation: Iterable[str] = ("random_crop", "color_transform"),
):
if encoder_type == "small":
from serl_launcher.vision.small_encoders import SmallEncoder
Expand Down Expand Up @@ -167,6 +212,12 @@ def create(
resnetv1_configs,
)

# NOTE: commented code enables gradient flow for the pretrained encoder
# pretrained_encoder = resnetv1_configs["resnetv1-10-frozen"](
# pre_pooling=False,
# name="pretrained_encoder",
# pooling_method="none",
# )
pretrained_encoder = resnetv1_configs["resnetv1-10-frozen"](
pre_pooling=True,
name="pretrained_encoder",
Expand Down Expand Up @@ -218,6 +269,8 @@ def create(
)
config = dict(
image_keys=image_keys,
image_aug_random_crop=("random_crop" in image_augmentation),
image_aug_color_transform=("color_transform" in image_augmentation),
)

agent = cls(state, config)
Expand Down
61 changes: 53 additions & 8 deletions serl_launcher/serl_launcher/agents/continuous/drq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from serl_launcher.networks.lagrange import GeqLagrangeMultiplier
from serl_launcher.networks.mlp import MLP
from serl_launcher.utils.train_utils import _unpack, concat_batches
from serl_launcher.vision.data_augmentations import batched_random_crop
from serl_launcher.vision.data_augmentations import (
batched_random_crop,
batched_color_transform,
)


class DrQAgent(SACAgent):
Expand Down Expand Up @@ -50,6 +53,7 @@ def create(
critic_ensemble_size: int = 2,
critic_subsample_size: Optional[int] = None,
image_keys: Iterable[str] = ("image",),
image_augmentation: Iterable[str] = ("random_crop"),
):
networks = {
"actor": actor_def,
Expand Down Expand Up @@ -98,6 +102,8 @@ def create(
target_entropy=target_entropy,
backup_entropy=backup_entropy,
image_keys=image_keys,
image_aug_random_crop=("random_crop" in image_augmentation),
image_aug_color_transform=("color_transform" in image_augmentation),
),
)

Expand Down Expand Up @@ -125,6 +131,7 @@ def create_drq(
critic_subsample_size: Optional[int] = None,
temperature_init: float = 1.0,
image_keys: Iterable[str] = ("image",),
image_augmentation: Iterable[str] = ("random_crop"),
**kwargs,
):
"""
Expand Down Expand Up @@ -231,6 +238,7 @@ def create_drq(
critic_ensemble_size=critic_ensemble_size,
critic_subsample_size=critic_subsample_size,
image_keys=image_keys,
image_augmentation=image_augmentation,
**kwargs,
)

Expand All @@ -243,13 +251,50 @@ def create_drq(

def data_augmentation_fn(self, rng, observations):
for pixel_key in self.config["image_keys"]:
observations = observations.copy(
add_or_replace={
pixel_key: batched_random_crop(
observations[pixel_key], rng, padding=4, num_batch_dims=2
)
}
)
if self.config.get("image_aug_random_crop", False):
# apply random crop to the img observations
observations = observations.copy(
add_or_replace={
pixel_key: batched_random_crop(
observations[pixel_key], rng, padding=4, num_batch_dims=2
)
}
)

if self.config.get("image_aug_color_transform", False):
# NOTE: the original image is in uint8, and the color_transform function
# requires float32, thus we need to convert the image to float32 first
# then convert it back to uint8 after the color transformation
observations = observations.copy(
add_or_replace={
pixel_key: jnp.array(observations[pixel_key], dtype=jnp.float32)
/ 255.0,
}
)
observations = observations.copy(
add_or_replace={
pixel_key: batched_color_transform(
observations[pixel_key],
rng,
brightness=0.1,
contrast=0.1,
saturation=0.1,
hue=0.1,
apply_prob=1.0,
to_grayscale_prob=0.0, # don't convert to grayscale
color_jitter_prob=0.5,
shuffle=False, # wont shuffle the color channels
num_batch_dims=2, # 2 images observations
),
}
)
observations = observations.copy(
add_or_replace={
pixel_key: jnp.array(
observations[pixel_key] * 255.0, dtype=jnp.uint8
),
}
)
return observations

@partial(jax.jit, static_argnames=("utd_ratio", "pmap_axis"))
Expand Down
12 changes: 12 additions & 0 deletions serl_launcher/serl_launcher/networks/reward_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from serl_launcher.vision.resnet_v1 import resnetv1_configs, PreTrainedResNetEncoder
from serl_launcher.common.encoding import EncodingWrapper
from serl_launcher.utils.jax_utils import are_trees_identical
from flax.core.frozen_dict import freeze, unfreeze


Expand Down Expand Up @@ -100,12 +101,23 @@ def load_classifier_func(
Return: a function that takes in an observation
and returns the logits of the classifier.
"""
print("Loading classifier from:", checkpoint_path, "; at step: ", step)
classifier = create_classifier(key, sample, image_keys)
pretrained_chkpt_params = classifier.params
classifier = checkpoints.restore_checkpoint(
checkpoint_path,
target=classifier,
step=step,
)

# restore_checkpoint will not raise error if path not found: https://github.com/google/flax/issues/1631
# To check if checkpoint is loaded correctly
identical = are_trees_identical(pretrained_chkpt_params, classifier.params)
assert not identical, (
"newly loaded params should not be identical to pretrained"
"check if provided checkpoint is loaded correctly"
)

func = lambda obs: classifier.apply_fn(
{"params": classifier.params}, obs, train=False
)
Expand Down
11 changes: 11 additions & 0 deletions serl_launcher/serl_launcher/utils/jax_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import jax
import jax.numpy as jnp


@jax.jit
Expand Down Expand Up @@ -46,6 +47,16 @@ def wrapped(*args, **kwargs):
return wrap_function


def are_trees_identical(tree1, tree2) -> bool:
"""Util function that compares two pytrees element-wise for equality"""

def compare_elements(x, y):
return jnp.array_equal(x, y)

comparison_results = jax.tree_util.tree_map(compare_elements, tree1, tree2)
return jax.tree_util.tree_all(comparison_results)


def init_rng(seed):
global jax_utils_rng
jax_utils_rng = JaxRNG.from_seed(seed)
Expand Down
9 changes: 8 additions & 1 deletion serl_launcher/serl_launcher/utils/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@


def make_bc_agent(
seed, sample_obs, sample_action, image_keys=("image",), encoder_type="small"
seed,
sample_obs,
sample_action,
image_keys=("image",),
encoder_type="small",
image_augmentation=(),
):
return BCAgent.create(
jax.random.PRNGKey(seed),
Expand All @@ -44,6 +49,7 @@ def make_bc_agent(
use_proprio=True,
encoder_type=encoder_type,
image_keys=image_keys,
image_augmentation=image_augmentation,
)


Expand Down Expand Up @@ -112,6 +118,7 @@ def make_drq_agent(
backup_entropy=False,
critic_ensemble_size=10,
critic_subsample_size=2,
image_augmentation=("random_crop"),
)
return agent

Expand Down
57 changes: 56 additions & 1 deletion serl_launcher/serl_launcher/vision/data_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def color_transform(
apply_prob,
shuffle
):
"""Applies color jittering to a single image."""
"""Applies color jittering to a single image. Assume image is in [0, 1]."""
apply_rng, transform_rng = jax.random.split(rng)
perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
transform_rng, 7
Expand Down Expand Up @@ -298,6 +298,61 @@ def _color_jitter(x):
return jnp.clip(out_apply, 0.0, 1.0)


@partial(
jax.jit,
static_argnames=(
"brightness",
"contrast",
"saturation",
"hue",
"to_grayscale_prob",
"color_jitter_prob",
"apply_prob",
"shuffle",
"num_batch_dims",
),
)
def batched_color_transform(
image,
rng,
*,
brightness,
contrast,
saturation,
hue,
to_grayscale_prob,
color_jitter_prob,
apply_prob,
shuffle,
num_batch_dims: int = 1
):
# Flatten batch dims
original_shape = image.shape
image = jnp.reshape(image, (-1, *image.shape[num_batch_dims:]))
rngs = jax.random.split(rng, image.shape[0])

image = jax.vmap(
lambda i, r: color_transform(
i,
r,
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue,
to_grayscale_prob=to_grayscale_prob,
color_jitter_prob=color_jitter_prob,
apply_prob=apply_prob,
shuffle=shuffle,
),
in_axes=(0, 0),
out_axes=0,
)(image, rngs)

# Restore batch dims
image = jnp.reshape(image, original_shape)
return image


def random_flip(image, rng):
_, flip_rng = jax.random.split(rng)
should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
Expand Down
Loading