Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Jul 25, 2023
1 parent 8cf72f4 commit 5833e4f
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions pytorch_blade/tests/disc/ops/test_input_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class KVCacheModule(nn.Module):
def forward(self, k_cache: Tensor, k: Tensor, step : Tensor):
k_cache[..., step - k.shape[-2]: step , :].add_(k)
value = k_cache[..., : step, :]
# attention
value = torch.matmul(k, value.transpose(-2, -1))
return k_cache, value

class TestInputMutation(DiscTestCase):
Expand All @@ -40,16 +38,16 @@ def tearDown(self):
def test_inplace_kv(self):
k_cache = torch.zeros(2, 32, 8, device=self.device)
k = torch.ones(2, 1, 8, device=self.device)

m = KVCacheModule()
m.train(False)
step = torch.tensor(1)
opt_func = torch_blade.optimize(m, allow_tracing=True, model_inputs=(k_cache.clone(), k.clone(), step))
expect = m(k_cache.clone(), k.clone(), step)
actual = opt_func(k_cache.clone(), k.clone(), step)
for exp, act in zip(expect, actual):
print(exp)
print(act)
print(exp.cpu())
print(act.cpu())
self.assertTrue(torch.allclose(exp.cpu(), act.cpu()))

if __name__ == "__main__":
Expand Down

0 comments on commit 5833e4f

Please sign in to comment.