From 247396cce9f376b53ac130d94af627c58864c9ac Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 6 Dec 2023 15:54:03 -0800 Subject: [PATCH] [tk5] Implement slice analysis and sufficient op coverage for a softmax kernel. (#222) * Adds arithmetic binary broadcast. * Adds f32 arithmetic. * Handles torch ops exp, max, sum (only reduce to scalar at the moment). * General python signature validation and unpacking. * Some initial attribute propagation work needed to handle unsigned properly. * A start at full generality for symbolic dimensions and slice analysis. Needs to be reworked before too long but this gets us basic functionality on static kernels while leaving room for dynamic to emerge. --- .../shark_turbine/kernel/_support/indexing.py | 192 +++++++++- .../shark_turbine/kernel/_support/tracing.py | 17 +- .../shark_turbine/kernel/compiler/analysis.py | 250 +++++++++++++ python/shark_turbine/kernel/compiler/base.py | 6 + .../shark_turbine/kernel/compiler/builder.py | 140 ++++++- python/shark_turbine/kernel/compiler/ir.py | 1 + .../kernel/compiler/op_matchers.py | 64 ++++ .../kernel/compiler/vector_codegen.py | 354 +++++++++++------- python/shark_turbine/kernel/gen/thread.py | 2 +- tests/kernel/analysis_test.py | 65 ++++ tests/kernel/indexing_test.py | 7 + tests/kernel/vector_codegen_test.py | 33 ++ 12 files changed, 986 insertions(+), 145 deletions(-) create mode 100644 python/shark_turbine/kernel/compiler/analysis.py create mode 100644 python/shark_turbine/kernel/compiler/base.py create mode 100644 python/shark_turbine/kernel/compiler/op_matchers.py create mode 100644 tests/kernel/analysis_test.py diff --git a/python/shark_turbine/kernel/_support/indexing.py b/python/shark_turbine/kernel/_support/indexing.py index ebf7f219e..7d0fa4b27 100644 --- a/python/shark_turbine/kernel/_support/indexing.py +++ b/python/shark_turbine/kernel/_support/indexing.py @@ -10,6 +10,7 @@ from . import context __all__ = [ + "BoundedSymbolicValue", "KernelBuffer", "Grid", "InputBuffer", @@ -17,6 +18,10 @@ "SymbolDef", "TemporaryBuffer", "sym", + "sym_0", + "sym_1", + "sym_2", + "sym_n1", ] @@ -73,7 +78,37 @@ def ir_type_asm(self) -> str: ############################################################################### -class SymbolDef: +class SymbolExpr: + def is_one(self) -> Optional[bool]: + """Returns True if the symbol is known to be 1. + + Return False if known to be != 1 and None if not known. + """ + raise NotImplementedError + + def is_non_negative(self) -> Optional[bool]: + """Returns True is the symbol is known to be non-negative. + + Returns False if known to be negative and None if not known. + """ + raise NotImplementedError + + def is_positive(self) -> Optional[bool]: + """Returns True is the symbol is known to be greater than zero. + + Returns False if known to be <= 0 and None if not known. + """ + raise NotImplementedError + + def is_negative(self) -> Optional[bool]: + """Returns True is the symbol is known to be greater than zero. + + Returns False if known to be <= 0 and None if not known. + """ + raise NotImplementedError + + +class SymbolDef(SymbolExpr): """Represents a named symbol representing a dimension in a shape.""" ALL_SYMBOLS: ClassVar[dict[str, "SymbolDef"]] = dict() @@ -101,9 +136,153 @@ def __getattr__(self, n): return Expando() + def is_one(self) -> Optional[bool]: + value = IndexingContext.current().get_static_value(self) + if value is None: + return None + return value == 1 + + def is_non_negative(self) -> Optional[bool]: + value = IndexingContext.current().get_static_value(self) + if value is None: + return None + return value >= 0 + + def is_positive(self) -> Optional[bool]: + value = IndexingContext.current().get_static_value(self) + if value is None: + return None + return value > 0 + + def is_negative(self) -> Optional[bool]: + value = IndexingContext.current().get_static_value(self) + if value is None: + return None + return value < 0 + sym = SymbolDef.create_expando() +sym_0 = SymbolDef("0") +sym_1 = SymbolDef("1") +sym_2 = SymbolDef("2") +sym_n1 = SymbolDef("-1") + + +############################################################################### +# Bounded symbolic value. +############################################################################### + +BoundedRangeExprT = tuple[Optional[SymbolExpr], Optional[SymbolExpr]] + + +class _BoundedSymbolicValueMeta(type): + """Meta-class for deriving new bounded symbolic values.""" + + range: BoundedRangeExprT + + def __new__(mcls, name: str, bases, dct, *, range: BoundedRangeExprT): + dct["range"] = range + dct["__qualname__"] = _bounded_symbolic_value_repr(range=range) + new_class = type.__new__(mcls, name, bases, dct) + return new_class + + def __repr__(cls): + return _bounded_symbolic_value_repr(range=cls.range) + + @property + def min_bound(cls) -> Optional[SymbolExpr]: + return cls.range[0] + + @property + def max_bound(cls) -> Optional[SymbolExpr]: + return cls.range[1] + + def bound( + cls: Type[SubtypeT], + min_bound: Optional[SymbolExpr], + max_bound: Optional[SymbolExpr], + ) -> Type[SubtypeT]: + class Bounded(BoundedSymbolicValue, range=(min_bound, max_bound)): + ... + + return Bounded + + def narrow( + cls: Type[SubtypeT], + *, + min_bound: Optional[SymbolExpr] = None, + max_bound: Optional[SymbolExpr] = None, + ) -> Type[SubtypeT]: + class Bounded( + BoundedSymbolicValue, + range=( + min_bound if min_bound is not None else cls.min_bound, + max_bound if max_bound is not None else cls.max_bound, + ), + ): + ... + + return Bounded + + +def _bounded_symbolic_value_repr(*, range: BoundedRangeExprT) -> str: + min_expr, max_expr = range + min_s = repr(min_expr) if min_expr is not None else "*" + max_s = repr(max_expr) if max_expr is not None else "*" + return f"BoundedSymbolicValue({min_s} : {max_s})" + + +class BoundedSymbolicValue( + SymbolExpr, metaclass=_BoundedSymbolicValueMeta, range=(None, None) +): + """Represents a symbolic value that is bounded to a range fixed for the type.""" + + def __init__(self, value: Optional[int] = None): + self.value = value + + def __repr__(self): + return f"{type(self)}({'proxy' if self.value is None else self.value})" + + @property + def static_range(self) -> Optional[tuple[int, int]]: + # TODO: This is a hack until shape derivation is in place. + ctx = IndexingContext.current() + mn, mx = type(self).range + if mn is not None: + mn = ctx.get_static_value(mn) + if mx is not None: + mx = ctx.get_static_value(mx) + if mn is not None and mx is not None: + return mn, mx + else: + return None + + def is_one(self) -> Optional[bool]: + r = self.static_range + if r: + return r[0] == 1 and r[1] == 2 + return None + + def is_non_negative(self) -> Optional[bool]: + r = self.static_range + if r: + return r[0] >= 0 + return None + + def is_positive(self) -> Optional[bool]: + r = self.static_range + if r: + return r[0] > 0 + return None + + def is_negative(self) -> Optional[bool]: + r = self.static_range + if r: + return r[1] < 0 + return None + + ############################################################################### # Grid ############################################################################### @@ -271,7 +450,7 @@ def __repr__(cls): ) -def _is_kernel_buffer_meta_derived(t: type) -> bool: +def is_kernel_buffer_meta_derived(t: type) -> bool: return isinstance(t, _KernelBufferMeta) @@ -361,7 +540,12 @@ class IndexingContext: __tk_context_idname__ = "IndexingContext" def __init__(self): - self.constant_bindings: dict[SymbolDef, int] = {} + self.constant_bindings: dict[SymbolDef, int] = { + sym_0: 0, + sym_1: 1, + sym_2: 2, + sym_n1: -1, + } def bind_constant(self, sym: SymbolDef, value: int): existing = self.constant_bindings.get(sym) @@ -371,7 +555,7 @@ def bind_constant(self, sym: SymbolDef, value: int): ) self.constant_bindings[sym] = value - def get_static_value(self, sym: SymbolDef) -> Optional[int]: + def get_static_value(self, sym: SymbolExpr) -> Optional[int]: """If the symbol can be resolved to a static value, returns it.""" return self.constant_bindings.get(sym) diff --git a/python/shark_turbine/kernel/_support/tracing.py b/python/shark_turbine/kernel/_support/tracing.py index 822abf87b..004075c4d 100644 --- a/python/shark_turbine/kernel/_support/tracing.py +++ b/python/shark_turbine/kernel/_support/tracing.py @@ -7,7 +7,10 @@ import torch.fx as fx from .indexing import ( + BoundedSymbolicValue, + Grid, KernelBuffer, + sym_0, ) from ..lang.types import ( @@ -106,13 +109,23 @@ def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, ite class CompiledContext(BaseContext): - def __init__(self, tracer: KernelTracer): + def __init__(self, tracer: KernelTracer, *, grid_type: Type[Grid]): super().__init__(eager=False) self.tracer = tracer + self.grid_type = grid_type def handle_thread_program_id(self, op, axis: int) -> Index: + grid_shape = self.grid_type.symbolic_shape + if axis < 0 or axis >= len(grid_shape): + raise IndexError( + f"Illegal index into grid of rank {len(grid_shape)}: {axis}" + ) proxy = self.tracer.create_proxy( - "call_function", op, args=(axis,), kwargs={}, type_expr=Index + "call_function", + op, + args=(axis,), + kwargs={}, + type_expr=BoundedSymbolicValue.bound(sym_0, grid_shape[axis]), ) return proxy diff --git a/python/shark_turbine/kernel/compiler/analysis.py b/python/shark_turbine/kernel/compiler/analysis.py new file mode 100644 index 000000000..19ef2ee59 --- /dev/null +++ b/python/shark_turbine/kernel/compiler/analysis.py @@ -0,0 +1,250 @@ +from typing import Any, Optional, Type, Union + +import torch.fx as fx + +from .base import CodegenError + +from .._support.indexing import ( + BoundedSymbolicValue, + IndexingContext, + SymbolDef, + SymbolExpr, +) + + +NormalizedSlice = list[Union[slice, None]] + + +def _symbolize_slice_value(value): + # TODO: I don't like this and wish this happened more automatically somehow. + if isinstance(value, fx.Node): + sym_type = value.type + if sym_type and issubclass(sym_type, BoundedSymbolicValue): + return sym_type() + return value + else: + return value + + +def _norm_slice_spec(rank: int, slice_spec) -> NormalizedSlice: + def _norm_single_slice(s): + if s is None or s is ...: + return s + if isinstance(s, slice): + # Validate. + if s.step == 0: + # A zero step is illegal, but we use it to signal an integer index + # vs a range. + raise IndexError(f"slice with step 0 is illegal (got {s})") + return s + else: + # Promote a raw value to our special 0-step slice. + return slice(s, 0, 0) + + if not isinstance(slice_spec, tuple): + slice_spec = (slice_spec,) + norm_slices = [_norm_single_slice(s) for s in slice_spec] + + # Replace any ellipses with rank-filling None values. + none_count = norm_slices.count(None) + ellipses_count = norm_slices.count(...) + if ellipses_count == 1: + # Expand by the original list of slices less any unit dim insertions. + # If negative, this does nothing and will be caught later upon + # rank validation. + expand_index = norm_slices.index(...) + del norm_slices[expand_index] + expansion_count = (rank + none_count) - len(norm_slices) + for _ in range(expansion_count): + norm_slices.insert(expand_index, slice(None)) + elif ellipses_count > 1: + raise IndexError( + f"Cannot index into a rank expanding referrent with multiple `...` values" + ) + return norm_slices + + +class SliceAnalysis: + """Analyses Python slicing notations such that it can be validated and code generated. + + The numpy page has a good description here: + https://numpy.org/doc/1.26/user/basics.indexing.html + + This class analyzes: + * Basic Indexing + * Slicing and Striding + * Dimensional Indexing Tools + + Note that `None` is interpreted as `np.newaxis` (which we do not support). + + Each element of a slice specification can be: + * An arbitrary Python value representing a single element span + * None to indicate a new unit dimension + * Ellipses to indicate space filling `slice()` + * A `slice` object + + Such a specification is decomposed into a `source_slice` which does not + include any rank broadcasting and a `broadcast_slice` which includes + any rank expansion. Depending on the operation being code generated, + these will be handled differently. All loose Python values will + be promoted into a `slice` object. + + Raises: + IndexError on any violations of Python indexing semantics which + are statically determined during analysis. + """ + + def __init__(self, ref: tuple[SymbolExpr, ...], slice_spec): + self.ref = ref + self.slices = _norm_slice_spec(len(ref), slice_spec) + + # Compute an expanded version of ref that has None values for + # any to-be-inserted unit dims. This will be the same size + # as slices. + self.expanded_ref: list[Optional[SymbolExpr]] = list(ref) + for i in (i for i, entry in enumerate(self.slices) if entry is None): + self.expanded_ref.insert(i, None) + assert len(self.expanded_ref) == len(self.slices) + self._is_symbolic_normalized = False + + def __repr__(self): + return repr(self.slices) + + def normalize_symbolic_ranges( + self, *, allow_reverse_step: bool = False, allow_non_unit_step: bool = False + ): + """Uses the IndexingContext to normalize range for any None fields. + + This fully populates the fields of every slice with either + integers or SymbolExprs. Does not modify slice fields that are + non-None. + + Raises IndexError for any variations that are not supported + or cannot be statically derived. + """ + if self._is_symbolic_normalized: + return + + def norm(dim_expr: SymbolExpr, s: Optional[slice]) -> Optional[slice]: + if s is None: + return s + ctx = IndexingContext.current() + start = s.start + stop = s.stop + step = s.step + + # Set defaults. + if start is None: + start = 0 + if stop is None: + stop = dim_expr + if step is None: + step = 1 + + # Symbolize for analysis. + start_sym = _symbolize_slice_value(start) + stop_sym = _symbolize_slice_value(stop) + step_sym = _symbolize_slice_value(step) + + # Evaluate facts for start. + if isinstance(start_sym, SymbolExpr): + start_is_non_negative = start_sym.is_non_negative() + elif isinstance(start_sym, int): + start_is_non_negative = start_sym >= 0 + else: + raise IndexError( + f"A symbolically evaluable start index is required (got: {start_sym} (type {type(start_sym)}))" + ) + + # Evaluate facts for stop. + if isinstance(stop_sym, SymbolExpr): + stop_is_non_negative = stop_sym.is_non_negative() + stop_is_zero = False + elif isinstance(stop_sym, int): + stop_is_non_negative = stop_sym >= 0 + stop_is_zero = stop_sym == 0 + else: + raise IndexError( + f"A symbolically evaluable stop index is required (got: {stop_sym} (type {type(stop_sym)}))" + ) + + # Evaluate facts for step. + if isinstance(step_sym, SymbolExpr): + reverse_step = step_sym.is_negative() + unit_step = step_sym.is_one() + zero_step = False + elif isinstance(step_sym, int): + reverse_step = step_sym < 0 + unit_step = step_sym == 1 + zero_step = step_sym == 0 + else: + raise IndexError( + f"A symbolically evaluable step is required (got: {step_sym} (type {type(step_sym)}))" + ) + + # Validate step constraints. + if zero_step: + # This is our special marker for a unit (non-range extract). + assert ( + stop_is_zero + ), "slices of non zero stop and zero step should have been filtered" + else: + if not allow_non_unit_step and not unit_step: + raise IndexError( + f"Only unit steps are supported in this context (got slice({start_sym}, {stop_sym}, {step_sym}))" + ) + + if not allow_reverse_step and reverse_step: + raise IndexError( + f"Only forward steps are supported in this context (got slice({start_sym}, {stop_sym}, {step_sym}))" + ) + + # Normalize negative start/stop. + if not start_is_non_negative: + raise IndexError(f"NYI: Negative slice start") + if not stop_is_non_negative: + raise IndexError(f"NYI: Negative slice stop") + + return slice(start, stop, step) + + for i in range(len(self.slices)): + expr = self.expanded_ref[i] + self.slices[i] = norm(expr, self.slices[i]) + self._is_symbolic_normalized = True + + @property + def symbolic_shape(self) -> tuple[Union[int, SymbolExpr]]: + """Resolves the symbolic shape of the result of this slice. + + Forces symbolic normalization if it has not already been done. + Any rank broadcast dimensions will be retained as None. + """ + self.normalize_symbolic_ranges() + + def _item(s: Optional[slice]): + if s is None: + return None + ctx = IndexingContext.current() + # Detect special unit 0-step slices. + if s.stop == 0 and s.step == 0: + return 1 + start = s.start + stop = s.stop + + # TODO: This is a hack to work around that I don't have the full + # symbolic expression support in yet. We should just be asking + # the symbols to evaluate. + if isinstance(start, SymbolExpr): + static_start = ctx.get_static_value(start) + elif isinstance(start, int): + static_start = start + if isinstance(stop, SymbolExpr): + static_stop = ctx.get_static_value(stop) + elif isinstance(stop, int): + static_stop = stop + if static_start is not None and static_stop is not None: + return static_stop - static_start + + raise IndexError(f"NYI: Non-statically resolved symbolic shapes") + + return [_item(s) for s in self.slices] diff --git a/python/shark_turbine/kernel/compiler/base.py b/python/shark_turbine/kernel/compiler/base.py new file mode 100644 index 000000000..1c896313f --- /dev/null +++ b/python/shark_turbine/kernel/compiler/base.py @@ -0,0 +1,6 @@ +class CodegenError(Exception): + ... + + +class ValidationError(CodegenError): + ... diff --git a/python/shark_turbine/kernel/compiler/builder.py b/python/shark_turbine/kernel/compiler/builder.py index e1ed0d9a2..89a149c6d 100644 --- a/python/shark_turbine/kernel/compiler/builder.py +++ b/python/shark_turbine/kernel/compiler/builder.py @@ -1,19 +1,42 @@ from typing import Optional +from .base import ( + CodegenError, +) + from .ir import ( + Attribute, Context, + FloatAttr, + IndexType, + IntegerAttr, + IntegerType, + IrType, Location, Operation, + Value, + VectorType, + arith_d, builtin_d, ) +# TODO: Have a way upstream to check if a floating point type. +FLOAT_TYPES_ASM = { + "bf16", + "f16", + "f32", + "f64", + # TODO: FP8 types. +} + + class ModuleBuilder: def __init__( self, *, context: Optional[Context] = None, - module_op: Optional[Operation] = None + module_op: Optional[Operation] = None, ): if module_op: self.module_op = module_op @@ -24,3 +47,118 @@ def __init__( self.module_op = builtin_d.ModuleOp(loc=Location.unknown(context)) self.body_block = self.module_op.body self.context = self.module_op.context + + +class _ScalarBuilder: + def is_floating_point_type(self, t: IrType) -> bool: + return str(t) in FLOAT_TYPES_ASM + + def is_integer_type(self, t: IrType) -> bool: + return IntegerType.isinstance(t) + + def promote(self, value: Value, to_type: IrType) -> Value: + value_type = value.type + # Short-circuit if already the right type. + if value_type == to_type: + return value + + attr_name = f"promote_{value_type}_to_{to_type}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"No implemented path to implicitly promote scalar `{value_type}` to `{to_type}` (tried '{attr_name}')" + ) + return handler(value, to_type) + + def zero_attr(self, t: IrType) -> Attribute: + attr_name = f"zero_attr_{t}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"Cannot derive a zero value for type `{t}` (tried '{attr_name}')" + ) + return handler(t) + + def constant(self, py_value) -> Value: + attr_name = f"py_constant_{type(py_value).__name__}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"Cannot convert Python value to constant: {py_value} of type {type(py_value)} (tried '{attr_name}')" + ) + return handler(py_value) + + def binary_arithmetic(self, op: str, lhs: Value, rhs: Value) -> Value: + attr_name = f"binary_{op}_{lhs.type}_{rhs.type}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')" + ) + return handler(lhs, rhs) + + def binary_vector_arithmetic(self, op: str, lhs: Value, rhs: Value) -> Value: + lhs_element_type = VectorType(lhs.type).element_type + rhs_element_type = VectorType(rhs.type).element_type + attr_name = f"binary_{op}_{lhs_element_type}_{rhs_element_type}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')" + ) + return handler(lhs, rhs) + + def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value: + i32_type = IntegerType.get_signless(32) + i32 = arith_d.index_cast(i32_type, value) + return arith_d.sitofp(to_type, i32) + + def zero_attr_f32(self, t: IrType) -> Attribute: + return FloatAttr.get(t, 0.0) + + def py_constant_int(self, py_value) -> Value: + # If coming from a stock 'int' Python type with no idea how to convert it, + # there isn't much smart we can do. We conservatively treat 'index' as + # reasonable. + attr = IntegerAttr.get(IndexType.get(), py_value) + return arith_d.constant(attr) + + # Binary index arithmetic. + def binary_add_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.addi(lhs, rhs) + + def binary_mul_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.muli(lhs, rhs) + + def binary_sub_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.subi(lhs, rhs) + + def binary_mod_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.remsi(lhs, rhs) + + def binary_floordiv_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.floordivsi(lhs, rhs) + + # Binary f32 arithmetic + def binary_add_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.addf(lhs, rhs) + + def binary_mul_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.mulf(lhs, rhs) + + def binary_sub_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.subf(lhs, rhs) + + def binary_mod_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.remf(lhs, rhs) + + def binary_truediv_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.divf(lhs, rhs) + + +ScalarBuilder = _ScalarBuilder() diff --git a/python/shark_turbine/kernel/compiler/ir.py b/python/shark_turbine/kernel/compiler/ir.py index c73408897..84fb01259 100644 --- a/python/shark_turbine/kernel/compiler/ir.py +++ b/python/shark_turbine/kernel/compiler/ir.py @@ -26,5 +26,6 @@ arith as arith_d, builtin as builtin_d, func as func_d, + math as math_d, vector as vector_d, ) diff --git a/python/shark_turbine/kernel/compiler/op_matchers.py b/python/shark_turbine/kernel/compiler/op_matchers.py new file mode 100644 index 000000000..81f738a5e --- /dev/null +++ b/python/shark_turbine/kernel/compiler/op_matchers.py @@ -0,0 +1,64 @@ +from typing import Optional + +import torch +from torch import Tensor + +import functools +import inspect + + +def signature_matcher(f=None, *, arity: Optional[int] = None, original_name: str = ""): + """Transforms a function into a signature matcher. + + The transfored function takes the same args/kwargs as the original, but + it will return an inspect.BoundArguments.arguments when invoked. + + Optional overload selectors can be specified, and if not met, None + will be returned (versus raising an error). + + On argument mismatch, a TypeError will be raised. + """ + if f is None: + return functools.partial( + signature_matcher, arity=arity, original_name=original_name + ) + + sig = inspect.signature(f) + + def wrapped(*args, **kwargs) -> Optional[inspect.BoundArguments]: + if arity is not None and arity != (len(args) + len(kwargs)): + return None + try: + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + return bound_args.arguments + except TypeError as e: + reported_name = original_name or f.__name__ + raise TypeError(f"{reported_name}() {str(e)}") + + return wrapped + + +@signature_matcher(original_name="torch.exp") +def torch_exp(input: Tensor) -> Tensor: + ... + + +@signature_matcher(arity=1, original_name="torch.max") +def torch_max_unary(input: Tensor) -> Tensor: + ... + + +@signature_matcher(original_name="torch.max") +def torch_max(input: Tensor, dim: int, keepdim: bool = False): + ... + + +@signature_matcher(arity=1, original_name="torch.sum") +def torch_sum_unary(input: Tensor) -> Tensor: + ... + + +@signature_matcher(original_name="torch.sum") +def torch_sum(input: Tensor, dim: int, keepdim: bool = False): + ... diff --git a/python/shark_turbine/kernel/compiler/vector_codegen.py b/python/shark_turbine/kernel/compiler/vector_codegen.py index 880dc8578..6828d4871 100644 --- a/python/shark_turbine/kernel/compiler/vector_codegen.py +++ b/python/shark_turbine/kernel/compiler/vector_codegen.py @@ -1,8 +1,10 @@ from typing import Any, Callable, Type, Optional, Sequence, Union from dataclasses import dataclass +import inspect import operator as py_operator +import torch import torch.fx as fx from .._support.indexing import ( @@ -10,7 +12,7 @@ IndexingContext, KernelBuffer, SymbolDef, - _is_kernel_buffer_meta_derived, + is_kernel_buffer_meta_derived, ) from ..lang import ( @@ -19,23 +21,26 @@ from .. import ops +from .analysis import ( + SliceAnalysis, +) + from .builder import ( ModuleBuilder, + ScalarBuilder, +) + +from .base import ( + CodegenError, + ValidationError, ) from .ir import ( - AffineConstantExpr, - AffineExpr, AffineMap, AffineMapAttr, - Attribute, - DenseElementsAttr, - FloatAttr, FunctionType, IndexType, InsertionPoint, - IntegerAttr, - IntegerType, IrType, Location, MemRefType, @@ -44,21 +49,15 @@ VectorType, arith_d, func_d, + math_d, vector_d, ) +from . import op_matchers ArgTypeUnion = Union[SymbolDef, Type[KernelBuffer]] -class CodegenError(Exception): - ... - - -class ValidationError(CodegenError): - ... - - @dataclass class ArgMeta: name: Optional[str] = None @@ -66,6 +65,22 @@ class ArgMeta: grid_index: Optional[int] = None +@dataclass +class NodeAttrs: + # By default, integers are assumed signed. We propagate unsigned as graph + # node attrs. + unsigned: bool = False + + @staticmethod + def load(py_value) -> "NodeAttrs": + if isinstance(py_value, fx.Node): + return NodeAttrs(unsigned=bool(py_value.meta.get("unsigned"))) + return NodeAttrs() + + def store(self, node: fx.Node): + node.meta["unsigned"] = self.unsigned + + class Signature: """Represents a function signature. @@ -102,7 +117,7 @@ def sym_to_dim_asm(s: SymbolDef) -> str: def as_mlir_type(t: ArgTypeUnion) -> FunctionType: if isinstance(t, SymbolDef): return IndexType.get() - elif _is_kernel_buffer_meta_derived(t): + elif is_kernel_buffer_meta_derived(t): kb_t = t # type: KernelBuffer element_type_asm = kb_t.element_type.ir_type_asm() symbolic_shape = kb_t.symbolic_shape @@ -129,7 +144,7 @@ def add_from_graph_placeholders(self, graph: fx.Graph): continue t = node.type meta = ArgMeta(name=node.target, node=node) - if _is_kernel_buffer_meta_derived(t): + if is_kernel_buffer_meta_derived(t): self.add_kernel_buffer(t, meta=meta) elif issubclass(t, SymbolDef): self.add_symbol(t, meta=meta) @@ -150,7 +165,6 @@ class ThreadEmitter: OP_HANDLERS: dict[Any, Callable[["ThreadEmitter", fx.Node], None]] = {} def __init__(self, mb: ModuleBuilder, grid: Grid, sig: Signature): - self.scalar_builder = ScalarBuilder(self) self.nv_map: dict[fx.Node, Value] = {} self.grid_index_map: list[Optional[Value]] = [None] * grid.rank @@ -182,18 +196,18 @@ def __init__(self, mb: ModuleBuilder, grid: Grid, sig: Signature): self.ip = InsertionPoint(self.entry_block) def bind_node_result( - self, node: fx.Node, value: Value, type_expr: Optional[type] = None + self, node: fx.Node, value: Value, *, attrs: Optional[NodeAttrs] = None ): assert node not in self.nv_map, f"Cannot rebind node {node}: already bound" - if type_expr is not None: - node.type = type_expr + if attrs is not None: + attrs.store(node) self.nv_map[node] = value def emit_graph(self, graph: fx.Graph): context = self.context for node in graph.nodes: # TODO: Construct a location for the node. - with Location.unknown(context): + with self.ip, Location.unknown(context): if node.op == "call_function": self.emit_function_call_node(node) @@ -228,9 +242,54 @@ def decorator(f: Callable[["ThreadEmitter", fx.Node], None]): (py_operator.sub, "sub"), (py_operator.mod, "mod"), (py_operator.floordiv, "floordiv"), + (py_operator.truediv, "truediv"), ] +def binary_broadcast(lhs: Value, rhs: Value) -> tuple[bool, Value, Value]: + lhs_type = lhs.type + rhs_type = rhs.type + lhs_is_vector = VectorType.isinstance(lhs_type) + rhs_is_vector = VectorType.isinstance(rhs_type) + if not lhs_is_vector and not rhs_is_vector: + # Not vectors: return as-is. + return False, lhs, rhs + + # Promote to vector. + if not lhs_is_vector: + lhs = vector_d.splat(VectorType([], lhs_type), lhs) + if not rhs_is_vector: + rhs = vector_d.splat(VectorType([], rhs_type), rhs) + lhs_type = VectorType(lhs.type) + rhs_type = VectorType(rhs.type) + + broadcast_shape = lhs_type.shape + rhs_shape = rhs_type.shape + rank = max(len(broadcast_shape), len(rhs_shape)) + while len(broadcast_shape) < rank: + broadcast_shape.insert(0, 1) + while len(rhs_shape) < rank: + rhs_shape.insert(0, 1) + + for i in range(rank): + a = broadcast_shape[i] + b = rhs_shape[i] + if a != b: + if a != 1 and b != 1: + raise CodegenError( + f"Binary operands are not broadcast compatible: {lhs_type}, {rhs_type}" + ) + broadcast_shape[i] = rhs_shape[i] = max(a, b) + + lhs_type = VectorType.get(broadcast_shape, lhs_type.element_type) + rhs_type = VectorType.get(broadcast_shape, rhs_type.element_type) + if lhs_type != lhs.type: + lhs = vector_d.broadcast(lhs_type, lhs) + if rhs_type != rhs.type: + rhs = vector_d.broadcast(rhs_type, rhs) + return True, lhs, rhs + + def _define_arithmetic_handlers(): def register(py_operator, mnemonic): @handle_op(py_operator) @@ -242,7 +301,11 @@ def _(emitter: ThreadEmitter, node: fx.Node): lhs = cast_py_value(emitter, lhs) rhs = cast_py_value(emitter, rhs) - result = emitter.scalar_builder.binary_arithmetic(mnemonic, lhs, rhs) + is_vector, lhs, rhs = binary_broadcast(lhs, rhs) + if is_vector: + result = ScalarBuilder.binary_vector_arithmetic(mnemonic, lhs, rhs) + else: + result = ScalarBuilder.binary_arithmetic(mnemonic, lhs, rhs) emitter.bind_node_result(node, result) for py_operator, mnemonic in BINARY_ARITHMETIC_OPS: @@ -272,140 +335,141 @@ def _(emitter: ThreadEmitter, node: fx.Node): except IndexError as e: raise CodegenError("Grid axis out of bounds") from e - emitter.bind_node_result(node, value, Index) + emitter.bind_node_result(node, value) @handle_op(ops.kernel_buffer_getitem) def _(emitter: ThreadEmitter, node: fx.Node): - raise CodegenError("NYI: kernel_buffer_getitem") + try: + kb, slice_spec = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + kb_src, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb) + sa = SliceAnalysis(kb_py_type.symbolic_shape, slice_spec) + sa.normalize_symbolic_ranges() + vector_shape = sa.symbolic_shape + element_type = kb_ir_type.element_type + vector_type = VectorType.get(vector_shape, element_type) + pad_attr = ScalarBuilder.zero_attr(element_type) + indices = cast_indices(emitter, [s.start for s in sa.slices]) + pad_value = arith_d.constant(pad_attr) + result = vector_d.transfer_read( + vector_type, kb_src, indices, AffineMap.get_identity(len(indices)), pad_value + ) + emitter.bind_node_result(node, result) @handle_op(ops.kernel_buffer_setitem) def _(emitter: ThreadEmitter, node: fx.Node): try: - kb, key, item = node.args + kb, slice_spec, item = node.args except ValueError as e: raise ValidationError("Malformed arguments") from e - kb_dest, kb_type = cast_kernel_buffer(emitter, kb) - dest_rank = kb_type.rank - indices = cast_indices(emitter, key) + kb_dest, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb) + dest_rank = kb_ir_type.rank + sa = SliceAnalysis(kb_py_type.symbolic_shape, slice_spec) + sa.normalize_symbolic_ranges() + indices = cast_indices(emitter, [s.start for s in sa.slices]) if dest_rank != len(indices): raise CodegenError( f"Mismatched slice assignment: Expected rank {dest_rank}, got {len(indices)}" ) - insert_vector = cast_vector(emitter, item, element_type=kb_type.element_type) + insert_vector = cast_vector(emitter, item, element_type=kb_ir_type.element_type) insert_type = VectorType(insert_vector.type) insert_rank = insert_type.rank - # This form only supports 0d broadcast or same rank currently. - # TODO: This was specially crafted to make the iota demo work. Need to work - # it out in generality. - if insert_rank != 0 and insert_rank != dest_rank: - raise CodegenError( - f"The shorthand kernel_buffer[...]= assignment syntax only supports same rank assignment or restricted, 0d broadcast" - ) + # Special case rank-0 broadcast. + if insert_rank == 0: + broadcast_type = VectorType.get(dest_rank * [1], kb_ir_type.element_type) + insert_vector = vector_d.broadcast(broadcast_type, insert_vector) - with emitter.ip: - if insert_rank == 0: - broadcast_type = VectorType.get(dest_rank * [1], kb_type.element_type) - insert_vector = vector_d.broadcast(broadcast_type, insert_vector) - permutation_map = AffineMap.get_identity(dest_rank) - vector_d.transfer_write( - None, insert_vector, kb_dest, indices, AffineMapAttr.get(permutation_map) - ) + permutation_map = AffineMap.get_identity(dest_rank) + vector_d.transfer_write( + None, insert_vector, kb_dest, indices, AffineMapAttr.get(permutation_map) + ) ############################################################################### -# Conversion utilities +# Torch and math ops ############################################################################### -class ScalarBuilder: - def __init__(self, emitter: ThreadEmitter): - self.emitter = emitter - - def promote(self, value: Value, to_type: IrType) -> Value: - value_type = value.type - # Short-circuit if already the right type. - if value_type == to_type: - return value - - attr_name = f"promote_{value_type}_to_{to_type}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"No implemented path to implicitly promote scalar `{value_type}` to `{to_type}` (tried '{attr_name}')" - ) - return handler(value, to_type) +@handle_op(torch.exp) +def _(emitter: ThreadEmitter, node: fx.Node): + args = op_matchers.torch_exp(*node.args, **node.kwargs) + raw_input = args["input"] + input = cast_vector(emitter, raw_input) + result = math_d.exp(input) + emitter.bind_node_result(node, result) - def zero_attr(self, t: IrType) -> Attribute: - attr_name = f"zero_attr_{t}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot derive a zero value for type `{t}` (tried '{attr_name}')" - ) - return handler(t) - def constant(self, py_value) -> Value: - attr_name = f"py_constant_{type(py_value).__name__}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot convert Python value to constant: {py_value} of type {type(py_value)} (tried '{attr_name}')" +@handle_op(torch.max) +def _(emitter: ThreadEmitter, node: fx.Node): + args = op_matchers.torch_max_unary( + *node.args, **node.kwargs + ) or op_matchers.torch_max(*node.args, **node.kwargs) + + def combiner(element_type: IrType, attrs: NodeAttrs) -> vector_d.CombiningKind: + if ScalarBuilder.is_floating_point_type(element_type): + # Non-NaN propagating. + # TODO: Carry a "fastmath" flag on the emitter and choose between this + # and MAXIMUMF? + return vector_d.CombiningKind.MAXF + elif ScalarBuilder.is_integer_type(element_type): + return ( + vector_d.CombiningKind.MAXUI + if attrs.unsigned + else vector_d.CombiningKind.MAXSI ) - return handler(py_value) - def binary_arithmetic(self, op: str, lhs: Value, rhs: Value) -> Value: - attr_name = f"binary_{op}_{lhs.type}_{rhs.type}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')" - ) - return handler(lhs, rhs) + emit_reduction(emitter, node, args, combiner) - def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value: - with self.emitter.ip: - i32_type = IntegerType.get_signless(32) - i32 = arith_d.index_cast(i32_type, value) - return arith_d.sitofp(to_type, i32) - def zero_attr_f32(self, t: IrType) -> Attribute: - return FloatAttr.get(t, 0.0) +@handle_op(torch.sum) +def _(emitter: ThreadEmitter, node: fx.Node): + args = op_matchers.torch_sum_unary( + *node.args, **node.kwargs + ) or op_matchers.torch_sum(*node.args, **node.kwargs) - def py_constant_int(self, py_value) -> Value: - # If coming from a stock 'int' Python type with no idea how to convert it, - # there isn't much smart we can do. We conservatively treat 'index' as - # reasonable. - with self.emitter.ip: - attr = IntegerAttr.get(IndexType.get(), py_value) - return arith_d.constant(attr) + def combiner(element_type: IrType, attrs: NodeAttrs) -> vector_d.CombiningKind: + return vector_d.CombiningKind.ADD - def binary_add_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.addi(lhs, rhs) + emit_reduction(emitter, node, args, combiner) - def binary_mul_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.muli(lhs, rhs) - def binary_sub_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.subi(lhs, rhs) +def emit_reduction( + emitter: ThreadEmitter, + node: fx.Node, + args: dict, + combiner_callback: Callable[[IrType], vector_d.CombiningKind], +): + # Setup. + raw_input = args["input"] + attrs = NodeAttrs.load(raw_input) + input = cast_vector(emitter, raw_input) + vector_type = VectorType(input.type) + element_type = vector_type.element_type + rank = vector_type.rank + zero = arith_d.constant(ScalarBuilder.zero_attr(element_type)) + combiner = combiner_callback(element_type, attrs) + + if len(args) == 1: + # Reduce to scalar. + scalar_result = vector_d.multi_reduction( + combiner, input, zero, list(range(rank)) + ) + result = vector_d.splat(VectorType.get([], element_type), scalar_result) + emitter.bind_node_result(node, result, attrs=attrs) + else: + # Reduce to vector. + raise CodegenError("NYI: Reduce to vector") - def binary_mod_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.remsi(lhs, rhs) - def binary_floordiv_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.floordivsi(lhs, rhs) +############################################################################### +# Conversion utilities +############################################################################### def cast_py_value(emitter: ThreadEmitter, value) -> Value: @@ -415,20 +479,40 @@ def cast_py_value(emitter: ThreadEmitter, value) -> Value: except KeyError: raise CodegenError(f"Producer node `{value}` has no IR Value") - return emitter.scalar_builder.constant(value) + return ScalarBuilder.constant(value) + + +def cast_py_lvalue(emitter: ThreadEmitter, py_value) -> tuple[Value, fx.Node]: + if isinstance(py_value, fx.Node): + try: + return emitter.nv_map[py_value], py_value + except KeyError: + raise CodegenError(f"Producer node `{py_value}` has no IR Value") + else: + raise CodegenError( + f"Required a traced node in the graph. Got: {py_value} (type {type(py_value)})" + ) -def cast_kernel_buffer(emitter: ThreadEmitter, kb) -> tuple[Value, MemRefType]: +def cast_kernel_buffer( + emitter: ThreadEmitter, kb +) -> tuple[Value, MemRefType, Type[KernelBuffer]]: """Casts a Python value of type KernelBuffer, which lowers to a MemRefType'd value.""" - value = cast_py_value(emitter, kb) - value_type = value.type + value, node = cast_py_lvalue(emitter, kb) + ir_type = value.type + py_type = node.type - if MemRefType.isinstance(value_type): - return value, MemRefType(value_type) + if not MemRefType.isinstance(ir_type): + raise CodegenError( + f"Expected a KernelBuffer (aka. `memref`) but got `{ir_type}`" + ) - raise CodegenError( - f"Expected a KernelBuffer (aka. `memref`) but got `{value_type}`" - ) + if not issubclass(py_type, KernelBuffer): + raise CodegenError( + f"Expected an lvalue of type KernelBuffer but got '{py_type}' for node {node}" + ) + + return value, MemRefType(ir_type), py_type def cast_indices(emitter: ThreadEmitter, slice) -> list[Value]: @@ -444,10 +528,9 @@ def cast_vector( value = cast_py_value(emitter, value) # Promote scalar types correctly first. - if not ShapedType.isinstance(value.type): - if element_type is not None: - # Implicit scalar type promotion. - value = emitter.scalar_builder.promote(value, element_type) + if element_type and not ShapedType.isinstance(value.type): + # Implicit scalar type promotion. + value = ScalarBuilder.promote(value, element_type) # After scalar promotion, promote to vector. if VectorType.isinstance(value.type): @@ -466,7 +549,4 @@ def cast_vector( # Scalar -> vector. element_type = value.type vector_type = VectorType.get([], element_type) - with emitter.ip: - return vector_d.splat(vector_type, value) - - raise CodegenError(f"Unable to automatically cast type `{value.type}` to a vector") + return vector_d.splat(vector_type, value) diff --git a/python/shark_turbine/kernel/gen/thread.py b/python/shark_turbine/kernel/gen/thread.py index ded5e765e..e49754c11 100644 --- a/python/shark_turbine/kernel/gen/thread.py +++ b/python/shark_turbine/kernel/gen/thread.py @@ -32,7 +32,7 @@ def thread(*symbolic_shape: SymbolDef): def decorator(f: Optional[TCallable] = None) -> "UnconfiguredThread[TCallable]": # Eagerly capture the trace and attach it to the wrapped function. tracer = KernelTracer() - with CompiledContext(tracer) as context: + with CompiledContext(tracer, grid_type=GridType) as context: g = tracer.trace(f) gm = fx.GraphModule(tracer.root, g, f.__name__) diff --git a/tests/kernel/analysis_test.py b/tests/kernel/analysis_test.py new file mode 100644 index 000000000..35c16828e --- /dev/null +++ b/tests/kernel/analysis_test.py @@ -0,0 +1,65 @@ +import logging +import unittest + + +from shark_turbine.kernel._support.indexing import ( + IndexingContext, + sym, +) + +from shark_turbine.kernel.compiler.analysis import ( + SliceAnalysis, + _norm_slice_spec, +) + +M = sym.M +N = sym.N +K = sym.K + + +class SliceAnalysisTest(unittest.TestCase): + def testNorm(self): + self.assertEqual([slice(1, 0, 0)], (_norm_slice_spec(1, 1))) + self.assertEqual([slice(1, 0, 0)], (_norm_slice_spec(1, (1,)))) + self.assertEqual( + [slice(1, None, None)], (_norm_slice_spec(1, (slice(1, None)))) + ) + self.assertEqual([None, slice(2, 0, 0)], (_norm_slice_spec(2, (None, 2)))) + self.assertEqual( + [slice(1, 0, 0), slice(2, 0, 0)], + (_norm_slice_spec(2, (1, ..., 2))), + ) + self.assertEqual( + [slice(1, 0, 0), slice(2, 0, 0)], + (_norm_slice_spec(1, (1, ..., 2))), + ) + self.assertEqual( + [ + None, + slice(None, None, None), + slice(None, None, None), + slice(None, None, None), + slice(None, None, None), + slice(2, 0, 0), + ], + (_norm_slice_spec(5, (None, ..., 2))), + ) + + def testSymbolic(self): + with IndexingContext() as ctx: + ctx.bind_constant(M, 20) + ctx.bind_constant(N, 30) + ctx.bind_constant(K, 5) + + sa = SliceAnalysis((M, N, K), (1, slice(None), slice(2, K), None)) + sa.normalize_symbolic_ranges() + self.assertEqual( + "[slice(1, 0, 0), slice(0, Symbol(N), 1), slice(2, Symbol(K), 1), None]", + repr(sa.slices), + ) + self.assertEqual([1, 30, 3, None], sa.symbolic_shape) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/kernel/indexing_test.py b/tests/kernel/indexing_test.py index b73a11a47..1553717c6 100644 --- a/tests/kernel/indexing_test.py +++ b/tests/kernel/indexing_test.py @@ -65,6 +65,13 @@ def testUsageAndElementTypeInstance(self): T = InputBuffer[M].of(torch.float16) self.assertEqual("InputBuffer[M].of(torch.float16)", repr(T)) + def testBoundedSymbolValue(self): + self.assertEqual("BoundedSymbolicValue(* : *)", (repr(BoundedSymbolicValue))) + B1 = BoundedSymbolicValue.bound(sym_0, None) + self.assertEqual("BoundedSymbolicValue(Symbol(0) : *)", repr(B1)) + B2 = B1.narrow(max_bound=sym_1) + self.assertEqual("BoundedSymbolicValue(Symbol(0) : Symbol(1))", repr(B2)) + if __name__ == "__main__": unittest.main() diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py index 138fc3361..2439bdc8d 100644 --- a/tests/kernel/vector_codegen_test.py +++ b/tests/kernel/vector_codegen_test.py @@ -1,6 +1,7 @@ import logging import unittest +import torch import shark_turbine.kernel as tk from shark_turbine.kernel.compiler import ( @@ -43,6 +44,38 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]): print(mb.module_op.get_asm()) mb.module_op.verify() + def testSoftmaxFx(self): + @tk.gen.thread(M) + def softmax_kernel( + input: tk.lang.KernelBuffer[M, K], output: tk.lang.KernelBuffer[M, K] + ): + row_index = tk.lang.program_id(0) + input_row = input[row_index, :] + numerator = torch.exp(input_row - torch.max(input_row)) + output_row = numerator / torch.sum(numerator) + output[row_index, :] = output_row + + gm = softmax_kernel._trace.gm + print(gm.graph) + mb = builder.ModuleBuilder() + with indexing.IndexingContext() as idxc: + idxc.bind_constant(M, 128) + idxc.bind_constant(K, 64) + + sig = vector_codegen.Signature() + sig.add_from_graph_placeholders(gm.graph) + sig.add_grid(softmax_kernel.grid_type) + print(sig) + try: + emitter = vector_codegen.ThreadEmitter( + mb, softmax_kernel.grid_type, sig + ) + emitter.emit_graph(gm.graph) + finally: + emitter.finish() + print(mb.module_op.get_asm()) + mb.module_op.verify() + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)