Skip to content

Commit

Permalink
rename save_fn to save_best_fn to avoid ambiguity (#575)
Browse files Browse the repository at this point in the history
This PR also introduces `tianshou.utils.deprecation` for a unified deprecation wrapper.
  • Loading branch information
Trinkle23897 authored Mar 21, 2022
1 parent 10d9190 commit 2a9c928
Show file tree
Hide file tree
Showing 69 changed files with 187 additions and 155 deletions.
7 changes: 4 additions & 3 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ With the above preparation, we are close to the first learned agent. The followi

# ======== callback functions used during training =========

def save_fn(policy):
def save_best_fn(policy):
if hasattr(args, 'model_save_path'):
model_save_path = args.model_save_path
else:
Expand Down Expand Up @@ -358,8 +358,9 @@ With the above preparation, we are close to the first learned agent. The followi
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
logger=logger, test_in_train=False, reward_metric=reward_metric)
stop_fn=stop_fn, save_best_fn=save_best_fn,
update_per_step=args.update_per_step, logger=logger,
test_in_train=False, reward_metric=reward_metric)

agent = policy.policies[args.agent_id - 1]
# let's watch the match!
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_c51(args=get_args()):
else: # wandb
logger.load(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -206,7 +206,7 @@ def watch():
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_dqn(args=get_args()):
else: # wandb
logger.load(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -244,7 +244,7 @@ def watch():
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_fqf(args=get_args()):
else: # wandb
logger.load(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -223,7 +223,7 @@ def watch():
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_iqn(args=get_args()):
else: # wandb
logger.load(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -218,7 +218,7 @@ def watch():
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def dist(p):
else: # wandb
logger.load(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -272,7 +272,7 @@ def watch():
args.batch_size,
step_per_collect=args.step_per_collect,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False,
resume_from_log=args.resume_id is not None,
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_qrdqn(args=get_args()):
else: # wandb
logger.load(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -202,7 +202,7 @@ def watch():
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_rainbow(args=get_args()):
else: # wandb
logger.load(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -246,7 +246,7 @@ def watch():
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_dqn(args=get_args()):
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_fn(epoch, env_step):
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger
)

Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_sac_bipedal(args=get_args()):
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
Expand All @@ -180,7 +180,7 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step,
test_in_train=False,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger
)

Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_dqn(args=get_args()):
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
Expand All @@ -127,7 +127,7 @@ def test_fn(epoch, env_step):
stop_fn=stop_fn,
train_fn=train_fn,
test_fn=test_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger
)

Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_sac(args=get_args()):
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
Expand All @@ -146,7 +146,7 @@ def stop_fn(mean_rewards):
args.batch_size,
update_per_step=args.update_per_step,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger
)

Expand Down
4 changes: 2 additions & 2 deletions examples/inverse/irl_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def dist(*logits):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -259,7 +259,7 @@ def save_fn(policy):
args.test_num,
args.batch_size,
step_per_collect=args.step_per_collect,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False
)
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def dist(*logits):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -194,7 +194,7 @@ def save_fn(policy):
args.test_num,
args.batch_size,
step_per_collect=args.step_per_collect,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False
)
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_ddpg(args=get_args()):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -142,7 +142,7 @@ def save_fn(policy):
args.step_per_collect,
args.test_num,
args.batch_size,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def dist(*logits):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -190,7 +190,7 @@ def save_fn(policy):
args.test_num,
args.batch_size,
step_per_collect=args.step_per_collect,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False
)
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def dist(*logits):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -201,7 +201,7 @@ def save_fn(policy):
args.test_num,
args.batch_size,
step_per_collect=args.step_per_collect,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False
)
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def dist(*logits):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=10, train_interval=100)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -173,7 +173,7 @@ def save_fn(policy):
args.test_num,
args.batch_size,
step_per_collect=args.step_per_collect,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False
)
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_sac(args=get_args()):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -165,7 +165,7 @@ def save_fn(policy):
args.step_per_collect,
args.test_num,
args.batch_size,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_td3(args=get_args()):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -162,7 +162,7 @@ def save_fn(policy):
args.step_per_collect,
args.test_num,
args.batch_size,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def dist(*logits):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
Expand All @@ -195,7 +195,7 @@ def save_fn(policy):
args.test_num,
args.batch_size,
step_per_collect=args.step_per_collect,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False
)
Expand Down
4 changes: 2 additions & 2 deletions examples/offline/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_discrete_bcq(args=get_args()):
else: # wandb
logger.load(writer)

def save_fn(policy):
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
Expand Down Expand Up @@ -182,7 +182,7 @@ def watch():
args.test_num,
args.batch_size,
stop_fn=stop_fn,
save_fn=save_fn,
save_best_fn=save_best_fn,
logger=logger,
)

Expand Down
Loading

0 comments on commit 2a9c928

Please sign in to comment.