Skip to content

Commit

Permalink
polish readme
Browse files Browse the repository at this point in the history
  • Loading branch information
‘whl’ committed Sep 19, 2024
1 parent 10f6626 commit d5292bf
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3
- Exploration algorithms: HER, RND, ICM, NGU
- LLM + RL Algorithms: PPO-max, DPO, PromptPG
- LLM + RL Algorithms: PPO-max, DPO, PromptPG, PromptAWR
- Other algorithms: such as PER, PLR, PCGrad
- MCTS + RL algorithms: AlphaZero, MuZero, please refer to [LightZero](https://github.com/opendilab/LightZero)
- Generative Model + RL algorithms: Diffusion-QL, QGPO, SRPO, please refer to [GenerativeRL](https://github.com/opendilab/GenerativeRL)
Expand Down Expand Up @@ -283,6 +283,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 57 | [AWR](https://arxiv.org/pdf/1910.00177) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_awr.py) | python3 -u tabmwp_awr_config.py |

</details>

Expand Down
66 changes: 66 additions & 0 deletions dizoo/tabmwp/config/tabmwp_awr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from easydict import EasyDict

tabmwp_prompt_pg_config = dict(
exp_name='tabmwp_prompt_pg_seed0',
env=dict(
collector_env_num=1,
evaluator_env_num=1,
n_evaluator_episode=1,
stop_value=1,
cand_number=16,
train_number=80,
engine='text-davinci-002',
temperature=0.,
max_tokens=512,
top_p=1.,
frequency_penalty=0.,
presence_penalty=0.,
option_inds=["A", "B", "C", "D", "E", "F"],
# The API-key of openai. You can get your key in this website: https://platform.openai.com/
api_key='',
enable_replay=True,
prompt_format='TQ-A',
seed=0,
),
policy=dict(
cuda=True,
shot_number=2,
model=dict(
model_name="bert-base-uncased",
add_linear=True,
freeze_encoder=True,
embedding_size=128,
),
learn=dict(
batch_size=10,
# (bool) Whether to normalize advantage. Default to False.
learning_rate=0.001,
# (float) loss weight of the value network, the weight of policy network is set to 1
entropy_weight=0.001,
weight_decay=5e-3,
grad_norm=0.5,
),
collect=dict(
# (int) collect n_sample data, train model 1 times
n_sample=20,
discount_factor=0.,
),
eval=dict(evaluator=dict(eval_freq=500, )),
),
)
main_config = EasyDict(tabmwp_prompt_pg_config)

tabmwp_prompt_pg_config = dict(
env=dict(
type='tabmwp',
import_names=['dizoo.tabmwp.envs.tabmwp_env'],
),
env_manager=dict(type='base'),
policy=dict(type='prompt_awr'),
replay_buffer=dict(type='naive'),
)
create_config = EasyDict(tabmwp_prompt_pg_config)

if __name__ == '__main__':
from ding.entry import serial_pipeline_onpolicy
serial_pipeline_onpolicy((main_config, create_config), seed=0)

0 comments on commit d5292bf

Please sign in to comment.