Skip to content

Commit

Permalink
CTCLoss: Fix the hang issue caused by barrier divergence (#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Nov 18, 2024
1 parent 3e54add commit f9c7682
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 64 deletions.
137 changes: 74 additions & 63 deletions src/ATen/native/xpu/sycl/LossCTCKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,45 +47,49 @@ struct CTCLossLogAlphaKernelFunctor {
int64_t la_batch_offset = b * la_batch_stride_;
int64_t tg_batch_offset = tg_batch_offsets_[b];

bool valid = true;
if (b >= batch_size_)
return;
valid = false;

// Waiting for support for activeThreadsOnlyBarrier
if (input_length == 0) {
if (tid_x == 0) {
scalar_t log_likelihood = target_length == 0 ? 0 : neginf;
neg_log_likelihood_data_[b] = -log_likelihood;
}
return;
valid = false;
}

// first row (t=0), the three equations for alpha_1 above eq (6)
for (int64_t block_s = 0; block_s < 2 * max_target_length_ + 1;
block_s += item.get_local_range(1)) {
int64_t s = tid_x + block_s;
scalar_t la;
switch (s) {
case 0:
la = log_probs_data_[lp_batch_offset + lp_char_stride_ * BLANK_];
break;
case 1:
la = target_length == 0 ? neginf
: log_probs_data_
[lp_batch_offset +
lp_char_stride_ *
get_target_prime(
targets_data_,
tg_batch_offset,
tg_target_stride_,
1,
BLANK_)];
break;
default:
la = neginf;
if (valid) {
// first row (t=0), the three equations for alpha_1 above eq (6)
for (int64_t block_s = 0; block_s < 2 * max_target_length_ + 1;
block_s += item.get_local_range(1)) {
int64_t s = tid_x + block_s;
scalar_t la;
switch (s) {
case 0:
la = log_probs_data_[lp_batch_offset + lp_char_stride_ * BLANK_];
break;
case 1:
la = target_length == 0 ? neginf
: log_probs_data_
[lp_batch_offset +
lp_char_stride_ *
get_target_prime(
targets_data_,
tg_batch_offset,
tg_target_stride_,
1,
BLANK_)];
break;
default:
la = neginf;
}
if (s < 2 * max_target_length_ + 1)
log_alpha_data_
[la_batch_offset +
/* la_input_stride_ * 0 */ +la_target_stride_ * s] = la;
}
if (s < 2 * max_target_length_ + 1)
log_alpha_data_
[la_batch_offset +
/* la_input_stride_ * 0 */ +la_target_stride_ * s] = la;
}

for (int64_t block_s = 0; block_s < 2 * max_target_length_ + 1;
Expand All @@ -95,7 +99,7 @@ struct CTCLossLogAlphaKernelFunctor {
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2 * target_length + 1 && target_length > 0) {
if (valid && s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(
targets_data_, tg_batch_offset, tg_target_stride_, s, BLANK_);
have_three =
Expand All @@ -112,7 +116,7 @@ struct CTCLossLogAlphaKernelFunctor {
}
for (int64_t t = 1; t < max_input_length_; t++) {
item.barrier(sycl_local_fence);
if ((t < input_length) && (s < 2 * target_length + 1)) {
if (valid && (t < input_length) && (s < 2 * target_length + 1)) {
// only for valid t, s. This is equation (6) and (7), la1, la2, la3
// are the three summands, lamax is the maximum for the logsumexp
// trick.
Expand Down Expand Up @@ -154,7 +158,7 @@ struct CTCLossLogAlphaKernelFunctor {
lp_char_stride_ * current_char];
} else {
// otherwise we just set to neginf
if (s < 2 * max_target_length_ + 1)
if (valid && s < 2 * max_target_length_ + 1)
log_alpha_data_
[la_batch_offset + la_input_stride_ * t +
la_target_stride_ * s] = neginf;
Expand All @@ -163,6 +167,9 @@ struct CTCLossLogAlphaKernelFunctor {
}
item.barrier(sycl_local_fence);

if (!valid)
return;

// compute the loss (eq (8))
if (tid_x == 0) {
scalar_t l1 = log_alpha_data_
Expand Down Expand Up @@ -430,37 +437,41 @@ struct CTCLossBackwardLogBetaKernelFunctor {
int64_t lb_batch_offset = b * lb_batch_stride_;
int64_t tg_batch_offset = tg_batch_offsets_[b];

bool valid = true;

if (b >= batch_size_)
return;
valid = false;

if (input_length == 0)
return;

// "first" row, the beta initialization before eq (10) (t=target_length -
// differes per batch)
for (int64_t block_s =
2 * max_target_length_ - (2 * max_target_length_ % group_size_x);
block_s >= 0;
block_s -= group_size_x) {
int64_t s = tid_x + block_s;
scalar_t lb;
if (s == 2 * target_length) {
lb = log_probs_data_
[lp_batch_offset + (input_length - 1) * lp_input_stride_ +
lp_char_stride_ * BLANK_];
} else if (s == 2 * target_length - 1) { // false for target_length == 0
int64_t current_target_prime = get_target_prime(
targets_data_, tg_batch_offset, tg_target_stride_, s, BLANK_);
lb = log_probs_data_
[lp_batch_offset + (input_length - 1) * lp_input_stride_ +
lp_char_stride_ * current_target_prime];
} else {
lb = neginf;
}
if (s < 2 * max_target_length_ + 1) {
log_beta_data_
[lb_batch_offset + (input_length - 1) * lb_input_stride_ +
lb_target_stride_ * s] = lb;
valid = false;

if (valid) {
// "first" row, the beta initialization before eq (10) (t=target_length -
// differes per batch)
for (int64_t block_s =
2 * max_target_length_ - (2 * max_target_length_ % group_size_x);
block_s >= 0;
block_s -= group_size_x) {
int64_t s = tid_x + block_s;
scalar_t lb;
if (s == 2 * target_length) {
lb = log_probs_data_
[lp_batch_offset + (input_length - 1) * lp_input_stride_ +
lp_char_stride_ * BLANK_];
} else if (s == 2 * target_length - 1) { // false for target_length == 0
int64_t current_target_prime = get_target_prime(
targets_data_, tg_batch_offset, tg_target_stride_, s, BLANK_);
lb = log_probs_data_
[lp_batch_offset + (input_length - 1) * lp_input_stride_ +
lp_char_stride_ * current_target_prime];
} else {
lb = neginf;
}
if (s < 2 * max_target_length_ + 1) {
log_beta_data_
[lb_batch_offset + (input_length - 1) * lb_input_stride_ +
lb_target_stride_ * s] = lb;
}
}
}

Expand All @@ -472,7 +483,7 @@ struct CTCLossBackwardLogBetaKernelFunctor {
int64_t s = tid_x + block_s;
int64_t current_target_prime;
bool have_three;
if (s < 2 * target_length + 1 && target_length > 0) {
if (valid && s < 2 * target_length + 1 && target_length > 0) {
current_target_prime = get_target_prime(
targets_data_, tg_batch_offset, tg_target_stride_, s, BLANK_);
have_three =
Expand All @@ -491,7 +502,7 @@ struct CTCLossBackwardLogBetaKernelFunctor {
// we did above.
for (int64_t t = max_input_length_ - 2; t >= 0; t--) {
item.barrier(sycl_local_fence);
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
if (valid && (t < input_length - 1) && (s < 2 * target_length + 1)) {
scalar_t lb1 = log_beta_data_
[lb_batch_offset + lb_input_stride_ * (t + 1) +
lb_target_stride_ * s];
Expand Down Expand Up @@ -531,7 +542,7 @@ struct CTCLossBackwardLogBetaKernelFunctor {
[lb_batch_offset + lb_input_stride_ * t + lb_target_stride_ * s] =
lb;
} else if (
(s < 2 * max_target_length_ + 1) &&
(b < batch_size_) && (s < 2 * max_target_length_ + 1) &&
(((target_length == 0) && (s > 0)) ||
(s >= 2 * target_length + 1) || (t >= input_length))) {
log_beta_data_
Expand Down
2 changes: 1 addition & 1 deletion test/xpu/test_nn_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,7 +1739,7 @@ def helper(self, size, groups, memory_format, is_mixed, device, dtype):
helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed, device, dtype)
helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed, device, dtype)
helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed, device, dtype)
TestNN.test_groupnorm_nhwc = _test_groupnorm_nhwc
TestNN.test_groupnorm_nhwc = None # TODO: Disable it temporarily as Pytorch has revert the PR: https://github.com/pytorch/pytorch/pull/126635

@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
@parametrize_test("mode", ["bilinear", "bicubic"])
Expand Down

0 comments on commit f9c7682

Please sign in to comment.