Skip to content

Commit

Permalink
Fix FBGEMM_GPU_MEMCHECK in Split optimizers (#3416)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3416

X-link: facebookresearch/FBGEMM#504

As title

Reviewed By: q10

Differential Revision: D66474172

fbshipit-source-id: abfbc54bd3b5ca37cb704c99de2cac86b21bb67e
  • Loading branch information
sryap authored and facebook-github-bot committed Nov 26, 2024
1 parent 51fe9ee commit 357b54c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@ template <
>
__global__ __launch_bounds__(kMaxThreads)
void split_{{ optimizer }}_update_kernel(
at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights,
at::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits> lxu_cache_weights,
const at::PackedTensorAccessor32<emb_t, 1, at::RestrictPtrTraits> grad_dev_weights,
pta::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
pta::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights,
pta::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<emb_t, 1, at::RestrictPtrTraits> grad_dev_weights,
// grad_dev_indices is equivalent to sorted_linear_indices_run
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_dev_indices,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_dev_indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_lxu_cache_locations,
const int32_t max_D,
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args,
{{ args.split_kernel_args | join(", ") }}
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
) {
const auto run_id = blockIdx.x * blockDim.y + threadIdx.y;
if (run_id >= grad_dev_indices.size(0)) {
Expand Down Expand Up @@ -130,21 +130,22 @@ void split_{{ optimizer }}_update_kernel
{{ kThreadGroupSize }},
4 // VEC_WIDTH
>(
at::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights,
at::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights,
at::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights,
const at::PackedTensorAccessor32<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights,
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights,
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights,
pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights,
// grad_dev_indices is equivalent to sorted_linear_indices_run
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_dev_indices,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_dev_indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_lxu_cache_locations,
const int32_t max_D,
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args,
{{ args.split_kernel_args_no_defaults |
replace_pta_namespace() |
replace_placeholder_types(ph_type_combo) |
join(",\n ") |
replace("cache_t", cache_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// clang-format off
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/utils/tensor_accessor.h"

using Tensor = at::Tensor;
using namespace fbgemm_gpu;
Expand All @@ -24,21 +25,21 @@ template <
>
__global__ __launch_bounds__(kMaxThreads) void
split_{{ optimizer }}_update_kernel(
at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights,
at::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits> lxu_cache_weights,
const at::PackedTensorAccessor32<emb_t, 1, at::RestrictPtrTraits> grad_dev_weights,
pta::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
pta::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights,
pta::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<emb_t, 1, at::RestrictPtrTraits> grad_dev_weights,
// grad_dev_indices is equivalent to sorted_linear_indices_run
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_dev_indices,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_dev_indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_lxu_cache_locations,
const int32_t max_D,
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args,
{{ args.split_kernel_args | join(", ") }});
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }});

////////////////////////////////////////////////////////////////////////////////
// Auto generated placeholder tensor dispatch macros
Expand Down Expand Up @@ -171,6 +172,10 @@ void split_embedding_{{ optimizer }}_update(
#else
constexpr int kThreadGroupSize = kWarpSize;
#endif
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "split_{{ optimizer }}_update_kernel";
#endif

DISPATCH_PLACEHOLDER_TYPES(
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_name + "_dev" }}.scalar_type(),
Expand All @@ -192,21 +197,21 @@ void split_embedding_{{ optimizer }}_update(
at::cuda::getCurrentCUDAStream()
>>>
(
dev_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
flatten_grad_dev_weights.packed_accessor32<emb_t, 1, at::RestrictPtrTraits>(),
flatten_grad_dev_indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
MAKE_PTA_WITH_NAME(func_name, dev_weights, emb_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, uvm_weights, emb_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, cache_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, flatten_grad_dev_weights, emb_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, flatten_grad_dev_indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
// Use weights_placements instead of
// sorted_lxu_cache_locations because LXU cache is not
// supported right now
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32),
max_D,
stochastic_rounding,
rng_engine_inputs,
{{ args.split_kernel_arg_constructors | join(", ") }}
{{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name") | join(", ") }}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}); // DISPATCH_PLACEHOLDER_TYPES
Expand Down

0 comments on commit 357b54c

Please sign in to comment.