-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matmul with DID loop split #3651
base: main
Are you sure you want to change the base?
Conversation
!build |
!test |
|
||
fd = Model(d, b, s, e) | ||
out_tensors = fd.execute([inp_tensor, sharded_weight_tensor]) | ||
print(f"Output tensor: {out_tensors[0].shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print(f"Output tensor: {out_tensors[0].shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd avoid debug prints to stdout. If you intended to keep it, you may want to consider logging.debug
and how to do it in a compatible way to pytorch.
expected_out_tensor = unsharded_out_tensor.view([b, s, d, e]).permute(2, 0, 1, 3)[ | ||
rank : rank + 1 | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be done with shard_tensor?
@@ -1313,9 +1314,37 @@ bool hasTrivialAllocationDomain(const TensorView* tv) { | |||
} | |||
const std::vector<IterDomain*>& alloc = tv->getMaybeAllocationDomain(); | |||
const std::vector<IterDomain*>& logical = tv->getLogicalDomain(); | |||
return TensorDomain::noBroadcasts(TensorDomain::noReductions(logical)) == | |||
const auto alloc_no_red_bcast = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unsure about changing hasTrivialAllocationDomain, which is used elsewhere. Philosophically, its name suggests it should check for the allocation being trivial and having a DID split is conceptually non-trivial.
More practically, how about using inferShapeOfOutput? I made it support DID loop split and it's used by KernelExecutor to allocate outputs. I think in MatmulOp::evaluate you can say something like
sizes, strides = inferShapeOfOutput(tv, ee);
create a meta tensor of sizes and strides
if that_meta_tensor.is_contiguous():
return the result of at::matmul
...
When you change the code to use at::matmul_out (which I still think is a change we should land sooner than later), you can further simplify MatmulOp::evaluate to:
sizes, strides = inferShapeOfOutput(tv, ee);
out = at::empty_strided(sizes, strides, ...);
at::matmul_out(...);
This PR modifies the
hasTrivialAllocationDomain
to consider if the tensorview has a DID loop split. In this case, we compare the corresponding iterdomains for logical and allocation domain across all but the sharded logical axis.Note: This does not guarantee that
MatmulOp
with non-trivial stride order will work for DID loop split. I suspect it will require some additional changes to theMatmulOp::evaluate
method.