Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(lyd): refactor dt_policy in new pipeline #687

Closed
wants to merge 5 commits into from

Conversation

AltmanD
Copy link
Collaborator

@AltmanD AltmanD commented Jul 18, 2023

Description

新的pipeline下DT效果不如旧的pipeline好

@AltmanD AltmanD closed this Jul 18, 2023
@AltmanD AltmanD reopened this Jul 18, 2023
@AltmanD AltmanD marked this pull request as ready for review July 18, 2023 07:54
@PaParaZz1 PaParaZz1 added algo Add new algorithm or improve old one refactor refactor module or component labels Jul 18, 2023
@@ -1174,6 +1174,23 @@ def reset(self):
return self.env.reset()


class AllinObsWrapper(gym.Wrapper):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add overview description for this wrapper

@@ -42,7 +42,8 @@

from .d4pg import D4PGPolicy
from .cql import CQLPolicy, CQLDiscretePolicy
from .decision_transformer import DTPolicy
# from .decision_transformer import DTPolicy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the commented code

class DTPolicy(Policy):
r"""
Overview:
Policy class of DT algorithm in discrete environments.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the full name of DT and the paper link


def _init_learn(self) -> None:
r"""
Overview:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish indents

@@ -27,7 +27,7 @@
embed_dim=128,
n_heads=1,
dropout_p=0.1,
log_dir='/home/puyuan/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps',
log_dir='/mnt/nfs/luyd/DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't upload the absolute path

@@ -0,0 +1,2 @@
duration,num_updates,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unnecessary log files


self.context_len = context_len
def __init__(self, cfg: dict) -> None:
dataset_path = cfg.policy.collect.get('data_path', None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use xxx.get, we must ensure all the configs are fixed after compile_config


self.running_rtg = [self.rtg_target / self.rtg_scale] * self.eval_batch_size
self.t = [0] * self.eval_batch_size
self.timesteps = torch.arange(start=0, end=self.max_eval_ep_len, step=1).repeat(self.eval_batch_size, 1).to(self.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indicate device when use torch.arange

@AltmanD AltmanD changed the title Dev dt in new pipeline refactor(lyd): refactor dt_policy in new pipeline Jul 18, 2023
@AltmanD AltmanD closed this Jul 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one refactor refactor module or component
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants