Skip to content

Commit

Permalink
let averaged model optional
Browse files Browse the repository at this point in the history
  • Loading branch information
lifeiteng committed Mar 2, 2023
1 parent 2d80d5c commit 3da3bc2
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions valle/bin/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def get_parser():
parser.add_argument(
"--average-period",
type=int,
default=200,
default=0,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
Expand Down Expand Up @@ -466,6 +466,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
for metric in metrics:
info[metric] = metrics[metric].detach().cpu().item()
del metrics

return predicts, loss, info

Expand Down Expand Up @@ -629,16 +630,17 @@ def train_one_epoch(
display_and_save_batch(batch, params=params)
raise

if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if params.average_period > 0:
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)

if (
params.batch_idx_train > 0
Expand Down Expand Up @@ -813,7 +815,7 @@ def run(rank, world_size, args):

assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
if rank == 0 and params.average_period > 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)

Expand Down

0 comments on commit 3da3bc2

Please sign in to comment.