Skip to content

Commit

Permalink
move .gitignore and .style.yapf to root dir (#6)
Browse files Browse the repository at this point in the history
* move .gitignore and .style.yapf to root dir

* align yapf version to 0.30.0
  • Loading branch information
yitongh authored Aug 13, 2024
1 parent 630e333 commit 415d0ad
Show file tree
Hide file tree
Showing 15 changed files with 246 additions and 238 deletions.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
yapf==0.40.2
yapf==0.30.0
pytest==7.3.2
isort
expecttest
49 changes: 25 additions & 24 deletions tests/core/test_bucketing_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,35 @@


class TestBucketingDataLoader:

@classmethod
def setup_class(cls):
bs = 8
seq_len = 512
rand_int = torch.randint(-500, 500, size=(20, ))
rand_int = torch.randint(-500, 500, size=(20,))
data = [{
"input_ids":
torch.zeros((bs, seq_len + rand_int[i]), dtype=torch.int64),
torch.zeros((bs, seq_len + rand_int[i]), dtype=torch.int64),
"attention_mask":
torch.zeros((bs, seq_len + rand_int[i]), dtype=torch.int64),
torch.zeros((bs, seq_len + rand_int[i]), dtype=torch.int64),
"labels":
torch.zeros((bs, seq_len + rand_int[i]), dtype=torch.int64)
torch.zeros((bs, seq_len + rand_int[i]), dtype=torch.int64)
} for i in range(20)]
dataloader = DataLoader(RawDataset(data),
batch_size=None,
shuffle=True)
dataloader = DataLoader(RawDataset(data), batch_size=None, shuffle=True)
cls.dataloader = dataloader

def test_buckets(self):
device = ta.lazy_device()
loader = ta.AsyncLoader(self.dataloader,
device,
max_length=1024,
num_buckets=8,
pad_value_dict={
'input_ids': 0,
'attention_mask': 0,
'labels': -100
})
loader = ta.AsyncLoader(
self.dataloader,
device,
max_length=1024,
num_buckets=8,
pad_value_dict={
'input_ids': 0,
'attention_mask': 0,
'labels': -100
})
buckets = [1024 // 8 * (i + 1) for i in range(1024)]
for batch in loader:
assert batch['input_ids'].cpu().shape[-1] in buckets
Expand All @@ -45,14 +45,15 @@ def test_buckets(self):
def test_uniform_buckets(self):
device = ta.lazy_device()
buckets = [128, 256, 512, 1024]
loader = ta.AsyncLoader(self.dataloader,
device,
buckets=buckets,
pad_value_dict={
'input_ids': 0,
'attention_mask': 0,
'labels': -100
})
loader = ta.AsyncLoader(
self.dataloader,
device,
buckets=buckets,
pad_value_dict={
'input_ids': 0,
'attention_mask': 0,
'labels': -100
})
for batch in loader:
assert batch['input_ids'].cpu().shape[-1] in buckets
assert batch['attention_mask'].cpu().shape[-1] in buckets
Expand Down
1 change: 1 addition & 0 deletions tests/distributed/test_dist_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


class DistTest(MultiProcessTestBase):

@property
def world_size(self) -> int:
return 4
Expand Down
87 changes: 45 additions & 42 deletions tests/ops/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setup_env():

@instantiate_parametrized_tests
class ContextParallelTest(MultiProcessTestBase):

@skip_if_lt_x_gpu(2)
@parametrize("is_cuda", [False, True])
@parametrize("test_varlen", [False, True])
Expand Down Expand Up @@ -106,19 +107,20 @@ def test_cp(
k_lens=q_lens,
softmax_scale=None,
dropout_p=0.0,
intra_process_group=context_parallel.
get_intra_cp_process_group(),
inter_process_group=context_parallel.
get_inter_cp_process_group())
intra_process_group=context_parallel.get_intra_cp_process_group(
),
inter_process_group=context_parallel.get_inter_cp_process_group(
))
else:
output_cp = cp_func(q,
k,
v,
q_lens=q_lens,
k_lens=q_lens,
softmax_scale=None,
dropout_p=0.0,
process_group=cp_group)
output_cp = cp_func(
q,
k,
v,
q_lens=q_lens,
k_lens=q_lens,
softmax_scale=None,
dropout_p=0.0,
process_group=cp_group)

output_cp = context_parallel.gather_forward_split_backward(
output_cp, seq_dim=1, process_group=cp_group)
Expand All @@ -138,45 +140,46 @@ def test_cp(
q, k, v = [
einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]
]
cu_q_lens = torch.arange(0, (B + 1) * N,
step=N,
dtype=torch.int32,
device=q.device)
cu_q_lens = torch.arange(
0, (B + 1) * N, step=N, dtype=torch.int32, device=q.device)
if ta.is_lazy_tensor(q):
output_fa = flash_attn_varlen_xla(q,
k,
v,
cu_q_lens,
cu_q_lens,
N,
N,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False)
output_fa = flash_attn_varlen_xla(
q,
k,
v,
cu_q_lens,
cu_q_lens,
N,
N,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False)
else:
output_fa = flash_attn_varlen_func(q,
k,
v,
cu_q_lens,
cu_q_lens,
N,
N,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False)
output_fa = flash_attn_varlen_func(
q,
k,
v,
cu_q_lens,
cu_q_lens,
N,
N,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False)
output_fa = einops.rearrange(output_fa, "(b s) ... -> b s ...", b=B)

loss_fa = torch.sum(output_fa)
loss_fa.backward()

ta.mark_step()

fwd_close = torch.allclose(output_fa.cpu().detach().to(torch.float32),
output_cp.cpu().detach().to(torch.float32),
rtol=1e-5,
atol=1e-2)
fwd_close = torch.allclose(
output_fa.cpu().detach().to(torch.float32),
output_cp.cpu().detach().to(torch.float32),
rtol=1e-5,
atol=1e-2)
bwd_close = torch.allclose(
hidden_states_fa.grad.cpu().detach().to(torch.float32),
hidden_states_cp.grad.cpu().detach().to(torch.float32),
Expand Down
74 changes: 37 additions & 37 deletions tests/ops/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def setup_env():
)
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal,
local, alibi, deterministic, mha_type,
dtype):
local, alibi, deterministic, mha_type, dtype):
# TODO(to wenting.swt): maybe we need support this
if d % 8 != 0:
pytest.skip(reason="Expected head_size_og % 8 == 0 to be true")
Expand All @@ -64,28 +63,31 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal,
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else tuple(
torch.randint(0, seqlen_k, (2, )).tolist())
q = torch.randn(batch_size,
seqlen_q,
nheads,
d,
device=device,
dtype=dtype,
requires_grad=True)
k = torch.randn(batch_size,
seqlen_k,
nheads_k,
d,
device=device,
dtype=dtype,
requires_grad=True)
v = torch.randn(batch_size,
seqlen_k,
nheads_k,
d,
device=device,
dtype=dtype,
requires_grad=True)
torch.randint(0, seqlen_k, (2,)).tolist())
q = torch.randn(
batch_size,
seqlen_q,
nheads,
d,
device=device,
dtype=dtype,
requires_grad=True)
k = torch.randn(
batch_size,
seqlen_k,
nheads_k,
d,
device=device,
dtype=dtype,
requires_grad=True)
v = torch.randn(
batch_size,
seqlen_k,
nheads_k,
d,
device=device,
dtype=dtype,
requires_grad=True)

if alibi:
alibi_slopes = torch.rand(
Expand Down Expand Up @@ -130,14 +132,16 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal,
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=q.device)
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q.device)
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=q.device)
if alibi:
alibi_slopes = alibi_slopes.cpu().to(device)
out_xla = ta.ops.flash_attn_varlen_xla(
Expand Down Expand Up @@ -167,11 +171,7 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal,
out_xla = out_xla.cpu().detach()

# TODO(to wenting.swt): The rtol and atol here are a bit high.
assert torch.allclose(out_xla,
out_fa,
rtol=1e-2,
atol=1e-2,
equal_nan=True)
assert torch.allclose(out_xla, out_fa, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dq_xla, dq_fa, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dk_xla, dk_fa, rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dv_xla, dv_fa, rtol=1e-2, atol=1e-2, equal_nan=True)
Loading

0 comments on commit 415d0ad

Please sign in to comment.