Skip to content

Commit

Permalink
Merge pull request #252 from AWhitmell/andrew/benchmarking
Browse files Browse the repository at this point in the history
Add flop counting of scheduled impero.
  • Loading branch information
wence- authored Jul 12, 2021
2 parents 8eccb7d + 831e03e commit b7c66f1
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 14 deletions.
197 changes: 197 additions & 0 deletions gem/flop_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
This file contains all the necessary functions to accurately count the
total number of floating point operations for a given script.
"""

import gem.gem as gem
import gem.impero as imp
from functools import singledispatch
import numpy
import math


@singledispatch
def statement(tree, parameters):
raise NotImplementedError


@statement.register(imp.Block)
def statement_block(tree, parameters):
flops = sum(statement(child, parameters) for child in tree.children)
return flops


@statement.register(imp.For)
def statement_for(tree, parameters):
extent = tree.index.extent
assert extent is not None
child, = tree.children
flops = statement(child, parameters)
return flops * extent


@statement.register(imp.Initialise)
def statement_initialise(tree, parameters):
return 0


@statement.register(imp.Accumulate)
def statement_accumulate(tree, parameters):
flops = expression_flops(tree.indexsum.children[0], parameters)
return flops + 1


@statement.register(imp.Return)
def statement_return(tree, parameters):
flops = expression_flops(tree.expression, parameters)
return flops + 1


@statement.register(imp.ReturnAccumulate)
def statement_returnaccumulate(tree, parameters):
flops = expression_flops(tree.indexsum.children[0], parameters)
return flops + 1


@statement.register(imp.Evaluate)
def statement_evaluate(tree, parameters):
flops = expression_flops(tree.expression, parameters, top=True)
return flops


@singledispatch
def flops(expr, parameters):
raise NotImplementedError(f"Don't know how to count flops of {type(expr)}")


@flops.register(gem.Failure)
def flops_failure(expr, parameters):
raise ValueError("Not expecting a Failure node")


@flops.register(gem.Variable)
@flops.register(gem.Identity)
@flops.register(gem.Delta)
@flops.register(gem.Zero)
@flops.register(gem.Literal)
@flops.register(gem.Index)
@flops.register(gem.VariableIndex)
def flops_zero(expr, parameters):
# Initial set up of these Gem nodes are of 0 floating point operations.
return 0


@flops.register(gem.LogicalNot)
@flops.register(gem.LogicalAnd)
@flops.register(gem.LogicalOr)
@flops.register(gem.ListTensor)
def flops_zeroplus(expr, parameters):
# These nodes contribute 0 floating point operations, but their children may not.
return 0 + sum(expression_flops(child, parameters)
for child in expr.children)


@flops.register(gem.Product)
def flops_product(expr, parameters):
# Multiplication by -1 is not a flop.
a, b = expr.children
if isinstance(a, gem.Literal) and a.value == -1:
return expression_flops(b, parameters)
elif isinstance(b, gem.Literal) and b.value == -1:
return expression_flops(a, parameters)
else:
return 1 + sum(expression_flops(child, parameters)
for child in expr.children)


@flops.register(gem.Sum)
@flops.register(gem.Division)
@flops.register(gem.Comparison)
@flops.register(gem.MathFunction)
@flops.register(gem.MinValue)
@flops.register(gem.MaxValue)
def flops_oneplus(expr, parameters):
return 1 + sum(expression_flops(child, parameters)
for child in expr.children)


@flops.register(gem.Power)
def flops_power(expr, parameters):
base, exponent = expr.children
base_flops = expression_flops(base, parameters)
if isinstance(exponent, gem.Literal):
exponent = exponent.value
if exponent > 0 and exponent == math.floor(exponent):
return base_flops + int(math.ceil(math.log2(exponent)))
else:
return base_flops + 5 # heuristic
else:
return base_flops + 5 # heuristic


@flops.register(gem.Conditional)
def flops_conditional(expr, parameters):
condition, then, else_ = (expression_flops(child, parameters)
for child in expr.children)
return condition + max(then, else_)


@flops.register(gem.Indexed)
@flops.register(gem.FlexiblyIndexed)
def flops_indexed(expr, parameters):
aggregate = sum(expression_flops(child, parameters)
for child in expr.children)
# Average flops per entry
return aggregate / numpy.product(expr.children[0].shape, dtype=int)


@flops.register(gem.IndexSum)
def flops_indexsum(expr, parameters):
raise ValueError("Not expecting IndexSum")


@flops.register(gem.Inverse)
def flops_inverse(expr, parameters):
n, _ = expr.shape
# 2n^3 + child flop count
return 2*n**3 + sum(expression_flops(child, parameters)
for child in expr.children)


@flops.register(gem.Solve)
def flops_solve(expr, parameters):
n, m = expr.shape
# 2mn + inversion cost of A + children flop count
return 2*n*m + 2*n**3 + sum(expression_flops(child, parameters)
for child in expr.children)


@flops.register(gem.ComponentTensor)
def flops_componenttensor(expr, parameters):
raise ValueError("Not expecting ComponentTensor")


def expression_flops(expression, parameters, top=False):
"""An approximation to flops required for each expression.
:arg expression: GEM expression.
:arg parameters: Useful miscellaneous information.
:arg top: are we at the root?
:returns: flop count for the expression
"""
if not top and expression in parameters.temporaries:
return 0
else:
return flops(expression, parameters)


def count_flops(impero_c):
"""An approximation to flops required for a scheduled impero_c tree.
:arg impero_c: a :class:`~.Impero_C` object.
:returns: approximate flop count for the tree.
"""
try:
return statement(impero_c.tree, impero_c)
except (ValueError, NotImplementedError):
return 0
64 changes: 64 additions & 0 deletions tests/test_flop_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import gem.gem as gem
from gem.flop_count import count_flops
from gem.impero_utils import preprocess_gem
from gem.impero_utils import compile_gem


def test_count_flops(expression):
expr, expected = expression
flops = count_flops(expr)
assert flops == expected


@pytest.fixture(params=("expr1", "expr2", "expr3", "expr4"))
def expression(request):
if request.param == "expr1":
expr = gem.Sum(gem.Product(gem.Variable("a", ()), gem.Literal(2)),
gem.Division(gem.Literal(3), gem.Variable("b", ())))
C = gem.Variable("C", (1,))
i, = gem.indices(1)
Ci = C[i]
expr, = preprocess_gem([expr])
assignments = [(Ci, expr)]
expr = compile_gem(assignments, (i,))
# C += a*2 + 3/b
expected = 1 + 3
elif request.param == "expr2":
expr = gem.Comparison(">=", gem.MaxValue(gem.Literal(1), gem.Literal(2)),
gem.MinValue(gem.Literal(3), gem.Literal(1)))
C = gem.Variable("C", (1,))
i, = gem.indices(1)
Ci = C[i]
expr, = preprocess_gem([expr])
assignments = [(Ci, expr)]
expr = compile_gem(assignments, (i,))
# C += max(1, 2) >= min(3, 1)
expected = 1 + 3
elif request.param == "expr3":
expr = gem.Solve(gem.Identity(3), gem.Inverse(gem.Identity(3)))
C = gem.Variable("C", (3, 3))
i, j = gem.indices(2)
Cij = C[i, j]
expr, = preprocess_gem([expr[i, j]])
assignments = [(Cij, expr)]
expr = compile_gem(assignments, (i, j))
# C += solve(Id(3x3), Id(3x3)^{-1})
expected = 9 + 18 + 54 + 54
elif request.param == "expr4":
A = gem.Variable("A", (10, 15))
B = gem.Variable("B", (8, 10))
i, j, k = gem.indices(3)
Aij = A[i, j]
Bki = B[k, i]
Cjk = gem.IndexSum(Aij * Bki, (i,))
expr = Cjk
expr, = preprocess_gem([expr])
assignments = [(gem.Variable("C", (15, 8))[j, k], expr)]
expr = compile_gem(assignments, (j, k))
# Cjk += \sum_i Aij * Bki
expected = 2 * 10 * 8 * 15

else:
raise ValueError("Unexpected expression")
return expr, expected
66 changes: 66 additions & 0 deletions tests/test_impero_loopy_flop_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Tests impero flop counts against loopy.
"""
import pytest
import numpy
import loopy
from tsfc import compile_form
from ufl import (FiniteElement, FunctionSpace, Mesh, TestFunction,
TrialFunction, VectorElement, dx, grad, inner,
interval, triangle, quadrilateral,
TensorProductCell)


