-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #252 from AWhitmell/andrew/benchmarking
Add flop counting of scheduled impero.
- Loading branch information
Showing
7 changed files
with
357 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
""" | ||
This file contains all the necessary functions to accurately count the | ||
total number of floating point operations for a given script. | ||
""" | ||
|
||
import gem.gem as gem | ||
import gem.impero as imp | ||
from functools import singledispatch | ||
import numpy | ||
import math | ||
|
||
|
||
@singledispatch | ||
def statement(tree, parameters): | ||
raise NotImplementedError | ||
|
||
|
||
@statement.register(imp.Block) | ||
def statement_block(tree, parameters): | ||
flops = sum(statement(child, parameters) for child in tree.children) | ||
return flops | ||
|
||
|
||
@statement.register(imp.For) | ||
def statement_for(tree, parameters): | ||
extent = tree.index.extent | ||
assert extent is not None | ||
child, = tree.children | ||
flops = statement(child, parameters) | ||
return flops * extent | ||
|
||
|
||
@statement.register(imp.Initialise) | ||
def statement_initialise(tree, parameters): | ||
return 0 | ||
|
||
|
||
@statement.register(imp.Accumulate) | ||
def statement_accumulate(tree, parameters): | ||
flops = expression_flops(tree.indexsum.children[0], parameters) | ||
return flops + 1 | ||
|
||
|
||
@statement.register(imp.Return) | ||
def statement_return(tree, parameters): | ||
flops = expression_flops(tree.expression, parameters) | ||
return flops + 1 | ||
|
||
|
||
@statement.register(imp.ReturnAccumulate) | ||
def statement_returnaccumulate(tree, parameters): | ||
flops = expression_flops(tree.indexsum.children[0], parameters) | ||
return flops + 1 | ||
|
||
|
||
@statement.register(imp.Evaluate) | ||
def statement_evaluate(tree, parameters): | ||
flops = expression_flops(tree.expression, parameters, top=True) | ||
return flops | ||
|
||
|
||
@singledispatch | ||
def flops(expr, parameters): | ||
raise NotImplementedError(f"Don't know how to count flops of {type(expr)}") | ||
|
||
|
||
@flops.register(gem.Failure) | ||
def flops_failure(expr, parameters): | ||
raise ValueError("Not expecting a Failure node") | ||
|
||
|
||
@flops.register(gem.Variable) | ||
@flops.register(gem.Identity) | ||
@flops.register(gem.Delta) | ||
@flops.register(gem.Zero) | ||
@flops.register(gem.Literal) | ||
@flops.register(gem.Index) | ||
@flops.register(gem.VariableIndex) | ||
def flops_zero(expr, parameters): | ||
# Initial set up of these Gem nodes are of 0 floating point operations. | ||
return 0 | ||
|
||
|
||
@flops.register(gem.LogicalNot) | ||
@flops.register(gem.LogicalAnd) | ||
@flops.register(gem.LogicalOr) | ||
@flops.register(gem.ListTensor) | ||
def flops_zeroplus(expr, parameters): | ||
# These nodes contribute 0 floating point operations, but their children may not. | ||
return 0 + sum(expression_flops(child, parameters) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.Product) | ||
def flops_product(expr, parameters): | ||
# Multiplication by -1 is not a flop. | ||
a, b = expr.children | ||
if isinstance(a, gem.Literal) and a.value == -1: | ||
return expression_flops(b, parameters) | ||
elif isinstance(b, gem.Literal) and b.value == -1: | ||
return expression_flops(a, parameters) | ||
else: | ||
return 1 + sum(expression_flops(child, parameters) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.Sum) | ||
@flops.register(gem.Division) | ||
@flops.register(gem.Comparison) | ||
@flops.register(gem.MathFunction) | ||
@flops.register(gem.MinValue) | ||
@flops.register(gem.MaxValue) | ||
def flops_oneplus(expr, parameters): | ||
return 1 + sum(expression_flops(child, parameters) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.Power) | ||
def flops_power(expr, parameters): | ||
base, exponent = expr.children | ||
base_flops = expression_flops(base, parameters) | ||
if isinstance(exponent, gem.Literal): | ||
exponent = exponent.value | ||
if exponent > 0 and exponent == math.floor(exponent): | ||
return base_flops + int(math.ceil(math.log2(exponent))) | ||
else: | ||
return base_flops + 5 # heuristic | ||
else: | ||
return base_flops + 5 # heuristic | ||
|
||
|
||
@flops.register(gem.Conditional) | ||
def flops_conditional(expr, parameters): | ||
condition, then, else_ = (expression_flops(child, parameters) | ||
for child in expr.children) | ||
return condition + max(then, else_) | ||
|
||
|
||
@flops.register(gem.Indexed) | ||
@flops.register(gem.FlexiblyIndexed) | ||
def flops_indexed(expr, parameters): | ||
aggregate = sum(expression_flops(child, parameters) | ||
for child in expr.children) | ||
# Average flops per entry | ||
return aggregate / numpy.product(expr.children[0].shape, dtype=int) | ||
|
||
|
||
@flops.register(gem.IndexSum) | ||
def flops_indexsum(expr, parameters): | ||
raise ValueError("Not expecting IndexSum") | ||
|
||
|
||
@flops.register(gem.Inverse) | ||
def flops_inverse(expr, parameters): | ||
n, _ = expr.shape | ||
# 2n^3 + child flop count | ||
return 2*n**3 + sum(expression_flops(child, parameters) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.Solve) | ||
def flops_solve(expr, parameters): | ||
n, m = expr.shape | ||
# 2mn + inversion cost of A + children flop count | ||
return 2*n*m + 2*n**3 + sum(expression_flops(child, parameters) | ||
for child in expr.children) | ||
|
||
|
||
@flops.register(gem.ComponentTensor) | ||
def flops_componenttensor(expr, parameters): | ||
raise ValueError("Not expecting ComponentTensor") | ||
|
||
|
||
def expression_flops(expression, parameters, top=False): | ||
"""An approximation to flops required for each expression. | ||
:arg expression: GEM expression. | ||
:arg parameters: Useful miscellaneous information. | ||
:arg top: are we at the root? | ||
:returns: flop count for the expression | ||
""" | ||
if not top and expression in parameters.temporaries: | ||
return 0 | ||
else: | ||
return flops(expression, parameters) | ||
|
||
|
||
def count_flops(impero_c): | ||
"""An approximation to flops required for a scheduled impero_c tree. | ||
:arg impero_c: a :class:`~.Impero_C` object. | ||
:returns: approximate flop count for the tree. | ||
""" | ||
try: | ||
return statement(impero_c.tree, impero_c) | ||
except (ValueError, NotImplementedError): | ||
return 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
import gem.gem as gem | ||
from gem.flop_count import count_flops | ||
from gem.impero_utils import preprocess_gem | ||
from gem.impero_utils import compile_gem | ||
|
||
|
||
def test_count_flops(expression): | ||
expr, expected = expression | ||
flops = count_flops(expr) | ||
assert flops == expected | ||
|
||
|
||
@pytest.fixture(params=("expr1", "expr2", "expr3", "expr4")) | ||
def expression(request): | ||
if request.param == "expr1": | ||
expr = gem.Sum(gem.Product(gem.Variable("a", ()), gem.Literal(2)), | ||
gem.Division(gem.Literal(3), gem.Variable("b", ()))) | ||
C = gem.Variable("C", (1,)) | ||
i, = gem.indices(1) | ||
Ci = C[i] | ||
expr, = preprocess_gem([expr]) | ||
assignments = [(Ci, expr)] | ||
expr = compile_gem(assignments, (i,)) | ||
# C += a*2 + 3/b | ||
expected = 1 + 3 | ||
elif request.param == "expr2": | ||
expr = gem.Comparison(">=", gem.MaxValue(gem.Literal(1), gem.Literal(2)), | ||
gem.MinValue(gem.Literal(3), gem.Literal(1))) | ||
C = gem.Variable("C", (1,)) | ||
i, = gem.indices(1) | ||
Ci = C[i] | ||
expr, = preprocess_gem([expr]) | ||
assignments = [(Ci, expr)] | ||
expr = compile_gem(assignments, (i,)) | ||
# C += max(1, 2) >= min(3, 1) | ||
expected = 1 + 3 | ||
elif request.param == "expr3": | ||
expr = gem.Solve(gem.Identity(3), gem.Inverse(gem.Identity(3))) | ||
C = gem.Variable("C", (3, 3)) | ||
i, j = gem.indices(2) | ||
Cij = C[i, j] | ||
expr, = preprocess_gem([expr[i, j]]) | ||
assignments = [(Cij, expr)] | ||
expr = compile_gem(assignments, (i, j)) | ||
# C += solve(Id(3x3), Id(3x3)^{-1}) | ||
expected = 9 + 18 + 54 + 54 | ||
elif request.param == "expr4": | ||
A = gem.Variable("A", (10, 15)) | ||
B = gem.Variable("B", (8, 10)) | ||
i, j, k = gem.indices(3) | ||
Aij = A[i, j] | ||
Bki = B[k, i] | ||
Cjk = gem.IndexSum(Aij * Bki, (i,)) | ||
expr = Cjk | ||
expr, = preprocess_gem([expr]) | ||
assignments = [(gem.Variable("C", (15, 8))[j, k], expr)] | ||
expr = compile_gem(assignments, (j, k)) | ||
# Cjk += \sum_i Aij * Bki | ||
expected = 2 * 10 * 8 * 15 | ||
|
||
else: | ||
raise ValueError("Unexpected expression") | ||
return expr, expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
""" | ||
Tests impero flop counts against loopy. | ||
""" | ||
import pytest | ||
import numpy | ||
import loopy | ||
from tsfc import compile_form | ||
from ufl import (FiniteElement, FunctionSpace, Mesh, TestFunction, | ||
TrialFunction, VectorElement, dx, grad, inner, | ||
interval, triangle, quadrilateral, | ||
TensorProductCell) | ||
|
||
|
||
def count_loopy_flops(kernel): | ||
name = kernel.name | ||
program = kernel.ast | ||
program = program.with_kernel( | ||
program[name].copy( | ||
target=loopy.CTarget(), | ||
silenced_warnings=["insn_count_subgroups_upper_bound", | ||
"get_x_map_guessing_subgroup_size"]) | ||
) | ||
op_map = loopy.get_op_map(program | ||
.with_entrypoints(kernel.name), | ||
numpy_types=None, | ||
subgroup_size=1) | ||
return op_map.filter_by(name=['add', 'sub', 'mul', 'div', | ||
'func:abs'], | ||
dtype=[float]).eval_and_sum({}) | ||
|
||
|
||
@pytest.fixture(params=[interval, triangle, quadrilateral, | ||
TensorProductCell(triangle, interval)], | ||
ids=lambda cell: cell.cellname()) | ||
def cell(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture(params=[{"mode": "vanilla"}, | ||
{"mode": "spectral"}], | ||
ids=["vanilla", "spectral"]) | ||
def parameters(request): | ||
return request.param | ||
|
||
|
||
def test_flop_count(cell, parameters): | ||
mesh = Mesh(VectorElement("P", cell, 1)) | ||
loopy_flops = [] | ||
new_flops = [] | ||
for k in range(1, 5): | ||
V = FunctionSpace(mesh, FiniteElement("P", cell, k)) | ||
u = TrialFunction(V) | ||
v = TestFunction(V) | ||
a = inner(u, v)*dx + inner(grad(u), grad(v))*dx | ||
kernel, = compile_form(a, prefix="form", | ||
parameters=parameters, | ||
coffee=False) | ||
# Record new flops here, and compare asymptotics and | ||
# approximate order of magnitude. | ||
new_flops.append(kernel.flop_count) | ||
loopy_flops.append(count_loopy_flops(kernel)) | ||
|
||
new_flops = numpy.asarray(new_flops) | ||
loopy_flops = numpy.asarray(loopy_flops) | ||
|
||
assert all(new_flops == loopy_flops) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.