Skip to content

Commit

Permalink
Fix GPU memory increase during training
Browse files Browse the repository at this point in the history
  • Loading branch information
lifeiteng committed Mar 2, 2023
1 parent 3da3bc2 commit 7725efb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions valle/models/valle.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def __init__(
NUM_AUDIO_TOKENS + 1,
top_k=10,
average="micro",
multidim_average="samplewise",
multidim_average="global",
ignore_index=NUM_AUDIO_TOKENS,
)

self.nar_accuracy_metric = MulticlassAccuracy(
NUM_AUDIO_TOKENS + 1,
top_k=10,
average="micro",
multidim_average="samplewise",
multidim_average="global",
ignore_index=NUM_AUDIO_TOKENS,
)

Expand Down Expand Up @@ -275,7 +275,7 @@ def pad_y_eos(y, eos_id):
total_loss = F.cross_entropy(logits, targets, reduction=reduction)
metrics["ArTop10Accuracy"] = self.ar_accuracy_metric(
logits.detach(), targets
).mean() * y_lens.sum().type(torch.float32)
).item() * y_lens.sum().type(torch.float32)

# Non-AR Decoders
train_stage = self.rng.choices(
Expand Down Expand Up @@ -310,7 +310,7 @@ def pad_y_eos(y, eos_id):
value=logits.min().cpu().item(),
),
targets,
).mean()
).item()
* y_lens.sum().type(torch.float32)
)

Expand Down Expand Up @@ -540,7 +540,7 @@ def pad_y_eos(y, eos_id):

metrics["ArTop10Accuracy"] = self.ar_accuracy_metric(
logits.detach(), targets
).mean() * y_lens.sum().type(torch.float32)
).item() * y_lens.sum().type(torch.float32)

# Non-AR Decoders
train_stage = self.rng.choices(
Expand Down Expand Up @@ -578,7 +578,7 @@ def pad_y_eos(y, eos_id):
value=logits.min().cpu().item(),
),
targets,
).mean()
).item()
* y_lens.sum().type(torch.float32)
)

Expand Down
2 changes: 1 addition & 1 deletion valle/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def extend_pe(self, x):
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()

def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
Expand Down

0 comments on commit 7725efb

Please sign in to comment.