def count_loopy_flops(kernel):
name = kernel.name
program = kernel.ast
program = program.with_kernel(
program[name].copy(
target=loopy.CTarget(),
silenced_warnings=["insn_count_subgroups_upper_bound",
"get_x_map_guessing_subgroup_size"])
)
op_map = loopy.get_op_map(program
.with_entrypoints(kernel.name),
numpy_types=None,
subgroup_size=1)
return op_map.filter_by(name=['add', 'sub', 'mul', 'div',
'func:abs'],
dtype=[float]).eval_and_sum({})


@pytest.fixture(params=[interval, triangle, quadrilateral,
TensorProductCell(triangle, interval)],
ids=lambda cell: cell.cellname())
def cell(request):
return request.param


@pytest.fixture(params=[{"mode": "vanilla"},
{"mode": "spectral"}],
ids=["vanilla", "spectral"])
def parameters(request):
return request.param


def test_flop_count(cell, parameters):
mesh = Mesh(VectorElement("P", cell, 1))
loopy_flops = []
new_flops = []
for k in range(1, 5):
V = FunctionSpace(mesh, FiniteElement("P", cell, k))
u = TrialFunction(V)
v = TestFunction(V)
a = inner(u, v)*dx + inner(grad(u), grad(v))*dx
kernel, = compile_form(a, prefix="form",
parameters=parameters,
coffee=False)
# Record new flops here, and compare asymptotics and
# approximate order of magnitude.
new_flops.append(kernel.flop_count)
loopy_flops.append(count_loopy_flops(kernel))

