diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 58e69906a..da9562306 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -327,7 +327,7 @@ With the above preparation, we are close to the first learned agent. The followi # ======== callback functions used during training ========= - def save_fn(policy): + def save_best_fn(policy): if hasattr(args, 'model_save_path'): model_save_path = args.model_save_path else: @@ -358,8 +358,9 @@ With the above preparation, we are close to the first learned agent. The followi policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, - logger=logger, test_in_train=False, reward_metric=reward_metric) + stop_fn=stop_fn, save_best_fn=save_best_fn, + update_per_step=args.update_per_step, logger=logger, + test_in_train=False, reward_metric=reward_metric) agent = policy.policies[args.agent_id - 1] # let's watch the match! diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 71c52e0d3..c6c58e7ca 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -133,7 +133,7 @@ def test_c51(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -206,7 +206,7 @@ def watch(): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index c5aacd2b7..7935d9b4b 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -165,7 +165,7 @@ def test_dqn(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -244,7 +244,7 @@ def watch(): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index c4c1b36c8..5d848bc9f 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -150,7 +150,7 @@ def test_fqf(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -223,7 +223,7 @@ def watch(): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False, diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 49eb6e9d5..17d407a89 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -145,7 +145,7 @@ def test_iqn(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -218,7 +218,7 @@ def watch(): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 8c70bf66e..22eff1034 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -209,7 +209,7 @@ def dist(p): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -272,7 +272,7 @@ def watch(): args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, test_in_train=False, resume_from_log=args.resume_id is not None, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 72915efa0..1fa03467e 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -129,7 +129,7 @@ def test_qrdqn(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -202,7 +202,7 @@ def watch(): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False, diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 9b254507f..913df677a 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -162,7 +162,7 @@ def test_rainbow(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -246,7 +246,7 @@ def watch(): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 76246fd3e..8d3bd46b0 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -99,7 +99,7 @@ def test_dqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -132,7 +132,7 @@ def test_fn(epoch, env_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 1d2d7f1c1..f440c8000 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -161,7 +161,7 @@ def test_sac_bipedal(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -180,7 +180,7 @@ def stop_fn(mean_rewards): update_per_step=args.update_per_step, test_in_train=False, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 88f4c397b..cd1b2c2c5 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -100,7 +100,7 @@ def test_dqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -127,7 +127,7 @@ def test_fn(epoch, env_step): stop_fn=stop_fn, train_fn=train_fn, test_fn=test_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 0638e8f61..48436bf67 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -128,7 +128,7 @@ def test_sac(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -146,7 +146,7 @@ def stop_fn(mean_rewards): args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 1e4c0389c..1b9d2da73 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -244,7 +244,7 @@ def dist(*logits): writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -259,7 +259,7 @@ def save_fn(policy): args.test_num, args.batch_size, step_per_collect=args.step_per_collect, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, test_in_train=False ) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 02978697b..b8fc3370c 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -179,7 +179,7 @@ def dist(*logits): writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -194,7 +194,7 @@ def save_fn(policy): args.test_num, args.batch_size, step_per_collect=args.step_per_collect, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, test_in_train=False ) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 8d436b573..51955fe71 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -128,7 +128,7 @@ def test_ddpg(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -142,7 +142,7 @@ def save_fn(policy): args.step_per_collect, args.test_num, args.batch_size, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 23883a119..fd681bca3 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -175,7 +175,7 @@ def dist(*logits): writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -190,7 +190,7 @@ def save_fn(policy): args.test_num, args.batch_size, step_per_collect=args.step_per_collect, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, test_in_train=False ) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 01dc5aa3f..392f1c22a 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -186,7 +186,7 @@ def dist(*logits): writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -201,7 +201,7 @@ def save_fn(policy): args.test_num, args.batch_size, step_per_collect=args.step_per_collect, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, test_in_train=False ) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 914b46251..7d331af43 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -158,7 +158,7 @@ def dist(*logits): writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=10, train_interval=100) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -173,7 +173,7 @@ def save_fn(policy): args.test_num, args.batch_size, step_per_collect=args.step_per_collect, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, test_in_train=False ) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index cb764f473..eb2afe70f 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -151,7 +151,7 @@ def test_sac(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -165,7 +165,7 @@ def save_fn(policy): args.step_per_collect, args.test_num, args.batch_size, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 9e0ca0d82..d2a9bd7bc 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -148,7 +148,7 @@ def test_td3(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -162,7 +162,7 @@ def save_fn(policy): args.step_per_collect, args.test_num, args.batch_size, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index aef324fd5..dd2ce5334 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -180,7 +180,7 @@ def dist(*logits): writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) if not args.watch: @@ -195,7 +195,7 @@ def save_fn(policy): args.test_num, args.batch_size, step_per_collect=args.step_per_collect, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, test_in_train=False ) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 4ccb52df1..f398f64e9 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -150,7 +150,7 @@ def test_discrete_bcq(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -182,7 +182,7 @@ def watch(): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 2515002e4..c8300b3fe 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -135,7 +135,7 @@ def test_discrete_cql(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -167,7 +167,7 @@ def watch(): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index a249928e3..49ef9c3a1 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -155,7 +155,7 @@ def test_discrete_crr(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -186,7 +186,7 @@ def watch(): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index b71303d7e..348f9920d 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -120,7 +120,7 @@ def test_il(args=get_args()): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): @@ -151,7 +151,7 @@ def watch(): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 38da5c104..434763e6d 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -196,7 +196,7 @@ def test_bcq(): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch(): @@ -237,7 +237,7 @@ def watch(): args.step_per_epoch, args.test_num, args.batch_size, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) pprint.pprint(result) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 952737aab..2b7f5ffde 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -191,7 +191,7 @@ def test_cql(): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch(): @@ -232,7 +232,7 @@ def watch(): args.step_per_epoch, args.test_num, args.batch_size, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) pprint.pprint(result) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 208e71ef7..710441a48 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -133,7 +133,7 @@ def test_il(): else: # wandb logger.load(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch(): @@ -174,7 +174,7 @@ def watch(): args.step_per_epoch, args.test_num, args.batch_size, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) pprint.pprint(result) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 53eafae20..9f5aab835 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -125,7 +125,7 @@ def test_c51(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -200,7 +200,7 @@ def watch(): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index ec88cb914..9a219a3ee 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -202,7 +202,7 @@ def dist(p): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -261,7 +261,7 @@ def watch(): args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, test_in_train=False ) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 84885ffb5..cc9aafe5d 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -111,7 +111,7 @@ def test_ddpg(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -129,7 +129,7 @@ def stop_fn(mean_rewards): args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 483335053..b7ec50c82 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -134,7 +134,7 @@ def dist(*logits): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -152,7 +152,7 @@ def stop_fn(mean_rewards): args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index f187b0f68..e0858edf1 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -129,7 +129,7 @@ def dist(*logits): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -168,7 +168,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 8930b0839..a204e55e8 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -138,7 +138,7 @@ def test_sac_with_il(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -156,7 +156,7 @@ def stop_fn(mean_rewards): args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) @@ -198,7 +198,7 @@ def stop_fn(mean_rewards): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index d7ee186fc..ac557ba12 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -129,7 +129,7 @@ def test_td3(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -147,7 +147,7 @@ def stop_fn(mean_rewards): args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) for epoch, epoch_stat, info in trainer: diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 92228b781..cc8b39f61 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -138,7 +138,7 @@ def dist(*logits): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -156,7 +156,7 @@ def stop_fn(mean_rewards): args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 977aafab2..b777e2e03 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -96,7 +96,7 @@ def test_a2c_with_il(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -114,7 +114,7 @@ def stop_fn(mean_rewards): args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) @@ -152,7 +152,7 @@ def stop_fn(mean_rewards): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index fa1db2ffd..993c4a80e 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -118,7 +118,7 @@ def test_c51(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -183,7 +183,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index d0a9f0833..2644dc998 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -109,7 +109,7 @@ def test_dqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -143,7 +143,7 @@ def test_fn(epoch, env_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 496420a97..36ff5fa76 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -96,7 +96,7 @@ def test_drqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -122,7 +122,7 @@ def test_fn(epoch, env_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 73cc2bd93..e25c42997 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -126,7 +126,7 @@ def test_fqf(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -159,7 +159,7 @@ def test_fn(epoch, env_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step ) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index c7e074c45..725c9a9d5 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -120,7 +120,7 @@ def test_iqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -153,7 +153,7 @@ def test_fn(epoch, env_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step ) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index c3210b2ab..1f5007f7a 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -97,7 +97,7 @@ def test_pg(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -115,7 +115,7 @@ def stop_fn(mean_rewards): args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index c0b7e1b88..b7dba97c9 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -122,7 +122,7 @@ def test_ppo(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -140,7 +140,7 @@ def stop_fn(mean_rewards): args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index eaa4dc0bd..9c699b4e3 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -113,7 +113,7 @@ def test_qrdqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -146,7 +146,7 @@ def test_fn(epoch, env_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, ) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 1e4d0a37b..5e2345300 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -132,7 +132,7 @@ def noisy_linear(x, y): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -207,7 +207,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index d53268288..6593c9864 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -114,7 +114,7 @@ def test_discrete_sac(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -131,7 +131,7 @@ def stop_fn(mean_rewards): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index d5d9aaa37..fba0b5523 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -148,7 +148,7 @@ def test_dqn_icm(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -182,7 +182,7 @@ def test_fn(epoch, env_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, ) assert stop_fn(result['best_reward']) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index f87f51a0e..6efd96277 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -154,7 +154,7 @@ def test_ppo(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -172,7 +172,7 @@ def stop_fn(mean_rewards): args.batch_size, step_per_collect=args.step_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) assert stop_fn(result['best_reward']) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 12abf8bd7..2e01723f3 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -117,7 +117,7 @@ def gather_data(): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -150,7 +150,7 @@ def test_fn(epoch, env_step): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, ) diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 23864587d..f60ec3696 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -147,7 +147,7 @@ def gather_data(): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -164,7 +164,7 @@ def stop_fn(mean_rewards): args.test_num, args.batch_size, update_per_step=args.update_per_step, - save_fn=save_fn, + save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, ) diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index ca3c2c974..dccb2ec90 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -181,7 +181,7 @@ def test_bcq(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -206,7 +206,7 @@ def watch(): args.step_per_epoch, args.test_num, args.batch_size, - save_fn=save_fn, + save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, ) diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 91a2784df..6d52bca7a 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -178,7 +178,7 @@ def test_cql(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -203,7 +203,7 @@ def watch(): args.step_per_epoch, args.test_num, args.batch_size, - save_fn=save_fn, + save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, ) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 179311766..51380fb62 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -108,7 +108,7 @@ def test_discrete_bcq(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -144,7 +144,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 2f275e8b1..eca810fb6 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -103,7 +103,7 @@ def test_discrete_cql(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -118,7 +118,7 @@ def stop_fn(mean_rewards): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index d58178093..9f47b32e7 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -106,7 +106,7 @@ def test_discrete_crr(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -121,7 +121,7 @@ def stop_fn(mean_rewards): args.test_num, args.batch_size, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger ) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index f68b7ae74..3ccaeaca7 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -167,7 +167,7 @@ def dist(*logits): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_fn(policy): + def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): @@ -206,7 +206,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 8ce458788..ecc5e0222 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -135,7 +135,7 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): pass def stop_fn(mean_rewards): @@ -163,7 +163,7 @@ def reward_metric(rews): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, update_per_step=args.update_per_step, logger=logger, test_in_train=False, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 7a7ffe341..33072b50f 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -230,7 +230,7 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): pass def stop_fn(mean_rewards): @@ -257,7 +257,7 @@ def reward_metric(rews): args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume ) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 2bc6a72f1..69af6a35a 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -178,7 +178,7 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_fn(policy): + def save_best_fn(policy): if hasattr(args, 'model_save_path'): model_save_path = args.model_save_path else: @@ -214,7 +214,7 @@ def reward_metric(rews): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, update_per_step=args.update_per_step, logger=logger, test_in_train=False, diff --git a/tianshou/__init__.py b/tianshou/__init__.py index cd337d188..5a9abbf1b 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,6 +1,6 @@ from tianshou import data, env, exploration, policy, trainer, utils -__version__ = "0.4.6.post1" +__version__ = "0.4.7" __all__ = [ "env", diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 3ae1b1618..b861a15d5 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -1,10 +1,11 @@ -import warnings from abc import ABC, abstractmethod from typing import Any, Callable, List, Optional, Tuple, Union import gym import numpy as np +from tianshou.utils import deprecation + class EnvWorker(ABC): """An abstract worker for an environment.""" @@ -33,7 +34,7 @@ def send(self, action: Optional[np.ndarray]) -> None: function is determined by such kind of different signal. """ if hasattr(self, "send_action"): - warnings.warn( + deprecation( "send_action will soon be deprecated. " "Please use send and recv for your own EnvWorker." ) @@ -54,7 +55,7 @@ def recv( info). """ if hasattr(self, "get_result"): - warnings.warn( + deprecation( "get_result will soon be deprecated. " "Please use send and recv for your own EnvWorker." ) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index fca1036ff..60525f67b 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.policy import BasePolicy from tianshou.trainer.utils import gather_info, test_episode -from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, deprecation, tqdm_config class BaseTrainer(ABC): @@ -48,9 +48,9 @@ class BaseTrainer(ABC): :param function test_fn: a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean + :param function save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. + ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. :param function save_checkpoint_fn: a function to save training process, with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can save whatever you want. @@ -137,14 +137,23 @@ def __init__( train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_best_fn: Optional[Callable[[BasePolicy], None]] = None, save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, + save_fn: Optional[Callable[[BasePolicy], None]] = None, ): + if save_fn: + deprecation( + "save_fn in trainer is marked as deprecated and will be " + "removed in the future. Please use save_best_fn instead." + ) + assert save_best_fn is None + save_best_fn = save_fn + self.policy = policy self.buffer = buffer @@ -176,7 +185,7 @@ def __init__( self.train_fn = train_fn self.test_fn = test_fn self.stop_fn = stop_fn - self.save_fn = save_fn + self.save_best_fn = save_best_fn self.save_checkpoint_fn = save_checkpoint_fn self.reward_metric = reward_metric @@ -220,8 +229,8 @@ def reset(self) -> None: self.best_epoch = self.start_epoch self.best_reward, self.best_reward_std = \ test_result["rew"], test_result["rew_std"] - if self.save_fn: - self.save_fn(self.policy) + if self.save_best_fn: + self.save_best_fn(self.policy) self.epoch = self.start_epoch self.stop_fn_flag = False @@ -240,8 +249,6 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: # iterator exhaustion check if self.epoch >= self.max_epoch: - if self.test_collector is None and self.save_fn: - self.save_fn(self.policy) raise StopIteration # exit flag 1, when stop_fn succeeds in train_step or test_step @@ -321,8 +328,8 @@ def test_step(self) -> Tuple[Dict[str, Any], bool]: self.best_epoch = self.epoch self.best_reward = float(rew) self.best_reward_std = rew_std - if self.save_fn: - self.save_fn(self.policy) + if self.save_best_fn: + self.save_best_fn(self.policy) if self.verbose: print( f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 890429a8b..82d9aa32a 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -28,9 +28,9 @@ class OfflineTrainer(BaseTrainer): epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean + :param function save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. + ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. :param function save_checkpoint_fn: a function to save training process, with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can save whatever you want. Because offline-RL doesn't have @@ -64,12 +64,13 @@ def __init__( batch_size: int, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_best_fn: Optional[Callable[[BasePolicy], None]] = None, save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, + **kwargs: Any, ): super().__init__( learning_type="offline", @@ -83,12 +84,13 @@ def __init__( batch_size=batch_size, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, save_checkpoint_fn=save_checkpoint_fn, resume_from_log=resume_from_log, reward_metric=reward_metric, logger=logger, verbose=verbose, + **kwargs, ) def policy_update_fn( diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index c3580397a..e7be852a7 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -37,9 +37,9 @@ class OffpolicyTrainer(BaseTrainer): :param function test_fn: a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean + :param function save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. + ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. :param function save_checkpoint_fn: a function to save training process, with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can save whatever you want. @@ -77,13 +77,14 @@ def __init__( train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_best_fn: Optional[Callable[[BasePolicy], None]] = None, save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, + **kwargs: Any, ): super().__init__( learning_type="offpolicy", @@ -99,13 +100,14 @@ def __init__( train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, save_checkpoint_fn=save_checkpoint_fn, resume_from_log=resume_from_log, reward_metric=reward_metric, logger=logger, verbose=verbose, test_in_train=test_in_train, + **kwargs, ) def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 46b195a70..a2234e710 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -39,9 +39,9 @@ class OnpolicyTrainer(BaseTrainer): :param function test_fn: a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean + :param function save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. + ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. :param function save_checkpoint_fn: a function to save training process, with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can save whatever you want. @@ -85,13 +85,14 @@ def __init__( train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_best_fn: Optional[Callable[[BasePolicy], None]] = None, save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, + **kwargs: Any, ): super().__init__( learning_type="onpolicy", @@ -108,13 +109,14 @@ def __init__( train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, + save_best_fn=save_best_fn, save_checkpoint_fn=save_checkpoint_fn, resume_from_log=resume_from_log, reward_metric=reward_metric, logger=logger, verbose=verbose, test_in_train=test_in_train, + **kwargs, ) def policy_update_fn( diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 25ceda186..2d7ea4906 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -5,8 +5,16 @@ from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.statistics import MovAvg, RunningMeanStd +from tianshou.utils.warning import deprecation __all__ = [ - "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger", - "BasicLogger", "LazyLogger", "WandbLogger" + "MovAvg", + "RunningMeanStd", + "tqdm_config", + "BaseLogger", + "TensorboardLogger", + "BasicLogger", + "LazyLogger", + "WandbLogger", + "deprecation", ] diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 469d32765..bc43bae1a 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -1,10 +1,10 @@ -import warnings from typing import Any, Callable, Optional, Tuple from tensorboard.backend.event_processing import event_accumulator from torch.utils.tensorboard import SummaryWriter from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger +from tianshou.utils.warning import deprecation class TensorboardLogger(BaseLogger): @@ -81,7 +81,8 @@ class BasicLogger(TensorboardLogger): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - warnings.warn( - "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427." + deprecation( + "Class BasicLogger is marked as deprecated and will be removed soon. " + "Please use TensorboardLogger instead." ) super().__init__(*args, **kwargs) diff --git a/tianshou/utils/warning.py b/tianshou/utils/warning.py new file mode 100644 index 000000000..93c5ccec3 --- /dev/null +++ b/tianshou/utils/warning.py @@ -0,0 +1,8 @@ +import warnings + +warnings.simplefilter("once", DeprecationWarning) + + +def deprecation(msg: str) -> None: + """Deprecation warning wrapper.""" + warnings.warn(msg, category=DeprecationWarning, stacklevel=2)