Skip to content

Commit

Permalink
Implement multi-step update
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jun 27, 2024
1 parent 20928c2 commit d4ea5f9
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 49 deletions.
56 changes: 39 additions & 17 deletions causal_conv1d/causal_conv1d_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,42 +172,64 @@ def causal_conv1d_ref(
return out if not return_final_states else (out, final_states_out)


def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, advance_lengths=None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
advance_lengths: (batch,), dtype int32. Each must be in [0, seqlen]. Values outside this range
will be clipped. If None, advance_lengths will be set to seqlen.
The conv_state will be updated by copying @advance_lengths elements from x to the end of conv_state,
and shifting the rest of the elements to the left.
out: (batch, dim)
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation = activation in ["silu", "swish"]
return causal_conv1d_cuda.causal_conv1d_update(
x, conv_state, weight, bias, activation
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
out = causal_conv1d_cuda.causal_conv1d_update(
x, conv_state, weight, bias, activation, advance_lengths
)
if unsqueeze:
out = out.squeeze(-1)
return out


def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, advance_lengths=None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
advance_lengths: (batch,), dtype int32. Each must be in [0, seqlen]. Values outside this range
will be clipped. If None, advance_lengths will be set to seqlen.
The conv_state will be updated by copying @advance_lengths elements from x to the end of conv_state,
and shifting the rest of the elements to the left.
out: (batch, dim)
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
batch, dim = x.shape
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
assert conv_state.shape == (batch, dim, width)
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
out = torch.sum(conv_state * weight, dim=-1) # (B D)
if bias is not None:
out += bias
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
if advance_lengths is None:
advance_lengths = torch.full((batch,), seqlen, dtype=torch.long, device=x.device)
idx = torch.arange(state_len, dtype=torch.long, device=x.device).unsqueeze(0) + advance_lengths.unsqueeze(1)
idx = idx.unsqueeze(1).expand(-1, dim, -1)
conv_state.copy_(x_new.gather(2, idx))
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
25 changes: 21 additions & 4 deletions csrc/causal_conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,9 @@ causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
bool silu_activation) {
bool silu_activation,
const c10::optional<at::Tensor> &advance_lengths_
) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand All @@ -400,10 +402,13 @@ causal_conv1d_update(const at::Tensor &x,
const auto sizes = x.sizes();
const int batch_size = sizes[0];
const int dim = sizes[1];
const int seqlen = sizes[2];
const int width = weight.size(-1);
const int conv_state_len = conv_state.size(2);
TORCH_CHECK(conv_state_len >= width - 1);

CHECK_SHAPE(x, batch_size, dim);
CHECK_SHAPE(conv_state, batch_size, dim, width);
CHECK_SHAPE(x, batch_size, dim, seqlen);
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
CHECK_SHAPE(weight, dim, width);

TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
Expand All @@ -419,15 +424,27 @@ causal_conv1d_update(const at::Tensor &x,
at::Tensor out = torch::empty_like(x);

ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
silu_activation);
params.conv_state_ptr = conv_state.data_ptr();
params.conv_state_len = conv_state_len;
// All stride are in elements, not bytes.
params.conv_state_batch_stride = conv_state.stride(0);
params.conv_state_c_stride = conv_state.stride(1);
params.conv_state_l_stride = conv_state.stride(2);

if (advance_lengths_.has_value()) {
auto advance_lengths = advance_lengths_.value();
TORCH_CHECK(advance_lengths.scalar_type() == torch::kInt32);
TORCH_CHECK(advance_lengths.is_cuda());
TORCH_CHECK(advance_lengths.stride(-1) == 1);
CHECK_SHAPE(advance_lengths, batch_size);
params.advance_lengths = advance_lengths.data_ptr<int32_t>();
} else {
params.advance_lengths = nullptr;
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
Expand Down
2 changes: 2 additions & 0 deletions csrc/causal_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct ConvParamsBase {
index_t out_c_stride;
index_t out_l_stride;

int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
Expand All @@ -32,6 +33,7 @@ struct ConvParamsBase {
void *__restrict__ out_ptr;

void *__restrict__ conv_state_ptr;
int32_t *__restrict__ advance_lengths;

void *__restrict__ seq_idx_ptr;

Expand Down
50 changes: 27 additions & 23 deletions csrc/causal_conv1d_update.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK

#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "causal_conv1d.h"
#include "causal_conv1d_common.h"
#include "static_switch.h"
Expand All @@ -39,35 +31,47 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int tidx = threadIdx.x;
const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y * kNThreads + tidx;
if (channel_id >= params.dim) return;

input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ channel_id * params.x_c_stride;
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);

int state_len = params.conv_state_len;
int advance_len = params.advance_lengths == nullptr ? params.seqlen : max(min(params.advance_lengths[batch_id], int(params.seqlen)), int(0));

float weight_vals[kWidth] = {0};
if (channel_id < params.dim) {
#pragma unroll
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
}
#pragma unroll
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }

float x_vals[kWidth] = {0};
if (channel_id < params.dim) {
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]); }
if (advance_len > 0) {
for (int update_idx = 0; update_idx < state_len - advance_len; ++update_idx) {
conv_state[update_idx * params.conv_state_l_stride] = conv_state[(update_idx + advance_len) * params.conv_state_l_stride];
}
}
for (int i = 0; i < params.seqlen; ++i) {
input_t x_val = x[i * params.x_l_stride];
if (i < advance_len && state_len - advance_len + i >= 0) {
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
}
x_vals[kWidth - 1] = float(x_val);
float out_val = bias_val;
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
x_vals[kWidth - 1] = float(x[0]);
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
out[i * params.out_l_stride] = input_t(out_val);
// Shift the input buffer by 1
#pragma unroll
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
}

float out_val = bias_val;
#pragma unroll
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
if (channel_id < params.dim) { out[0] = input_t(out_val); }
}

template<int kNThreads, int kWidth, typename input_t, typename weight_t>
Expand Down
17 changes: 12 additions & 5 deletions tests/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,15 @@ def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, cha
# @pytest.mark.parametrize('silu_activation', [False])
@pytest.mark.parametrize("has_bias", [False, True])
# @pytest.mark.parametrize('has_bias', [True])
@pytest.mark.parametrize("has_advance_lengths", [False, True])
# @pytest.mark.parametrize('has_advance_lengths', [True])
@pytest.mark.parametrize("seqlen", [1, 4, 5])
# @pytest.mark.parametrize('seqlen', [4])
@pytest.mark.parametrize("width", [2, 3, 4])
# @pytest.mark.parametrize('width', [2])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# @pytest.mark.parametrize("dim", [2048])
def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
def test_causal_conv1d_update(dim, width, seqlen, has_advance_lengths, has_bias, silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
Expand All @@ -124,17 +128,20 @@ def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
batch = 2
# batch = 1
# dim = 64
x = torch.randn(batch, dim, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
state_len = torch.randint(width - 1, width + 10, (1,)).item()
conv_state = torch.randn(batch, dim, state_len, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
else:
bias = None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
advance_lengths = (torch.randint(0, seqlen + 1, (batch,), dtype=torch.int32, device=device)
if has_advance_lengths else None)
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, advance_lengths=advance_lengths)
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, advance_lengths=advance_lengths)

print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
Expand Down

0 comments on commit d4ea5f9

Please sign in to comment.