Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
amend
Browse files Browse the repository at this point in the history
matteobettini committed Jul 10, 2024
1 parent 49d335c commit 0130596
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
@@ -557,7 +557,6 @@ def _collection_loop(self):
initial=self.n_iters_performed,
total=self.config.get_max_n_iters(self.on_policy),
)
sampling_start = time.time()

if not self.config.collect_with_grad:
iterator = iter(self.collector)
@@ -568,6 +567,7 @@ def _collection_loop(self):
for _ in range(
self.n_iters_performed, self.config.get_max_n_iters(self.on_policy)
):
iteration_start = time.time()
if not self.config.collect_with_grad:
batch = next(iterator)
else:
@@ -585,7 +585,7 @@ def _collection_loop(self):
reset_batch = step_mdp(batch[..., -1])

# Logging collection
collection_time = time.time() - sampling_start
collection_time = time.time() - iteration_start
current_frames = batch.numel()
self.total_frames += current_frames
self.mean_return = self.logger.log_collection(
@@ -637,22 +637,8 @@ def _collection_loop(self):
if not self.config.collect_with_grad:
self.collector.update_policy_weights_()

# Timers
# Training timer
training_time = time.time() - training_start
iteration_time = collection_time + training_time
self.total_time += iteration_time
self.logger.log(
{
"timers/collection_time": collection_time,
"timers/training_time": training_time,
"timers/iteration_time": iteration_time,
"timers/total_time": self.total_time,
"counters/current_frames": current_frames,
"counters/total_frames": self.total_frames,
"counters/iter": self.n_iters_performed,
},
step=self.n_iters_performed,
)

# Evaluation
if (
@@ -666,6 +652,20 @@ def _collection_loop(self):
self._evaluation_loop()

# End of step
iteration_time = time.time() - iteration_start
self.total_time += iteration_time
self.logger.log(
{
"timers/collection_time": collection_time,
"timers/training_time": training_time,
"timers/iteration_time": iteration_time,
"timers/total_time": self.total_time,
"counters/current_frames": current_frames,
"counters/total_frames": self.total_frames,
"counters/iter": self.n_iters_performed,
},
step=self.n_iters_performed,
)
self.n_iters_performed += 1
self.logger.commit()
if (
@@ -674,7 +674,6 @@ def _collection_loop(self):
):
self._save_experiment()
pbar.update()
sampling_start = time.time()

if self.config.checkpoint_at_end:
self._save_experiment()

0 comments on commit 0130596

Please sign in to comment.