Skip to content

Commit

Permalink
XeGPU Flash Attention implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
charithaintc committed Mar 27, 2024
1 parent 385bcd2 commit 15a181e
Show file tree
Hide file tree
Showing 3 changed files with 1,084 additions and 9 deletions.
21 changes: 12 additions & 9 deletions lib/ExecutionEngine/ImexRunnerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,15 +211,18 @@ void _mlir_ciface_printMaxError(UnrankedMemRefType<T> *M,
DynamicMemRefType<T> DN = DynamicMemRefType<T>(*N);
DynamicMemRefIterator<T> i = DM.begin();
DynamicMemRefIterator<T> j = DN.begin();
std::pair<float, DynamicMemRefIterator<T>> max_rel_err_idx{0.0, DM.begin()};
std::pair<float, DynamicMemRefIterator<T>> max_abs_err_idx{0.0, DM.begin()};
for (; i != DM.end() && j != DN.end(); ++i, ++j) {
const float delta = getFloat(*i) - getFloat(*j);
const float delta_abs = fabs(delta);
if (delta > max_abs_err_idx.first) {
max_abs_err_idx = {delta_abs, i};
max_rel_err_idx = {delta, i};
}
std::pair<double, DynamicMemRefIterator<T>> max_rel_err_idx{0.0, DM.begin()};
std::pair<double, DynamicMemRefIterator<T>> max_abs_err_idx{0.0, DM.begin()};
uint64_t idx = 0;
for (; i != DM.end() && j != DN.end(); ++i, ++j, ++idx) {
const double i_val = getFloat(*i);
const double j_val = getFloat(*j);
const double delta = fabs(i_val - j_val);
const double rel_error = delta / fmax(fabs(i_val), fabs(j_val));
if (delta > max_abs_err_idx.first)
max_abs_err_idx = {delta, i};
if (rel_error > max_rel_err_idx.first)
max_rel_err_idx = {rel_error, i};
}
std::cout << "Max absolute error " << max_abs_err_idx.first
<< " at idx=" << std::distance(DM.begin(), max_abs_err_idx.second)
Expand Down
Loading

0 comments on commit 15a181e

Please sign in to comment.