Skip to content

Commit

Permalink
feature(zjow): add new pipeline agent sac/ddpg/a2c (#637)
Browse files Browse the repository at this point in the history
* polish code

* fix data type error for mujoco

* polish code

* polish code

* Add features

* fix base env manager readyimage

* polish code

* remove NoReturn

* remove NoReturn

* format code

* format code

* polish code

* polish code

* fix logger

* format code

* format code

* change api for ckpt; polish code

* polish code

* format code

* polish code

* fix load bug

* fix bug

* fix dtype error

* polish code

* polish code

* Add dqn agent

* add config

* add bonus/c51.py

* add c51 logit monitor

* add sac dqn agent

* add sac dqn agent demo in dizoo

* polish format

* polish code

* polish code

* fix ddpg bug

* merge nyz c51/dqn config and policy

* fix config

* remove mutistep_trainer

* fix bug

* polish code

* polish code

* polish code

* add Hopper demo

* polish code

* add property best

* add a2c pipeline

* add sac halfcheetah+walker2d

* fix a2c pipeline bug

* fix pipeline bug

* fix bug

* change config

* remove IMPALA pipeline

* format code

* polish code

* polish c51 and add ddpg halfcheetah walker2d

* add dizoo/common for zjow to review

* fix agent best method

* reset dizoo

* delete common

* polish for zjow to review

* polish code

* polish code

* fix bug

* fix bug

* polish c51

* add pg agent

* add pendulum config

* add c51_atari td3_pendulum,bipedalwalker ddpg_pendulum

* polish code

* polish code

* polish code

* add bipedalwalker_ddpg_config

* change config

* change bipedalwalker config and noframeskip

* polish c51-atari name

* add pong spaceinvaders and qbert for dqn

* polish code

* polish code; add env mode

* add rew_clip in ding_env_wrapper

* polish dqn atari

* add a2c continuous action space

* add a2c continuous action space

* add a2c continuous for mujoco

* add a2c continuous for mujoco

* add a2c continuous for mujoco

* add a2c mujoco config; add ppo atari config

* add a2c mujoco config; add ppo atari config

* fix a2c deploy bug

* Add bipedalwalker a2c

* polish code

* polish code

* polish code

* polish code

* polish code

* add pendulum a2c+pg

* add pg bipedalwalker+mujoco

* polish code for wandb sweep

* polish code for wandb sweep

* polish code for wandb sweep

* polish code for a2c mujoco

* add pg pendulum new pipeline

* fix scalar action bug in random collect

* polish pg algorithm

* add bonus pg config

* polish pg config

* polish config

* polish code

* change pendulum pg config

* fix continuous action dim=1 bug

* Add ppof lr scheduler

* polish config

* fix random collect bug for dqn

* polish ppo qbert spaceinvader config

* remove mujoco wrapper

* polish a2c mujoco config; add ppo offpolicy agent pipeline

* Add wandb monitor evaluate return std

* polish deploy method

* format code

* polish code

* polish pg pendulum+hopper config

* fix data shape bug

* fix ppo offpolicy deploy bug

* fix mujoco reward action env clip bug

* fix mujoco reward action env clip bug

* fix deploy env mode bug

* fix env reset bug for deployment and evaluation

* Add ppo offpolicy atari config

* polish config

* polish config code

* polish code; add SQL

* polish code

* change config path

* add compatibility fix for nstep

* polish code

* Add ppo offpolicy continuous policy

* polish config

* add ppo offpolicy general action modeling

* add dependencies

* polish config

* polish deploy

* Add array video helper

* polish deploy

* polish config

* polish setup

* fix config bug

* polish code

* polish code

* polish code

* fix bug in evaluator

* polish code

* fix bug in ckpt_saver order

* fix format

* fix bug in reward shape

* format type

* polish code

* fix nstep error for ppo offpolicy

* fix bug in action shape of cql when dim is 1

* polish code

* delete config not work

* polish code; remove ppof general datatype

* remove useless code

* polish code

* polish code

* fix a2c unittest

* fix advantages_estimator unittest

* fix combination_argmax_sample unittest

* fix unittest bug

* fix wandb logger unittest bug

* polish code

* move config position

* remove useless config

* polish code

* add unittest for montecarlo_return_estimator

* fix bug in termination checker

* polish code


---------

Co-authored-by: zhangpaipai <[email protected]>
Co-authored-by: Ruoyu Gao <[email protected]>
Co-authored-by: Swain <[email protected]>
  • Loading branch information
4 people authored Sep 14, 2023
1 parent a07cde2 commit a37981e
Show file tree
Hide file tree
Showing 114 changed files with 6,014 additions and 339 deletions.
132 changes: 131 additions & 1 deletion ding/bonus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,132 @@
import ding.config
from .a2c import A2CAgent
from .c51 import C51Agent
from .ddpg import DDPGAgent
from .dqn import DQNAgent
from .pg import PGAgent
from .ppof import PPOF
from .td3 import TD3OffPolicyAgent
from .ppo_offpolicy import PPOOffPolicyAgent
from .sac import SACAgent
from .sql import SQLAgent
from .td3 import TD3Agent

supported_algo = dict(
A2C=A2CAgent,
C51=C51Agent,
DDPG=DDPGAgent,
DQN=DQNAgent,
PG=PGAgent,
PPOF=PPOF,
PPOOffPolicy=PPOOffPolicyAgent,
SAC=SACAgent,
SQL=SQLAgent,
TD3=TD3Agent,
)

supported_algo_list = list(supported_algo.keys())


def env_supported(algo: str = None) -> list:
"""
return list of the envs that supported by di-engine.
"""

if algo is not None:
if algo.upper() == "A2C":
return list(ding.config.example.A2C.supported_env.keys())
elif algo.upper() == "C51":
return list(ding.config.example.C51.supported_env.keys())
elif algo.upper() == "DDPG":
return list(ding.config.example.DDPG.supported_env.keys())
elif algo.upper() == "DQN":
return list(ding.config.example.DQN.supported_env.keys())
elif algo.upper() == "PG":
return list(ding.config.example.PG.supported_env.keys())
elif algo.upper() == "PPOF":
return list(ding.config.example.PPOF.supported_env.keys())
elif algo.upper() == "PPOOFFPOLICY":
return list(ding.config.example.PPOOffPolicy.supported_env.keys())
elif algo.upper() == "SAC":
return list(ding.config.example.SAC.supported_env.keys())
elif algo.upper() == "SQL":
return list(ding.config.example.SQL.supported_env.keys())
elif algo.upper() == "TD3":
return list(ding.config.example.TD3.supported_env.keys())
else:
raise ValueError("The algo {} is not supported by di-engine.".format(algo))
else:
supported_env = set()
supported_env.update(ding.config.example.A2C.supported_env.keys())
supported_env.update(ding.config.example.C51.supported_env.keys())
supported_env.update(ding.config.example.DDPG.supported_env.keys())
supported_env.update(ding.config.example.DQN.supported_env.keys())
supported_env.update(ding.config.example.PG.supported_env.keys())
supported_env.update(ding.config.example.PPOF.supported_env.keys())
supported_env.update(ding.config.example.PPOOffPolicy.supported_env.keys())
supported_env.update(ding.config.example.SAC.supported_env.keys())
supported_env.update(ding.config.example.SQL.supported_env.keys())
supported_env.update(ding.config.example.TD3.supported_env.keys())
# return the list of the envs
return list(supported_env)


supported_env = env_supported()


def algo_supported(env_id: str = None) -> list:
"""
return list of the algos that supported by di-engine.
"""
if env_id is not None:
algo = []
if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]:
algo.append("A2C")
if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]:
algo.append("C51")
if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]:
algo.append("DDPG")
if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]:
algo.append("DQN")
if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]:
algo.append("PG")
if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]:
algo.append("PPOF")
if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]:
algo.append("PPOOffPolicy")
if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]:
algo.append("SAC")
if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]:
algo.append("SQL")
if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]:
algo.append("TD3")

if len(algo) == 0:
raise ValueError("The env {} is not supported by di-engine.".format(env_id))
return algo
else:
return supported_algo_list


def is_supported(env_id: str = None, algo: str = None) -> bool:
"""
Check if the env-algo pair is supported by di-engine.
"""
if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]:
if algo is not None and algo.upper() in supported_algo_list:
if env_id.upper() in env_supported(algo):
return True
else:
return False
elif algo is None:
return True
else:
return False
elif env_id is None:
if algo is not None and algo.upper() in supported_algo_list:
return True
elif algo is None:
raise ValueError("Please specify the env or algo.")
else:
return False
else:
return False
Loading

0 comments on commit a37981e

Please sign in to comment.