Skip to content

Latest commit

 

History

History
36 lines (23 loc) · 879 Bytes

0011_Add.md

File metadata and controls

36 lines (23 loc) · 879 Bytes

Batched Addition

Link: https://github.com/cmeraki/vit.triton/blob/main/vit/kernels/add.py

Author: Romit Jain

Tags: Add

Description: Implements batched element-wise addition of two tensors. The kernel is pretty simple but at par with PyTorch's speed. This is a great kernel to get started with Triton. Adds two tensors with shape (B, N, T).

Minimal Usage:

import torch
from vit.kernels import add

device='cuda'
dtype=torch.float16

batch_size=12
num_tokens=500
dim=200

A = torch.randn(batch_size, num_tokens, dim, dtype=dtype, device=device)
B = torch.randn(batch_size, num_tokens, dim, dtype=dtype, device=device)

y_torch = torch.add(A, B)
y_triton = add(A, B)

Triton Version: v2.3.0

Other Notes:
The kernel doesn't handle the broadcast addition as of now.

Id in triton index: 0011