Skip to content

Commit e2c4673

Browse files
albheimfindmyway
andauthored
Update reward wrappers to be more consistent (#519)
* rewardoverridden -> rewardtransformed * minor updates * Update src/ReinforcementLearningEnvironments/src/environments/wrappers/RewardOverriddenEnv.jl Co-authored-by: Jun Tian <[email protected]>
1 parent b00d9bd commit e2c4673

File tree

8 files changed

+52
-20
lines changed

8 files changed

+52
-20
lines changed

docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function atari_env_factory(
8080
n_replica = nothing,
8181
)
8282
init(seed) =
83-
RewardOverriddenEnv(
83+
RewardTransformedEnv(
8484
StateCachedEnv(
8585
StateTransformedEnv(
8686
AtariEnv(;
@@ -101,8 +101,8 @@ function atari_env_factory(
101101
),
102102
state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
103103
)
104-
),
105-
r -> clamp(r, -1, 1)
104+
);
105+
reward_mapping = r -> clamp(r, -1, 1)
106106
)
107107

108108
if isnothing(n_replica)
@@ -130,7 +130,7 @@ end
130130
function (hook::TotalOriginalRewardPerEpisode)(
131131
::PostActStage,
132132
agent,
133-
env::RewardOverriddenEnv,
133+
env::RewardTransformedEnv,
134134
)
135135
hook.reward += reward(env.env)
136136
end
@@ -153,7 +153,7 @@ end
153153
function (hook::TotalBatchOriginalRewardPerEpisode)(
154154
::PostActStage,
155155
agent,
156-
env::MultiThreadEnv{<:RewardOverriddenEnv},
156+
env::MultiThreadEnv{<:RewardTransformedEnv},
157157
)
158158
for (i, e) in enumerate(env.envs)
159159
hook.reward[i] += reward(e.env)

docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function atari_env_factory(
8585
n_replica = nothing,
8686
)
8787
init(seed) =
88-
RewardOverriddenEnv(
88+
RewardTransformedEnv(
8989
StateCachedEnv(
9090
StateTransformedEnv(
9191
AtariEnv(;
@@ -106,8 +106,8 @@ function atari_env_factory(
106106
),
107107
state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
108108
)
109-
),
110-
r -> clamp(r, -1, 1)
109+
);
110+
reward_mapping = r -> clamp(r, -1, 1)
111111
)
112112

113113
if isnothing(n_replica)
@@ -135,7 +135,7 @@ end
135135
function (hook::TotalOriginalRewardPerEpisode)(
136136
::PostActStage,
137137
agent,
138-
env::RewardOverriddenEnv,
138+
env::RewardTransformedEnv,
139139
)
140140
hook.reward += reward(env.env)
141141
end
@@ -158,7 +158,7 @@ end
158158
function (hook::TotalBatchOriginalRewardPerEpisode)(
159159
::PostActStage,
160160
agent,
161-
env::MultiThreadEnv{<:RewardOverriddenEnv},
161+
env::MultiThreadEnv{<:RewardTransformedEnv},
162162
)
163163
for (i, e) in enumerate(env.envs)
164164
hook.reward[i] += reward(e.env)

docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function atari_env_factory(
8484
n_replica = nothing,
8585
)
8686
init(seed) =
87-
RewardOverriddenEnv(
87+
RewardTransformedEnv(
8888
StateCachedEnv(
8989
StateTransformedEnv(
9090
AtariEnv(;
@@ -105,8 +105,8 @@ function atari_env_factory(
105105
),
106106
state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
107107
)
108-
),
109-
r -> clamp(r, -1, 1)
108+
);
109+
reward_mapping = r -> clamp(r, -1, 1)
110110
)
111111

112112
if isnothing(n_replica)
@@ -134,7 +134,7 @@ end
134134
function (hook::TotalOriginalRewardPerEpisode)(
135135
::PostActStage,
136136
agent,
137-
env::RewardOverriddenEnv,
137+
env::RewardTransformedEnv,
138138
)
139139
hook.reward += reward(env.env)
140140
end
@@ -157,7 +157,7 @@ end
157157
function (hook::TotalBatchOriginalRewardPerEpisode)(
158158
::PostActStage,
159159
agent,
160-
env::MultiThreadEnv{<:RewardOverriddenEnv},
160+
env::MultiThreadEnv{<:RewardTransformedEnv},
161161
)
162162
for (i, e) in enumerate(env.envs)
163163
hook.reward[i] += reward(e.env)

docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ function RL.Experiment(
2424

2525
env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng=rng)
2626
env = GridWorlds.RLBaseEnv(env)
27-
env = RLEnvs.StateTransformedEnv(env;state_mapping=x -> vec(Float32.(x)))
28-
env = RewardOverriddenEnv(env, x -> x - convert(typeof(x), 0.01))
27+
env = RLEnvs.StateTransformedEnv(env; state_mapping=x -> vec(Float32.(x)))
28+
env = RewardTransformedEnv(env; reward_mapping = x -> x - convert(typeof(x), 0.01))
2929
env = MaxTimeoutEnv(env, 240)
3030

3131
ns, na = length(state(env)), length(action_space(env))

src/ReinforcementLearningEnvironments/src/environments/wrappers/RewardOverriddenEnv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ export RewardOverriddenEnv
33
"""
44
RewardOverriddenEnv(env, f)
55
6-
Apply `f` on `reward(env)`.
6+
Apply `f` on `env` to generate a custom reward.
77
"""
88
struct RewardOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
99
env::E
1010
f::F
1111
end
1212

1313
RLBase.reward(env::RewardOverriddenEnv, args...; kwargs...) =
14-
env.f(reward(env.env, args...; kwargs...))
14+
env.f(env.env, args...; kwargs...)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
export RewardTransformedEnv
2+
3+
"""
4+
RewardTransformedEnv(env, f)
5+
6+
Apply `f` on `reward(env)`.
7+
"""
8+
struct RewardTransformedEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
9+
env::E
10+
reward_mapping::F
11+
end
12+
13+
RewardTransformedEnv(env; reward_mapping=identity) =
14+
RewardTransformedEnv(env, reward_mapping)
15+
16+
RLBase.reward(env::RewardTransformedEnv, args...; kwargs...) =
17+
env.reward_mapping(reward(env.env, args...; kwargs...))

src/ReinforcementLearningEnvironments/src/environments/wrappers/wrappers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include("ActionTransformedEnv.jl")
3030
include("DefaultStateStyle.jl")
3131
include("MaxTimeoutEnv.jl")
3232
include("RewardOverriddenEnv.jl")
33+
include("RewardTransformedEnv.jl")
3334
include("StateCachedEnv.jl")
3435
include("StateTransformedEnv.jl")
3536
include("StochasticEnv.jl")

src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,24 @@
5353
@test is_terminated(env′) == false
5454
end
5555

56+
@testset "RewardTransformedEnv" begin
57+
rng = StableRNG(123)
58+
env = TigerProblemEnv(; rng=rng)
59+
env′ = RewardTransformedEnv(env; reward_mapping = x -> sign(x))
60+
61+
RLBase.test_interfaces!(env′)
62+
RLBase.test_runnable!(env′)
63+
64+
while !is_terminated(env′)
65+
env′(rand(rng, legal_action_space(env′)))
66+
@test reward(env′) (-1, 0, 1)
67+
end
68+
end
69+
5670
@testset "RewardOverriddenEnv" begin
5771
rng = StableRNG(123)
5872
env = TigerProblemEnv(; rng=rng)
59-
env′ = RewardOverriddenEnv(env, x -> sign(x))
73+
env′ = RewardOverriddenEnv(env, e -> sign(reward(e)))
6074

6175
RLBase.test_interfaces!(env′)
6276
RLBase.test_runnable!(env′)

0 commit comments

Comments
 (0)