new_flops = numpy.asarray(new_flops)
loopy_flops = numpy.asarray(loopy_flops)

assert all(new_flops == loopy_flops)
5 changes: 2 additions & 3 deletions tests/test_sum_factorisation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy
import pytest

from coffee.visitors import EstimateFlops

from ufl import (Mesh, FunctionSpace, FiniteElement, VectorElement,
TestFunction, TrialFunction, TensorProductCell,
EnrichedElement, HCurlElement, HDivElement,
Expand Down Expand Up @@ -68,7 +66,8 @@ def split_vector_laplace(cell, degree):

def count_flops(form):
kernel, = compile_form(form, parameters=dict(mode='spectral'))
return EstimateFlops().visit(kernel.ast)
flops = kernel.flop_count
return flops


@pytest.mark.parametrize(('cell', 'order'),
Expand Down
4 changes: 3 additions & 1 deletion tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import gem
import gem.impero_utils as impero_utils
from gem.flop_count import count_flops

import FIAT
from FIAT.reference_element import TensorProductCell
Expand Down Expand Up @@ -240,6 +241,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
index_ordering = tuple(quadrature_indices) + split_argument_indices
try:
impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True)
flop_count = count_flops(impero_c)
except impero_utils.NoopError:
# No operations, construct empty kernel
return builder.construct_empty_kernel(kernel_name)
Expand All @@ -265,7 +267,7 @@ def name_multiindex(multiindex, name):
for multiindex, name in zip(argument_multiindices, ['j', 'k']):
name_multiindex(multiindex, name)

return builder.construct_kernel(kernel_name, impero_c, index_names, quad_rule)
return builder.construct_kernel(kernel_name, impero_c, index_names, quad_rule, flop_count=flop_count)


def compile_expression_dual_evaluation(expression, to_element, *,
Expand Down
Loading

0 comments on commit b7c66f1

Please sign in to comment.