diff --git a/python/shark_turbine/kernel/_support/indexing.py b/python/shark_turbine/kernel/_support/indexing.py index ebf7f219e..7d0fa4b27 100644 --- a/python/shark_turbine/kernel/_support/indexing.py +++ b/python/shark_turbine/kernel/_support/indexing.py @@ -10,6 +10,7 @@ from . import context __all__ = [ + "BoundedSymbolicValue", "KernelBuffer", "Grid", "InputBuffer", @@ -17,6 +18,10 @@ "SymbolDef", "TemporaryBuffer", "sym", + "sym_0", + "sym_1", + "sym_2", + "sym_n1", ] @@ -73,7 +78,37 @@ def ir_type_asm(self) -> str: ############################################################################### -class SymbolDef: +class SymbolExpr: + def is_one(self) -> Optional[bool]: + """Returns True if the symbol is known to be 1. + + Return False if known to be != 1 and None if not known. + """ + raise NotImplementedError + + def is_non_negative(self) -> Optional[bool]: + """Returns True is the symbol is known to be non-negative. + + Returns False if known to be negative and None if not known. + """ + raise NotImplementedError + + def is_positive(self) -> Optional[bool]: + """Returns True is the symbol is known to be greater than zero. + + Returns False if known to be <= 0 and None if not known. + """ + raise NotImplementedError + + def is_negative(self) -> Optional[bool]: + """Returns True is the symbol is known to be greater than zero. + + Returns False if known to be <= 0 and None if not known. + """ + raise NotImplementedError + + +class SymbolDef(SymbolExpr): """Represents a named symbol representing a dimension in a shape.""" ALL_SYMBOLS: ClassVar[dict[str, "SymbolDef"]] = dict() @@ -101,9 +136,153 @@ def __getattr__(self, n): return Expando() + def is_one(self) -> Optional[bool]: + value = IndexingContext.current().get_static_value(self) + if value is None: + return None + return value == 1 + + def is_non_negative(self) -> Optional[bool]: + value = IndexingContext.current().get_static_value(self) + if value is None: + return None + return value >= 0 + + def is_positive(self) -> Optional[bool]: + value = IndexingContext.current().get_static_value(self) + if value is None: + return None + return value > 0 + + def is_negative(self) -> Optional[bool]: + value = IndexingContext.current().get_static_value(self) + if value is None: + return None + return value < 0 + sym = SymbolDef.create_expando() +sym_0 = SymbolDef("0") +sym_1 = SymbolDef("1") +sym_2 = SymbolDef("2") +sym_n1 = SymbolDef("-1") + + +############################################################################### +# Bounded symbolic value. +############################################################################### + +BoundedRangeExprT = tuple[Optional[SymbolExpr], Optional[SymbolExpr]] + + +class _BoundedSymbolicValueMeta(type): + """Meta-class for deriving new bounded symbolic values.""" + + range: BoundedRangeExprT + + def __new__(mcls, name: str, bases, dct, *, range: BoundedRangeExprT): + dct["range"] = range + dct["__qualname__"] = _bounded_symbolic_value_repr(range=range) + new_class = type.__new__(mcls, name, bases, dct) + return new_class + + def __repr__(cls): + return _bounded_symbolic_value_repr(range=cls.range) + + @property + def min_bound(cls) -> Optional[SymbolExpr]: + return cls.range[0] + + @property + def max_bound(cls) -> Optional[SymbolExpr]: + return cls.range[1] + + def bound( + cls: Type[SubtypeT], + min_bound: Optional[SymbolExpr], + max_bound: Optional[SymbolExpr], + ) -> Type[SubtypeT]: + class Bounded(BoundedSymbolicValue, range=(min_bound, max_bound)): + ... + + return Bounded + + def narrow( + cls: Type[SubtypeT], + *, + min_bound: Optional[SymbolExpr] = None, + max_bound: Optional[SymbolExpr] = None, + ) -> Type[SubtypeT]: + class Bounded( + BoundedSymbolicValue, + range=( + min_bound if min_bound is not None else cls.min_bound, + max_bound if max_bound is not None else cls.max_bound, + ), + ): + ... + + return Bounded + + +def _bounded_symbolic_value_repr(*, range: BoundedRangeExprT) -> str: + min_expr, max_expr = range + min_s = repr(min_expr) if min_expr is not None else "*" + max_s = repr(max_expr) if max_expr is not None else "*" + return f"BoundedSymbolicValue({min_s} : {max_s})" + + +class BoundedSymbolicValue( + SymbolExpr, metaclass=_BoundedSymbolicValueMeta, range=(None, None) +): + """Represents a symbolic value that is bounded to a range fixed for the type.""" + + def __init__(self, value: Optional[int] = None): + self.value = value + + def __repr__(self): + return f"{type(self)}({'proxy' if self.value is None else self.value})" + + @property + def static_range(self) -> Optional[tuple[int, int]]: + # TODO: This is a hack until shape derivation is in place. + ctx = IndexingContext.current() + mn, mx = type(self).range + if mn is not None: + mn = ctx.get_static_value(mn) + if mx is not None: + mx = ctx.get_static_value(mx) + if mn is not None and mx is not None: + return mn, mx + else: + return None + + def is_one(self) -> Optional[bool]: + r = self.static_range + if r: + return r[0] == 1 and r[1] == 2 + return None + + def is_non_negative(self) -> Optional[bool]: + r = self.static_range + if r: + return r[0] >= 0 + return None + + def is_positive(self) -> Optional[bool]: + r = self.static_range + if r: + return r[0] > 0 + return None + + def is_negative(self) -> Optional[bool]: + r = self.static_range + if r: + return r[1] < 0 + return None + + ############################################################################### # Grid ############################################################################### @@ -271,7 +450,7 @@ def __repr__(cls): ) -def _is_kernel_buffer_meta_derived(t: type) -> bool: +def is_kernel_buffer_meta_derived(t: type) -> bool: return isinstance(t, _KernelBufferMeta) @@ -361,7 +540,12 @@ class IndexingContext: __tk_context_idname__ = "IndexingContext" def __init__(self): - self.constant_bindings: dict[SymbolDef, int] = {} + self.constant_bindings: dict[SymbolDef, int] = { + sym_0: 0, + sym_1: 1, + sym_2: 2, + sym_n1: -1, + } def bind_constant(self, sym: SymbolDef, value: int): existing = self.constant_bindings.get(sym) @@ -371,7 +555,7 @@ def bind_constant(self, sym: SymbolDef, value: int): ) self.constant_bindings[sym] = value - def get_static_value(self, sym: SymbolDef) -> Optional[int]: + def get_static_value(self, sym: SymbolExpr) -> Optional[int]: """If the symbol can be resolved to a static value, returns it.""" return self.constant_bindings.get(sym) diff --git a/python/shark_turbine/kernel/_support/tracing.py b/python/shark_turbine/kernel/_support/tracing.py index 822abf87b..004075c4d 100644 --- a/python/shark_turbine/kernel/_support/tracing.py +++ b/python/shark_turbine/kernel/_support/tracing.py @@ -7,7 +7,10 @@ import torch.fx as fx from .indexing import ( + BoundedSymbolicValue, + Grid, KernelBuffer, + sym_0, ) from ..lang.types import ( @@ -106,13 +109,23 @@ def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, ite class CompiledContext(BaseContext): - def __init__(self, tracer: KernelTracer): + def __init__(self, tracer: KernelTracer, *, grid_type: Type[Grid]): super().__init__(eager=False) self.tracer = tracer + self.grid_type = grid_type def handle_thread_program_id(self, op, axis: int) -> Index: + grid_shape = self.grid_type.symbolic_shape + if axis < 0 or axis >= len(grid_shape): + raise IndexError( + f"Illegal index into grid of rank {len(grid_shape)}: {axis}" + ) proxy = self.tracer.create_proxy( - "call_function", op, args=(axis,), kwargs={}, type_expr=Index + "call_function", + op, + args=(axis,), + kwargs={}, + type_expr=BoundedSymbolicValue.bound(sym_0, grid_shape[axis]), ) return proxy diff --git a/python/shark_turbine/kernel/compiler/analysis.py b/python/shark_turbine/kernel/compiler/analysis.py new file mode 100644 index 000000000..19ef2ee59 --- /dev/null +++ b/python/shark_turbine/kernel/compiler/analysis.py @@ -0,0 +1,250 @@ +from typing import Any, Optional, Type, Union + +import torch.fx as fx + +from .base import CodegenError + +from .._support.indexing import ( + BoundedSymbolicValue, + IndexingContext, + SymbolDef, + SymbolExpr, +) + + +NormalizedSlice = list[Union[slice, None]] + + +def _symbolize_slice_value(value): + # TODO: I don't like this and wish this happened more automatically somehow. + if isinstance(value, fx.Node): + sym_type = value.type + if sym_type and issubclass(sym_type, BoundedSymbolicValue): + return sym_type() + return value + else: + return value + + +def _norm_slice_spec(rank: int, slice_spec) -> NormalizedSlice: + def _norm_single_slice(s): + if s is None or s is ...: + return s + if isinstance(s, slice): + # Validate. + if s.step == 0: + # A zero step is illegal, but we use it to signal an integer index + # vs a range. + raise IndexError(f"slice with step 0 is illegal (got {s})") + return s + else: + # Promote a raw value to our special 0-step slice. + return slice(s, 0, 0) + + if not isinstance(slice_spec, tuple): + slice_spec = (slice_spec,) + norm_slices = [_norm_single_slice(s) for s in slice_spec] + + # Replace any ellipses with rank-filling None values. + none_count = norm_slices.count(None) + ellipses_count = norm_slices.count(...) + if ellipses_count == 1: + # Expand by the original list of slices less any unit dim insertions. + # If negative, this does nothing and will be caught later upon + # rank validation. + expand_index = norm_slices.index(...) + del norm_slices[expand_index] + expansion_count = (rank + none_count) - len(norm_slices) + for _ in range(expansion_count): + norm_slices.insert(expand_index, slice(None)) + elif ellipses_count > 1: + raise IndexError( + f"Cannot index into a rank expanding referrent with multiple `...` values" + ) + return norm_slices + + +class SliceAnalysis: + """Analyses Python slicing notations such that it can be validated and code generated. + + The numpy page has a good description here: + https://numpy.org/doc/1.26/user/basics.indexing.html + + This class analyzes: + * Basic Indexing + * Slicing and Striding + * Dimensional Indexing Tools + + Note that `None` is interpreted as `np.newaxis` (which we do not support). + + Each element of a slice specification can be: + * An arbitrary Python value representing a single element span + * None to indicate a new unit dimension + * Ellipses to indicate space filling `slice()` + * A `slice` object + + Such a specification is decomposed into a `source_slice` which does not + include any rank broadcasting and a `broadcast_slice` which includes + any rank expansion. Depending on the operation being code generated, + these will be handled differently. All loose Python values will + be promoted into a `slice` object. + + Raises: + IndexError on any violations of Python indexing semantics which + are statically determined during analysis. + """ + + def __init__(self, ref: tuple[SymbolExpr, ...], slice_spec): + self.ref = ref + self.slices = _norm_slice_spec(len(ref), slice_spec) + + # Compute an expanded version of ref that has None values for + # any to-be-inserted unit dims. This will be the same size + # as slices. + self.expanded_ref: list[Optional[SymbolExpr]] = list(ref) + for i in (i for i, entry in enumerate(self.slices) if entry is None): + self.expanded_ref.insert(i, None) + assert len(self.expanded_ref) == len(self.slices) + self._is_symbolic_normalized = False + + def __repr__(self): + return repr(self.slices) + + def normalize_symbolic_ranges( + self, *, allow_reverse_step: bool = False, allow_non_unit_step: bool = False + ): + """Uses the IndexingContext to normalize range for any None fields. + + This fully populates the fields of every slice with either + integers or SymbolExprs. Does not modify slice fields that are + non-None. + + Raises IndexError for any variations that are not supported + or cannot be statically derived. + """ + if self._is_symbolic_normalized: + return + + def norm(dim_expr: SymbolExpr, s: Optional[slice]) -> Optional[slice]: + if s is None: + return s + ctx = IndexingContext.current() + start = s.start + stop = s.stop + step = s.step + + # Set defaults. + if start is None: + start = 0 + if stop is None: + stop = dim_expr + if step is None: + step = 1 + + # Symbolize for analysis. + start_sym = _symbolize_slice_value(start) + stop_sym = _symbolize_slice_value(stop) + step_sym = _symbolize_slice_value(step) + + # Evaluate facts for start. + if isinstance(start_sym, SymbolExpr): + start_is_non_negative = start_sym.is_non_negative() + elif isinstance(start_sym, int): + start_is_non_negative = start_sym >= 0 + else: + raise IndexError( + f"A symbolically evaluable start index is required (got: {start_sym} (type {type(start_sym)}))" + ) + + # Evaluate facts for stop. + if isinstance(stop_sym, SymbolExpr): + stop_is_non_negative = stop_sym.is_non_negative() + stop_is_zero = False + elif isinstance(stop_sym, int): + stop_is_non_negative = stop_sym >= 0 + stop_is_zero = stop_sym == 0 + else: + raise IndexError( + f"A symbolically evaluable stop index is required (got: {stop_sym} (type {type(stop_sym)}))" + ) + + # Evaluate facts for step. + if isinstance(step_sym, SymbolExpr): + reverse_step = step_sym.is_negative() + unit_step = step_sym.is_one() + zero_step = False + elif isinstance(step_sym, int): + reverse_step = step_sym < 0 + unit_step = step_sym == 1 + zero_step = step_sym == 0 + else: + raise IndexError( + f"A symbolically evaluable step is required (got: {step_sym} (type {type(step_sym)}))" + ) + + # Validate step constraints. + if zero_step: + # This is our special marker for a unit (non-range extract). + assert ( + stop_is_zero + ), "slices of non zero stop and zero step should have been filtered" + else: + if not allow_non_unit_step and not unit_step: + raise IndexError( + f"Only unit steps are supported in this context (got slice({start_sym}, {stop_sym}, {step_sym}))" + ) + + if not allow_reverse_step and reverse_step: + raise IndexError( + f"Only forward steps are supported in this context (got slice({start_sym}, {stop_sym}, {step_sym}))" + ) + + # Normalize negative start/stop. + if not start_is_non_negative: + raise IndexError(f"NYI: Negative slice start") + if not stop_is_non_negative: + raise IndexError(f"NYI: Negative slice stop") + + return slice(start, stop, step) + + for i in range(len(self.slices)): + expr = self.expanded_ref[i] + self.slices[i] = norm(expr, self.slices[i]) + self._is_symbolic_normalized = True + + @property + def symbolic_shape(self) -> tuple[Union[int, SymbolExpr]]: + """Resolves the symbolic shape of the result of this slice. + + Forces symbolic normalization if it has not already been done. + Any rank broadcast dimensions will be retained as None. + """ + self.normalize_symbolic_ranges() + + def _item(s: Optional[slice]): + if s is None: + return None + ctx = IndexingContext.current() + # Detect special unit 0-step slices. + if s.stop == 0 and s.step == 0: + return 1 + start = s.start + stop = s.stop + + # TODO: This is a hack to work around that I don't have the full + # symbolic expression support in yet. We should just be asking + # the symbols to evaluate. + if isinstance(start, SymbolExpr): + static_start = ctx.get_static_value(start) + elif isinstance(start, int): + static_start = start + if isinstance(stop, SymbolExpr): + static_stop = ctx.get_static_value(stop) + elif isinstance(stop, int): + static_stop = stop + if static_start is not None and static_stop is not None: + return static_stop - static_start + + raise IndexError(f"NYI: Non-statically resolved symbolic shapes") + + return [_item(s) for s in self.slices] diff --git a/python/shark_turbine/kernel/compiler/base.py b/python/shark_turbine/kernel/compiler/base.py new file mode 100644 index 000000000..1c896313f --- /dev/null +++ b/python/shark_turbine/kernel/compiler/base.py @@ -0,0 +1,6 @@ +class CodegenError(Exception): + ... + + +class ValidationError(CodegenError): + ... diff --git a/python/shark_turbine/kernel/compiler/builder.py b/python/shark_turbine/kernel/compiler/builder.py index e1ed0d9a2..89a149c6d 100644 --- a/python/shark_turbine/kernel/compiler/builder.py +++ b/python/shark_turbine/kernel/compiler/builder.py @@ -1,19 +1,42 @@ from typing import Optional +from .base import ( + CodegenError, +) + from .ir import ( + Attribute, Context, + FloatAttr, + IndexType, + IntegerAttr, + IntegerType, + IrType, Location, Operation, + Value, + VectorType, + arith_d, builtin_d, ) +# TODO: Have a way upstream to check if a floating point type. +FLOAT_TYPES_ASM = { + "bf16", + "f16", + "f32", + "f64", + # TODO: FP8 types. +} + + class ModuleBuilder: def __init__( self, *, context: Optional[Context] = None, - module_op: Optional[Operation] = None + module_op: Optional[Operation] = None, ): if module_op: self.module_op = module_op @@ -24,3 +47,118 @@ def __init__( self.module_op = builtin_d.ModuleOp(loc=Location.unknown(context)) self.body_block = self.module_op.body self.context = self.module_op.context + + +class _ScalarBuilder: + def is_floating_point_type(self, t: IrType) -> bool: + return str(t) in FLOAT_TYPES_ASM + + def is_integer_type(self, t: IrType) -> bool: + return IntegerType.isinstance(t) + + def promote(self, value: Value, to_type: IrType) -> Value: + value_type = value.type + # Short-circuit if already the right type. + if value_type == to_type: + return value + + attr_name = f"promote_{value_type}_to_{to_type}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"No implemented path to implicitly promote scalar `{value_type}` to `{to_type}` (tried '{attr_name}')" + ) + return handler(value, to_type) + + def zero_attr(self, t: IrType) -> Attribute: + attr_name = f"zero_attr_{t}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"Cannot derive a zero value for type `{t}` (tried '{attr_name}')" + ) + return handler(t) + + def constant(self, py_value) -> Value: + attr_name = f"py_constant_{type(py_value).__name__}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"Cannot convert Python value to constant: {py_value} of type {type(py_value)} (tried '{attr_name}')" + ) + return handler(py_value) + + def binary_arithmetic(self, op: str, lhs: Value, rhs: Value) -> Value: + attr_name = f"binary_{op}_{lhs.type}_{rhs.type}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')" + ) + return handler(lhs, rhs) + + def binary_vector_arithmetic(self, op: str, lhs: Value, rhs: Value) -> Value: + lhs_element_type = VectorType(lhs.type).element_type + rhs_element_type = VectorType(rhs.type).element_type + attr_name = f"binary_{op}_{lhs_element_type}_{rhs_element_type}" + try: + handler = getattr(self, attr_name) + except AttributeError: + raise CodegenError( + f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')" + ) + return handler(lhs, rhs) + + def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value: + i32_type = IntegerType.get_signless(32) + i32 = arith_d.index_cast(i32_type, value) + return arith_d.sitofp(to_type, i32) + + def zero_attr_f32(self, t: IrType) -> Attribute: + return FloatAttr.get(t, 0.0) + + def py_constant_int(self, py_value) -> Value: + # If coming from a stock 'int' Python type with no idea how to convert it, + # there isn't much smart we can do. We conservatively treat 'index' as + # reasonable. + attr = IntegerAttr.get(IndexType.get(), py_value) + return arith_d.constant(attr) + + # Binary index arithmetic. + def binary_add_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.addi(lhs, rhs) + + def binary_mul_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.muli(lhs, rhs) + + def binary_sub_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.subi(lhs, rhs) + + def binary_mod_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.remsi(lhs, rhs) + + def binary_floordiv_index_index(self, lhs: Value, rhs: Value) -> Value: + return arith_d.floordivsi(lhs, rhs) + + # Binary f32 arithmetic + def binary_add_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.addf(lhs, rhs) + + def binary_mul_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.mulf(lhs, rhs) + + def binary_sub_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.subf(lhs, rhs) + + def binary_mod_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.remf(lhs, rhs) + + def binary_truediv_f32_f32(self, lhs: Value, rhs: Value) -> Value: + return arith_d.divf(lhs, rhs) + + +ScalarBuilder = _ScalarBuilder() diff --git a/python/shark_turbine/kernel/compiler/ir.py b/python/shark_turbine/kernel/compiler/ir.py index c73408897..84fb01259 100644 --- a/python/shark_turbine/kernel/compiler/ir.py +++ b/python/shark_turbine/kernel/compiler/ir.py @@ -26,5 +26,6 @@ arith as arith_d, builtin as builtin_d, func as func_d, + math as math_d, vector as vector_d, ) diff --git a/python/shark_turbine/kernel/compiler/op_matchers.py b/python/shark_turbine/kernel/compiler/op_matchers.py new file mode 100644 index 000000000..81f738a5e --- /dev/null +++ b/python/shark_turbine/kernel/compiler/op_matchers.py @@ -0,0 +1,64 @@ +from typing import Optional + +import torch +from torch import Tensor + +import functools +import inspect + + +def signature_matcher(f=None, *, arity: Optional[int] = None, original_name: str = ""): + """Transforms a function into a signature matcher. + + The transfored function takes the same args/kwargs as the original, but + it will return an inspect.BoundArguments.arguments when invoked. + + Optional overload selectors can be specified, and if not met, None + will be returned (versus raising an error). + + On argument mismatch, a TypeError will be raised. + """ + if f is None: + return functools.partial( + signature_matcher, arity=arity, original_name=original_name + ) + + sig = inspect.signature(f) + + def wrapped(*args, **kwargs) -> Optional[inspect.BoundArguments]: + if arity is not None and arity != (len(args) + len(kwargs)): + return None + try: + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + return bound_args.arguments + except TypeError as e: + reported_name = original_name or f.__name__ + raise TypeError(f"{reported_name}() {str(e)}") + + return wrapped + + +@signature_matcher(original_name="torch.exp") +def torch_exp(input: Tensor) -> Tensor: + ... + + +@signature_matcher(arity=1, original_name="torch.max") +def torch_max_unary(input: Tensor) -> Tensor: + ... + + +@signature_matcher(original_name="torch.max") +def torch_max(input: Tensor, dim: int, keepdim: bool = False): + ... + + +@signature_matcher(arity=1, original_name="torch.sum") +def torch_sum_unary(input: Tensor) -> Tensor: + ... + + +@signature_matcher(original_name="torch.sum") +def torch_sum(input: Tensor, dim: int, keepdim: bool = False): + ... diff --git a/python/shark_turbine/kernel/compiler/vector_codegen.py b/python/shark_turbine/kernel/compiler/vector_codegen.py index 880dc8578..6828d4871 100644 --- a/python/shark_turbine/kernel/compiler/vector_codegen.py +++ b/python/shark_turbine/kernel/compiler/vector_codegen.py @@ -1,8 +1,10 @@ from typing import Any, Callable, Type, Optional, Sequence, Union from dataclasses import dataclass +import inspect import operator as py_operator +import torch import torch.fx as fx from .._support.indexing import ( @@ -10,7 +12,7 @@ IndexingContext, KernelBuffer, SymbolDef, - _is_kernel_buffer_meta_derived, + is_kernel_buffer_meta_derived, ) from ..lang import ( @@ -19,23 +21,26 @@ from .. import ops +from .analysis import ( + SliceAnalysis, +) + from .builder import ( ModuleBuilder, + ScalarBuilder, +) + +from .base import ( + CodegenError, + ValidationError, ) from .ir import ( - AffineConstantExpr, - AffineExpr, AffineMap, AffineMapAttr, - Attribute, - DenseElementsAttr, - FloatAttr, FunctionType, IndexType, InsertionPoint, - IntegerAttr, - IntegerType, IrType, Location, MemRefType, @@ -44,21 +49,15 @@ VectorType, arith_d, func_d, + math_d, vector_d, ) +from . import op_matchers ArgTypeUnion = Union[SymbolDef, Type[KernelBuffer]] -class CodegenError(Exception): - ... - - -class ValidationError(CodegenError): - ... - - @dataclass class ArgMeta: name: Optional[str] = None @@ -66,6 +65,22 @@ class ArgMeta: grid_index: Optional[int] = None +@dataclass +class NodeAttrs: + # By default, integers are assumed signed. We propagate unsigned as graph + # node attrs. + unsigned: bool = False + + @staticmethod + def load(py_value) -> "NodeAttrs": + if isinstance(py_value, fx.Node): + return NodeAttrs(unsigned=bool(py_value.meta.get("unsigned"))) + return NodeAttrs() + + def store(self, node: fx.Node): + node.meta["unsigned"] = self.unsigned + + class Signature: """Represents a function signature. @@ -102,7 +117,7 @@ def sym_to_dim_asm(s: SymbolDef) -> str: def as_mlir_type(t: ArgTypeUnion) -> FunctionType: if isinstance(t, SymbolDef): return IndexType.get() - elif _is_kernel_buffer_meta_derived(t): + elif is_kernel_buffer_meta_derived(t): kb_t = t # type: KernelBuffer element_type_asm = kb_t.element_type.ir_type_asm() symbolic_shape = kb_t.symbolic_shape @@ -129,7 +144,7 @@ def add_from_graph_placeholders(self, graph: fx.Graph): continue t = node.type meta = ArgMeta(name=node.target, node=node) - if _is_kernel_buffer_meta_derived(t): + if is_kernel_buffer_meta_derived(t): self.add_kernel_buffer(t, meta=meta) elif issubclass(t, SymbolDef): self.add_symbol(t, meta=meta) @@ -150,7 +165,6 @@ class ThreadEmitter: OP_HANDLERS: dict[Any, Callable[["ThreadEmitter", fx.Node], None]] = {} def __init__(self, mb: ModuleBuilder, grid: Grid, sig: Signature): - self.scalar_builder = ScalarBuilder(self) self.nv_map: dict[fx.Node, Value] = {} self.grid_index_map: list[Optional[Value]] = [None] * grid.rank @@ -182,18 +196,18 @@ def __init__(self, mb: ModuleBuilder, grid: Grid, sig: Signature): self.ip = InsertionPoint(self.entry_block) def bind_node_result( - self, node: fx.Node, value: Value, type_expr: Optional[type] = None + self, node: fx.Node, value: Value, *, attrs: Optional[NodeAttrs] = None ): assert node not in self.nv_map, f"Cannot rebind node {node}: already bound" - if type_expr is not None: - node.type = type_expr + if attrs is not None: + attrs.store(node) self.nv_map[node] = value def emit_graph(self, graph: fx.Graph): context = self.context for node in graph.nodes: # TODO: Construct a location for the node. - with Location.unknown(context): + with self.ip, Location.unknown(context): if node.op == "call_function": self.emit_function_call_node(node) @@ -228,9 +242,54 @@ def decorator(f: Callable[["ThreadEmitter", fx.Node], None]): (py_operator.sub, "sub"), (py_operator.mod, "mod"), (py_operator.floordiv, "floordiv"), + (py_operator.truediv, "truediv"), ] +def binary_broadcast(lhs: Value, rhs: Value) -> tuple[bool, Value, Value]: + lhs_type = lhs.type + rhs_type = rhs.type + lhs_is_vector = VectorType.isinstance(lhs_type) + rhs_is_vector = VectorType.isinstance(rhs_type) + if not lhs_is_vector and not rhs_is_vector: + # Not vectors: return as-is. + return False, lhs, rhs + + # Promote to vector. + if not lhs_is_vector: + lhs = vector_d.splat(VectorType([], lhs_type), lhs) + if not rhs_is_vector: + rhs = vector_d.splat(VectorType([], rhs_type), rhs) + lhs_type = VectorType(lhs.type) + rhs_type = VectorType(rhs.type) + + broadcast_shape = lhs_type.shape + rhs_shape = rhs_type.shape + rank = max(len(broadcast_shape), len(rhs_shape)) + while len(broadcast_shape) < rank: + broadcast_shape.insert(0, 1) + while len(rhs_shape) < rank: + rhs_shape.insert(0, 1) + + for i in range(rank): + a = broadcast_shape[i] + b = rhs_shape[i] + if a != b: + if a != 1 and b != 1: + raise CodegenError( + f"Binary operands are not broadcast compatible: {lhs_type}, {rhs_type}" + ) + broadcast_shape[i] = rhs_shape[i] = max(a, b) + + lhs_type = VectorType.get(broadcast_shape, lhs_type.element_type) + rhs_type = VectorType.get(broadcast_shape, rhs_type.element_type) + if lhs_type != lhs.type: + lhs = vector_d.broadcast(lhs_type, lhs) + if rhs_type != rhs.type: + rhs = vector_d.broadcast(rhs_type, rhs) + return True, lhs, rhs + + def _define_arithmetic_handlers(): def register(py_operator, mnemonic): @handle_op(py_operator) @@ -242,7 +301,11 @@ def _(emitter: ThreadEmitter, node: fx.Node): lhs = cast_py_value(emitter, lhs) rhs = cast_py_value(emitter, rhs) - result = emitter.scalar_builder.binary_arithmetic(mnemonic, lhs, rhs) + is_vector, lhs, rhs = binary_broadcast(lhs, rhs) + if is_vector: + result = ScalarBuilder.binary_vector_arithmetic(mnemonic, lhs, rhs) + else: + result = ScalarBuilder.binary_arithmetic(mnemonic, lhs, rhs) emitter.bind_node_result(node, result) for py_operator, mnemonic in BINARY_ARITHMETIC_OPS: @@ -272,140 +335,141 @@ def _(emitter: ThreadEmitter, node: fx.Node): except IndexError as e: raise CodegenError("Grid axis out of bounds") from e - emitter.bind_node_result(node, value, Index) + emitter.bind_node_result(node, value) @handle_op(ops.kernel_buffer_getitem) def _(emitter: ThreadEmitter, node: fx.Node): - raise CodegenError("NYI: kernel_buffer_getitem") + try: + kb, slice_spec = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + kb_src, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb) + sa = SliceAnalysis(kb_py_type.symbolic_shape, slice_spec) + sa.normalize_symbolic_ranges() + vector_shape = sa.symbolic_shape + element_type = kb_ir_type.element_type + vector_type = VectorType.get(vector_shape, element_type) + pad_attr = ScalarBuilder.zero_attr(element_type) + indices = cast_indices(emitter, [s.start for s in sa.slices]) + pad_value = arith_d.constant(pad_attr) + result = vector_d.transfer_read( + vector_type, kb_src, indices, AffineMap.get_identity(len(indices)), pad_value + ) + emitter.bind_node_result(node, result) @handle_op(ops.kernel_buffer_setitem) def _(emitter: ThreadEmitter, node: fx.Node): try: - kb, key, item = node.args + kb, slice_spec, item = node.args except ValueError as e: raise ValidationError("Malformed arguments") from e - kb_dest, kb_type = cast_kernel_buffer(emitter, kb) - dest_rank = kb_type.rank - indices = cast_indices(emitter, key) + kb_dest, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, kb) + dest_rank = kb_ir_type.rank + sa = SliceAnalysis(kb_py_type.symbolic_shape, slice_spec) + sa.normalize_symbolic_ranges() + indices = cast_indices(emitter, [s.start for s in sa.slices]) if dest_rank != len(indices): raise CodegenError( f"Mismatched slice assignment: Expected rank {dest_rank}, got {len(indices)}" ) - insert_vector = cast_vector(emitter, item, element_type=kb_type.element_type) + insert_vector = cast_vector(emitter, item, element_type=kb_ir_type.element_type) insert_type = VectorType(insert_vector.type) insert_rank = insert_type.rank - # This form only supports 0d broadcast or same rank currently. - # TODO: This was specially crafted to make the iota demo work. Need to work - # it out in generality. - if insert_rank != 0 and insert_rank != dest_rank: - raise CodegenError( - f"The shorthand kernel_buffer[...]= assignment syntax only supports same rank assignment or restricted, 0d broadcast" - ) + # Special case rank-0 broadcast. + if insert_rank == 0: + broadcast_type = VectorType.get(dest_rank * [1], kb_ir_type.element_type) + insert_vector = vector_d.broadcast(broadcast_type, insert_vector) - with emitter.ip: - if insert_rank == 0: - broadcast_type = VectorType.get(dest_rank * [1], kb_type.element_type) - insert_vector = vector_d.broadcast(broadcast_type, insert_vector) - permutation_map = AffineMap.get_identity(dest_rank) - vector_d.transfer_write( - None, insert_vector, kb_dest, indices, AffineMapAttr.get(permutation_map) - ) + permutation_map = AffineMap.get_identity(dest_rank) + vector_d.transfer_write( + None, insert_vector, kb_dest, indices, AffineMapAttr.get(permutation_map) + ) ############################################################################### -# Conversion utilities +# Torch and math ops ############################################################################### -class ScalarBuilder: - def __init__(self, emitter: ThreadEmitter): - self.emitter = emitter - - def promote(self, value: Value, to_type: IrType) -> Value: - value_type = value.type - # Short-circuit if already the right type. - if value_type == to_type: - return value - - attr_name = f"promote_{value_type}_to_{to_type}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"No implemented path to implicitly promote scalar `{value_type}` to `{to_type}` (tried '{attr_name}')" - ) - return handler(value, to_type) +@handle_op(torch.exp) +def _(emitter: ThreadEmitter, node: fx.Node): + args = op_matchers.torch_exp(*node.args, **node.kwargs) + raw_input = args["input"] + input = cast_vector(emitter, raw_input) + result = math_d.exp(input) + emitter.bind_node_result(node, result) - def zero_attr(self, t: IrType) -> Attribute: - attr_name = f"zero_attr_{t}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot derive a zero value for type `{t}` (tried '{attr_name}')" - ) - return handler(t) - def constant(self, py_value) -> Value: - attr_name = f"py_constant_{type(py_value).__name__}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot convert Python value to constant: {py_value} of type {type(py_value)} (tried '{attr_name}')" +@handle_op(torch.max) +def _(emitter: ThreadEmitter, node: fx.Node): + args = op_matchers.torch_max_unary( + *node.args, **node.kwargs + ) or op_matchers.torch_max(*node.args, **node.kwargs) + + def combiner(element_type: IrType, attrs: NodeAttrs) -> vector_d.CombiningKind: + if ScalarBuilder.is_floating_point_type(element_type): + # Non-NaN propagating. + # TODO: Carry a "fastmath" flag on the emitter and choose between this + # and MAXIMUMF? + return vector_d.CombiningKind.MAXF + elif ScalarBuilder.is_integer_type(element_type): + return ( + vector_d.CombiningKind.MAXUI + if attrs.unsigned + else vector_d.CombiningKind.MAXSI ) - return handler(py_value) - def binary_arithmetic(self, op: str, lhs: Value, rhs: Value) -> Value: - attr_name = f"binary_{op}_{lhs.type}_{rhs.type}" - try: - handler = getattr(self, attr_name) - except AttributeError: - raise CodegenError( - f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')" - ) - return handler(lhs, rhs) + emit_reduction(emitter, node, args, combiner) - def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value: - with self.emitter.ip: - i32_type = IntegerType.get_signless(32) - i32 = arith_d.index_cast(i32_type, value) - return arith_d.sitofp(to_type, i32) - def zero_attr_f32(self, t: IrType) -> Attribute: - return FloatAttr.get(t, 0.0) +@handle_op(torch.sum) +def _(emitter: ThreadEmitter, node: fx.Node): + args = op_matchers.torch_sum_unary( + *node.args, **node.kwargs + ) or op_matchers.torch_sum(*node.args, **node.kwargs) - def py_constant_int(self, py_value) -> Value: - # If coming from a stock 'int' Python type with no idea how to convert it, - # there isn't much smart we can do. We conservatively treat 'index' as - # reasonable. - with self.emitter.ip: - attr = IntegerAttr.get(IndexType.get(), py_value) - return arith_d.constant(attr) + def combiner(element_type: IrType, attrs: NodeAttrs) -> vector_d.CombiningKind: + return vector_d.CombiningKind.ADD - def binary_add_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.addi(lhs, rhs) + emit_reduction(emitter, node, args, combiner) - def binary_mul_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.muli(lhs, rhs) - def binary_sub_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.subi(lhs, rhs) +def emit_reduction( + emitter: ThreadEmitter, + node: fx.Node, + args: dict, + combiner_callback: Callable[[IrType], vector_d.CombiningKind], +): + # Setup. + raw_input = args["input"] + attrs = NodeAttrs.load(raw_input) + input = cast_vector(emitter, raw_input) + vector_type = VectorType(input.type) + element_type = vector_type.element_type + rank = vector_type.rank + zero = arith_d.constant(ScalarBuilder.zero_attr(element_type)) + combiner = combiner_callback(element_type, attrs) + + if len(args) == 1: + # Reduce to scalar. + scalar_result = vector_d.multi_reduction( + combiner, input, zero, list(range(rank)) + ) + result = vector_d.splat(VectorType.get([], element_type), scalar_result) + emitter.bind_node_result(node, result, attrs=attrs) + else: + # Reduce to vector. + raise CodegenError("NYI: Reduce to vector") - def binary_mod_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.remsi(lhs, rhs) - def binary_floordiv_index_index(self, lhs: Value, rhs: Value) -> Value: - with self.emitter.ip: - return arith_d.floordivsi(lhs, rhs) +############################################################################### +# Conversion utilities +############################################################################### def cast_py_value(emitter: ThreadEmitter, value) -> Value: @@ -415,20 +479,40 @@ def cast_py_value(emitter: ThreadEmitter, value) -> Value: except KeyError: raise CodegenError(f"Producer node `{value}` has no IR Value") - return emitter.scalar_builder.constant(value) + return ScalarBuilder.constant(value) + + +def cast_py_lvalue(emitter: ThreadEmitter, py_value) -> tuple[Value, fx.Node]: + if isinstance(py_value, fx.Node): + try: + return emitter.nv_map[py_value], py_value + except KeyError: + raise CodegenError(f"Producer node `{py_value}` has no IR Value") + else: + raise CodegenError( + f"Required a traced node in the graph. Got: {py_value} (type {type(py_value)})" + ) -def cast_kernel_buffer(emitter: ThreadEmitter, kb) -> tuple[Value, MemRefType]: +def cast_kernel_buffer( + emitter: ThreadEmitter, kb +) -> tuple[Value, MemRefType, Type[KernelBuffer]]: """Casts a Python value of type KernelBuffer, which lowers to a MemRefType'd value.""" - value = cast_py_value(emitter, kb) - value_type = value.type + value, node = cast_py_lvalue(emitter, kb) + ir_type = value.type + py_type = node.type - if MemRefType.isinstance(value_type): - return value, MemRefType(value_type) + if not MemRefType.isinstance(ir_type): + raise CodegenError( + f"Expected a KernelBuffer (aka. `memref`) but got `{ir_type}`" + ) - raise CodegenError( - f"Expected a KernelBuffer (aka. `memref`) but got `{value_type}`" - ) + if not issubclass(py_type, KernelBuffer): + raise CodegenError( + f"Expected an lvalue of type KernelBuffer but got '{py_type}' for node {node}" + ) + + return value, MemRefType(ir_type), py_type def cast_indices(emitter: ThreadEmitter, slice) -> list[Value]: @@ -444,10 +528,9 @@ def cast_vector( value = cast_py_value(emitter, value) # Promote scalar types correctly first. - if not ShapedType.isinstance(value.type): - if element_type is not None: - # Implicit scalar type promotion. - value = emitter.scalar_builder.promote(value, element_type) + if element_type and not ShapedType.isinstance(value.type): + # Implicit scalar type promotion. + value = ScalarBuilder.promote(value, element_type) # After scalar promotion, promote to vector. if VectorType.isinstance(value.type): @@ -466,7 +549,4 @@ def cast_vector( # Scalar -> vector. element_type = value.type vector_type = VectorType.get([], element_type) - with emitter.ip: - return vector_d.splat(vector_type, value) - - raise CodegenError(f"Unable to automatically cast type `{value.type}` to a vector") + return vector_d.splat(vector_type, value) diff --git a/python/shark_turbine/kernel/gen/thread.py b/python/shark_turbine/kernel/gen/thread.py index ded5e765e..e49754c11 100644 --- a/python/shark_turbine/kernel/gen/thread.py +++ b/python/shark_turbine/kernel/gen/thread.py @@ -32,7 +32,7 @@ def thread(*symbolic_shape: SymbolDef): def decorator(f: Optional[TCallable] = None) -> "UnconfiguredThread[TCallable]": # Eagerly capture the trace and attach it to the wrapped function. tracer = KernelTracer() - with CompiledContext(tracer) as context: + with CompiledContext(tracer, grid_type=GridType) as context: g = tracer.trace(f) gm = fx.GraphModule(tracer.root, g, f.__name__) diff --git a/tests/kernel/analysis_test.py b/tests/kernel/analysis_test.py new file mode 100644 index 000000000..35c16828e --- /dev/null +++ b/tests/kernel/analysis_test.py @@ -0,0 +1,65 @@ +import logging +import unittest + + +from shark_turbine.kernel._support.indexing import ( + IndexingContext, + sym, +) + +from shark_turbine.kernel.compiler.analysis import ( + SliceAnalysis, + _norm_slice_spec, +) + +M = sym.M +N = sym.N +K = sym.K + + +class SliceAnalysisTest(unittest.TestCase): + def testNorm(self): + self.assertEqual([slice(1, 0, 0)], (_norm_slice_spec(1, 1))) + self.assertEqual([slice(1, 0, 0)], (_norm_slice_spec(1, (1,)))) + self.assertEqual( + [slice(1, None, None)], (_norm_slice_spec(1, (slice(1, None)))) + ) + self.assertEqual([None, slice(2, 0, 0)], (_norm_slice_spec(2, (None, 2)))) + self.assertEqual( + [slice(1, 0, 0), slice(2, 0, 0)], + (_norm_slice_spec(2, (1, ..., 2))), + ) + self.assertEqual( + [slice(1, 0, 0), slice(2, 0, 0)], + (_norm_slice_spec(1, (1, ..., 2))), + ) + self.assertEqual( + [ + None, + slice(None, None, None), + slice(None, None, None), + slice(None, None, None), + slice(None, None, None), + slice(2, 0, 0), + ], + (_norm_slice_spec(5, (None, ..., 2))), + ) + + def testSymbolic(self): + with IndexingContext() as ctx: + ctx.bind_constant(M, 20) + ctx.bind_constant(N, 30) + ctx.bind_constant(K, 5) + + sa = SliceAnalysis((M, N, K), (1, slice(None), slice(2, K), None)) + sa.normalize_symbolic_ranges() + self.assertEqual( + "[slice(1, 0, 0), slice(0, Symbol(N), 1), slice(2, Symbol(K), 1), None]", + repr(sa.slices), + ) + self.assertEqual([1, 30, 3, None], sa.symbolic_shape) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/kernel/indexing_test.py b/tests/kernel/indexing_test.py index b73a11a47..1553717c6 100644 --- a/tests/kernel/indexing_test.py +++ b/tests/kernel/indexing_test.py @@ -65,6 +65,13 @@ def testUsageAndElementTypeInstance(self): T = InputBuffer[M].of(torch.float16) self.assertEqual("InputBuffer[M].of(torch.float16)", repr(T)) + def testBoundedSymbolValue(self): + self.assertEqual("BoundedSymbolicValue(* : *)", (repr(BoundedSymbolicValue))) + B1 = BoundedSymbolicValue.bound(sym_0, None) + self.assertEqual("BoundedSymbolicValue(Symbol(0) : *)", repr(B1)) + B2 = B1.narrow(max_bound=sym_1) + self.assertEqual("BoundedSymbolicValue(Symbol(0) : Symbol(1))", repr(B2)) + if __name__ == "__main__": unittest.main() diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py index 138fc3361..2439bdc8d 100644 --- a/tests/kernel/vector_codegen_test.py +++ b/tests/kernel/vector_codegen_test.py @@ -1,6 +1,7 @@ import logging import unittest +import torch import shark_turbine.kernel as tk from shark_turbine.kernel.compiler import ( @@ -43,6 +44,38 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]): print(mb.module_op.get_asm()) mb.module_op.verify() + def testSoftmaxFx(self): + @tk.gen.thread(M) + def softmax_kernel( + input: tk.lang.KernelBuffer[M, K], output: tk.lang.KernelBuffer[M, K] + ): + row_index = tk.lang.program_id(0) + input_row = input[row_index, :] + numerator = torch.exp(input_row - torch.max(input_row)) + output_row = numerator / torch.sum(numerator) + output[row_index, :] = output_row + + gm = softmax_kernel._trace.gm + print(gm.graph) + mb = builder.ModuleBuilder() + with indexing.IndexingContext() as idxc: + idxc.bind_constant(M, 128) + idxc.bind_constant(K, 64) + + sig = vector_codegen.Signature() + sig.add_from_graph_placeholders(gm.graph) + sig.add_grid(softmax_kernel.grid_type) + print(sig) + try: + emitter = vector_codegen.ThreadEmitter( + mb, softmax_kernel.grid_type, sig + ) + emitter.emit_graph(gm.graph) + finally: + emitter.finish() + print(mb.module_op.get_asm()) + mb.module_op.verify() + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)