File tree 8 files changed +52
-20
lines changed
docs/experiments/experiments/DQN
src/ReinforcementLearningEnvironments
src/environments/wrappers
test/environments/wrappers
8 files changed +52
-20
lines changed Original file line number Diff line number Diff line change @@ -80,7 +80,7 @@ function atari_env_factory(
80
80
n_replica = nothing ,
81
81
)
82
82
init (seed) =
83
- RewardOverriddenEnv (
83
+ RewardTransformedEnv (
84
84
StateCachedEnv (
85
85
StateTransformedEnv (
86
86
AtariEnv (;
@@ -101,8 +101,8 @@ function atari_env_factory(
101
101
),
102
102
state_space_mapping= _ -> Space (fill (0 .. 256 , state_size... , n_frames))
103
103
)
104
- ),
105
- r -> clamp (r, - 1 , 1 )
104
+ );
105
+ reward_mapping = r -> clamp (r, - 1 , 1 )
106
106
)
107
107
108
108
if isnothing (n_replica)
130
130
function (hook:: TotalOriginalRewardPerEpisode )(
131
131
:: PostActStage ,
132
132
agent,
133
- env:: RewardOverriddenEnv ,
133
+ env:: RewardTransformedEnv ,
134
134
)
135
135
hook. reward += reward (env. env)
136
136
end
153
153
function (hook:: TotalBatchOriginalRewardPerEpisode )(
154
154
:: PostActStage ,
155
155
agent,
156
- env:: MultiThreadEnv{<:RewardOverriddenEnv } ,
156
+ env:: MultiThreadEnv{<:RewardTransformedEnv } ,
157
157
)
158
158
for (i, e) in enumerate (env. envs)
159
159
hook. reward[i] += reward (e. env)
Original file line number Diff line number Diff line change @@ -85,7 +85,7 @@ function atari_env_factory(
85
85
n_replica = nothing ,
86
86
)
87
87
init (seed) =
88
- RewardOverriddenEnv (
88
+ RewardTransformedEnv (
89
89
StateCachedEnv (
90
90
StateTransformedEnv (
91
91
AtariEnv (;
@@ -106,8 +106,8 @@ function atari_env_factory(
106
106
),
107
107
state_space_mapping= _ -> Space (fill (0 .. 256 , state_size... , n_frames))
108
108
)
109
- ),
110
- r -> clamp (r, - 1 , 1 )
109
+ );
110
+ reward_mapping = r -> clamp (r, - 1 , 1 )
111
111
)
112
112
113
113
if isnothing (n_replica)
135
135
function (hook:: TotalOriginalRewardPerEpisode )(
136
136
:: PostActStage ,
137
137
agent,
138
- env:: RewardOverriddenEnv ,
138
+ env:: RewardTransformedEnv ,
139
139
)
140
140
hook. reward += reward (env. env)
141
141
end
158
158
function (hook:: TotalBatchOriginalRewardPerEpisode )(
159
159
:: PostActStage ,
160
160
agent,
161
- env:: MultiThreadEnv{<:RewardOverriddenEnv } ,
161
+ env:: MultiThreadEnv{<:RewardTransformedEnv } ,
162
162
)
163
163
for (i, e) in enumerate (env. envs)
164
164
hook. reward[i] += reward (e. env)
Original file line number Diff line number Diff line change @@ -84,7 +84,7 @@ function atari_env_factory(
84
84
n_replica = nothing ,
85
85
)
86
86
init (seed) =
87
- RewardOverriddenEnv (
87
+ RewardTransformedEnv (
88
88
StateCachedEnv (
89
89
StateTransformedEnv (
90
90
AtariEnv (;
@@ -105,8 +105,8 @@ function atari_env_factory(
105
105
),
106
106
state_space_mapping= _ -> Space (fill (0 .. 256 , state_size... , n_frames))
107
107
)
108
- ),
109
- r -> clamp (r, - 1 , 1 )
108
+ );
109
+ reward_mapping = r -> clamp (r, - 1 , 1 )
110
110
)
111
111
112
112
if isnothing (n_replica)
134
134
function (hook:: TotalOriginalRewardPerEpisode )(
135
135
:: PostActStage ,
136
136
agent,
137
- env:: RewardOverriddenEnv ,
137
+ env:: RewardTransformedEnv ,
138
138
)
139
139
hook. reward += reward (env. env)
140
140
end
157
157
function (hook:: TotalBatchOriginalRewardPerEpisode )(
158
158
:: PostActStage ,
159
159
agent,
160
- env:: MultiThreadEnv{<:RewardOverriddenEnv } ,
160
+ env:: MultiThreadEnv{<:RewardTransformedEnv } ,
161
161
)
162
162
for (i, e) in enumerate (env. envs)
163
163
hook. reward[i] += reward (e. env)
Original file line number Diff line number Diff line change @@ -24,8 +24,8 @@ function RL.Experiment(
24
24
25
25
env = GridWorlds. SingleRoomUndirectedModule. SingleRoomUndirected (rng= rng)
26
26
env = GridWorlds. RLBaseEnv (env)
27
- env = RLEnvs. StateTransformedEnv (env;state_mapping= x -> vec (Float32 .(x)))
28
- env = RewardOverriddenEnv (env, x -> x - convert (typeof (x), 0.01 ))
27
+ env = RLEnvs. StateTransformedEnv (env; state_mapping= x -> vec (Float32 .(x)))
28
+ env = RewardTransformedEnv (env; reward_mapping = x -> x - convert (typeof (x), 0.01 ))
29
29
env = MaxTimeoutEnv (env, 240 )
30
30
31
31
ns, na = length (state (env)), length (action_space (env))
Original file line number Diff line number Diff line change @@ -3,12 +3,12 @@ export RewardOverriddenEnv
3
3
"""
4
4
RewardOverriddenEnv(env, f)
5
5
6
- Apply `f` on `reward( env)` .
6
+ Apply `f` on `env` to generate a custom reward .
7
7
"""
8
8
struct RewardOverriddenEnv{F,E<: AbstractEnv } <: AbstractEnvWrapper
9
9
env:: E
10
10
f:: F
11
11
end
12
12
13
13
RLBase. reward (env:: RewardOverriddenEnv , args... ; kwargs... ) =
14
- env. f (reward ( env. env, args... ; kwargs... ) )
14
+ env. f (env. env, args... ; kwargs... )
Original file line number Diff line number Diff line change
1
+ export RewardTransformedEnv
2
+
3
+ """
4
+ RewardTransformedEnv(env, f)
5
+
6
+ Apply `f` on `reward(env)`.
7
+ """
8
+ struct RewardTransformedEnv{F,E<: AbstractEnv } <: AbstractEnvWrapper
9
+ env:: E
10
+ reward_mapping:: F
11
+ end
12
+
13
+ RewardTransformedEnv (env; reward_mapping= identity) =
14
+ RewardTransformedEnv (env, reward_mapping)
15
+
16
+ RLBase. reward (env:: RewardTransformedEnv , args... ; kwargs... ) =
17
+ env. reward_mapping (reward (env. env, args... ; kwargs... ))
Original file line number Diff line number Diff line change @@ -30,6 +30,7 @@ include("ActionTransformedEnv.jl")
30
30
include (" DefaultStateStyle.jl" )
31
31
include (" MaxTimeoutEnv.jl" )
32
32
include (" RewardOverriddenEnv.jl" )
33
+ include (" RewardTransformedEnv.jl" )
33
34
include (" StateCachedEnv.jl" )
34
35
include (" StateTransformedEnv.jl" )
35
36
include (" StochasticEnv.jl" )
Original file line number Diff line number Diff line change 53
53
@test is_terminated (env′) == false
54
54
end
55
55
56
+ @testset " RewardTransformedEnv" begin
57
+ rng = StableRNG (123 )
58
+ env = TigerProblemEnv (; rng= rng)
59
+ env′ = RewardTransformedEnv (env; reward_mapping = x -> sign (x))
60
+
61
+ RLBase. test_interfaces! (env′)
62
+ RLBase. test_runnable! (env′)
63
+
64
+ while ! is_terminated (env′)
65
+ env′ (rand (rng, legal_action_space (env′)))
66
+ @test reward (env′) ∈ (- 1 , 0 , 1 )
67
+ end
68
+ end
69
+
56
70
@testset " RewardOverriddenEnv" begin
57
71
rng = StableRNG (123 )
58
72
env = TigerProblemEnv (; rng= rng)
59
- env′ = RewardOverriddenEnv (env, x -> sign (x ))
73
+ env′ = RewardOverriddenEnv (env, e -> sign (reward (e) ))
60
74
61
75
RLBase. test_interfaces! (env′)
62
76
RLBase. test_runnable! (env′)
You can’t perform that action at this time.
0 commit comments