Skip to content

Commit

Permalink
Disable InferType if it was done and no changes after previous pass
Browse files Browse the repository at this point in the history
This optimizatin allows to speedup PatternRewriter transformations by
reusing of preious type inferred expression instead of perform
InferType multiple times
  • Loading branch information
elvin-n committed Jan 13, 2025
1 parent 567eeed commit bd02406
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -851,24 +851,32 @@ Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const E
std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> done;
do {
last = post;
// We don't have to call InferType if previous pass has not modified anything
// We can just take previous typed state of the expression
bool types_invalidated = true;
for (auto callback : callbacks) {
if (!done[callback]) {
auto before = post;
auto post_typed = post;
callback_ = callback;
if (callback_->require_type) {
post = InferTypeWithModule(post, mod_);
if (callback_->require_type && types_invalidated) {
post_typed = InferTypeWithModule(post, mod_);
}
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(callback_->pattern, post);
groups_ = grouper.GroupMatches(callback_->pattern, post_typed);
gid_assignments_ = grouper.GetGIDAssignments();
memo_.clear();
VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
post = this->VisitExpr(post);
post = this->VisitExpr(post_typed);
VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
count++;
if (callback_->rewrite_once) {
bool current_equal = (*structural_equal)(before, post, false, true);
if (!current_equal) {
bool current_equal = (*structural_equal)(before, post, false, true);
if (callback_->require_type && current_equal) {
types_invalidated = false;
post = post_typed;
} else {
types_invalidated = true;
if (callback_->rewrite_once) {
done[callback] = true;
}
}
Expand Down

0 comments on commit bd02406

Please sign in to comment.