Skip to content

Commit

Permalink
Add contributing guide file.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jun 13, 2024
1 parent abfd595 commit abcb6d2
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 27 deletions.
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[Git Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html)

[GitHub Cooperation Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html)

- [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html)
- [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html)
- [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html#pr-s-code-review)
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

English | [简体中文(Simplified Chinese)](https://github.com/opendilab/GenerativeRL_Preview/blob/main/README.zh.md)
English | [简体中文(Simplified Chinese)](https://github.com/opendilab/GenerativeRL/blob/main/README.zh.md)

**GenerativeRL**, short for Generative Reinforcement Learning, is a Python library for solving reinforcement learning (RL) problems using generative models, such as diffusion models and flow models. This library aims to provide a framework for combining the power of generative models with the decision-making capabilities of reinforcement learning algorithms.

Expand Down Expand Up @@ -62,8 +62,8 @@ pip install grl
Or, if you want to install from source:

```bash
git clone https://github.com/opendilab/GenerativeRL_Preview.git
cd GenerativeRL_Preview
git clone https://github.com/opendilab/GenerativeRL.git
cd GenerativeRL
pip install -e .
```

Expand Down
6 changes: 3 additions & 3 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)

[英语 (English)](https://github.com/opendilab/GenerativeRL_Preview/blob/main/README.md) | 简体中文
[英语 (English)](https://github.com/opendilab/GenerativeRL/blob/main/README.md) | 简体中文

**GenerativeRL** 是一个使用生成式模型解决强化学习问题的算法库,支持扩散模型和流模型等不同类型的生成式模型。这个库旨在提供一个框架,将生成式模型的能力与强化学习算法的决策能力相结合。

Expand Down Expand Up @@ -59,8 +59,8 @@ pip install grl
或者,如果你想从源码安装:

```bash
git clone https://github.com/opendilab/GenerativeRL_Preview.git
cd GenerativeRL_Preview
git clone https://github.com/opendilab/GenerativeRL.git
cd GenerativeRL
pip install -e .
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ If you want to try a preview of the latest features, you can install the latest

.. code-block:: console
$ pip install git+https://github.com/opendilab/GenerativeRL_Preview.git
$ pip install git+https://github.com/opendilab/GenerativeRL.git
20 changes: 0 additions & 20 deletions grl/algorithms/srpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,33 +385,18 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.behaviour_policy.learning_rate,
)

# checkpoint = torch.load(
# "/root/github/GenerativeRL_Preview/grl_pipelines/d4rl-halfcheetah-srpo/2024-04-17 06:22:21/checkpoint_diffusion_600000.pt"
# )
# self.model["SRPOPolicy"].sro.diffusion_model.model.load_state_dict(
# checkpoint["diffusion_model"]
# )
# behaviour_model_optimizer.load_state_dict(
# checkpoint["behaviour_model_optimizer"]
# )

for train_diffusion_iter in track(
range(config.parameter.behaviour_policy.iterations),
description="Behaviour policy training",
):
data = next(data_generator)
# data["s"].shape torch.Size([2048, 17]) data["a"].shape torch.Size([2048, 6]) data["r"].shape torch.Size([2048, 1])
behaviour_model_training_loss = self.model[
"SRPOPolicy"
].behaviour_policy_loss(data["a"], data["s"])
behaviour_model_optimizer.zero_grad()
behaviour_model_training_loss.backward()
behaviour_model_optimizer.step()

# if train_iter == 0 or (train_iter + 1) % config.parameter.evaluation.evaluation_interval == 0:
# evaluation_results = evaluate(self.model["SRPOPolicy"], train_iter=train_iter)
# wandb_run.log(data=evaluation_results, commit=False)

wandb_run.log(
data=dict(
train_diffusion_iter=train_diffusion_iter,
Expand Down Expand Up @@ -444,11 +429,6 @@ def policy(obs: np.ndarray) -> np.ndarray:
lr=config.parameter.critic.learning_rate,
)

# checkpoint = torch.load(
# "/root/github/GenerativeRL_Preview/grl_pipelines/d4rl-halfcheetah-srpo/2024-04-17 06:22:21/checkpoint_critic_600000.pt"
# )
# self.model["SRPOPolicy"].critic.q0.load_state_dict(checkpoint["q_model"])
# self.model["SRPOPolicy"].critic.vf.load_state_dict(checkpoint["v_model"])
data_generator = get_train_data(
DataLoader(
self.dataset,
Expand Down

0 comments on commit abcb6d2

Please sign in to comment.