Link: https://github.com/cmeraki/vit.triton/blob/main/vit/kernels/add.py
Author: Romit Jain
Tags: Add
Description: Implements batched element-wise addition of two tensors. The kernel is pretty simple but at par with PyTorch's speed. This is a great kernel to get started with Triton. Adds two tensors with shape (B, N, T).
Minimal Usage:
import torch
from vit.kernels import add
device='cuda'
dtype=torch.float16
batch_size=12
num_tokens=500
dim=200
A = torch.randn(batch_size, num_tokens, dim, dtype=dtype, device=device)
B = torch.randn(batch_size, num_tokens, dim, dtype=dtype, device=device)
y_torch = torch.add(A, B)
y_triton = add(A, B)
Triton Version: v2.3.0
Other Notes:
The kernel doesn't handle the broadcast addition as of now.
Id in triton index: 0011