Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updating layer norm impl
Browse files Browse the repository at this point in the history
szaman19 committed Jun 25, 2024

Verified

This commit was signed with the committer’s verified signature.
1 parent 9e44581 commit 021d2a7
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/layers/regularizers/layer_norm.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
// Produced at the Lawrence Livermore National Laboratory.
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
// the CONTRIBUTORS file. <[email protected]>
@@ -28,8 +28,8 @@
#include "layer_norm_kernels.cuh"
#include "lbann/comm_impl.hpp"
#include "lbann/layers/regularizers/layer_norm.hpp"
#include "lbann/optimizers/optimizer.hpp"
#include "lbann/layers/regularizers/layer_norm_impl.hpp"
#include "lbann/optimizers/optimizer.hpp"
#include "lbann/utils/gpu/helpers.hpp"

#ifdef LBANN_HAS_DISTCONV
@@ -556,12 +556,12 @@ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
{
#ifdef LBANN_HAS_DISTCONV
#ifdef LBANN_HAS_DISTCONV
if (this->distconv_enabled()) {
this->get_distconv_adapter().fp_compute();
return;
}
#endif // LBANN_HAS_DISTCONV
#endif // LBANN_HAS_DISTCONV

int weight_idx = 0;
const TensorDataType* scale_weights = nullptr;
@@ -575,6 +575,7 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
El::Int norm_size, global_norm_size, num_norm, norm_stride;
this->get_normdims(norm_size, global_norm_size, num_norm, norm_stride);

<<<<<<< HEAD
<<<<<<< HEAD
#ifdef LBANN_HAS_DISTCONV
if (this->distconv_enabled()) {
@@ -585,6 +586,8 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
=======
>>>>>>> f02146109 (Updated implementation with updating statistics tensors)
=======
>>>>>>> ecac28c9f (Updating layer norm impl)
fp_impl(*this->get_comm(),
this->m_epsilon,
norm_size,
@@ -601,13 +604,13 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_layer<TensorDataType, Layout, Device>::bp_compute()
{
#ifdef LBANN_HAS_DISTCONV
#ifdef LBANN_HAS_DISTCONV
if (this->distconv_enabled()) {
this->get_distconv_adapter().bp_compute();
return;
}
#endif // LBANN_HAS_DISTCONV
#endif // LBANN_HAS_DISTCONV
// Obtain optional buffers
const TensorDataType* scale_weights = nullptr;
TensorDataType* scale_grad = nullptr;

0 comments on commit 021d2a7

Please sign in to comment.