Skip to content

Commit

Permalink
Factor out lamda function in collective-select-folder
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711470542
  • Loading branch information
frgossen authored and tensorflower-gardener committed Jan 2, 2025
1 parent 28c0aae commit b77b79b
Showing 1 changed file with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,27 +126,29 @@ std::optional<FoldableSelect> MatchFoldableSelect(HloInstruction* select) {
select->mutable_operand(1), select->mutable_operand(2)};
}

bool SelectPredicateEval(const FoldableSelect& select_match,
const SourceTargetPair& pair) {
int64_t src_id = pair.first;
return select_match.cmp_direction == Comparison::Direction::kEq
? src_id == select_match.constant_id
: src_id != select_match.constant_id;
};

std::optional<bool> StaticallyEvaluatePredicateForAllSourceIDs(
FoldableSelect select_match, SourceTargetPairs pairs) {
const FoldableSelect& select_match, const SourceTargetPairs& pairs) {
// If there are no pairs, the predicate is undefined.
if (pairs.empty()) return std::nullopt;

// Evaluate the select predicate for the first source target pair.
CHECK(select_match.cmp_direction == Comparison::Direction::kEq ||
select_match.cmp_direction == Comparison::Direction::kNe);
auto select_predicate_eval = [&select_match](const SourceTargetPair& pair) {
int64_t src_id = pair.first;
return select_match.cmp_direction == Comparison::Direction::kEq
? src_id == select_match.constant_id
: src_id != select_match.constant_id;
};
bool result_candidate = select_predicate_eval(pairs.front());
bool result_candidate = SelectPredicateEval(select_match, pairs.front());

// Check that the result is the same for all source target pairs. If not,
// we have a contradiction and cannot statically evaluate the predicate. We
// return std::nullopt in this case.
if (!absl::c_all_of(pairs, [&](const SourceTargetPair& it) -> bool {
return result_candidate == select_predicate_eval(it);
return result_candidate == SelectPredicateEval(select_match, it);
})) {
return std::nullopt;
}
Expand Down

0 comments on commit b77b79b

Please sign in to comment.