Skip to content

Commit

Permalink
[tk5] Implement slice analysis and sufficient op coverage for a softm…
Browse files Browse the repository at this point in the history
…ax kernel. (nod-ai#222)

* Adds arithmetic binary broadcast.
* Adds f32 arithmetic.
* Handles torch ops exp, max, sum (only reduce to scalar at the moment).
* General python signature validation and unpacking.
* Some initial attribute propagation work needed to handle unsigned
properly.
* A start at full generality for symbolic dimensions and slice analysis.
Needs to be reworked before too long but this gets us basic
functionality on static kernels while leaving room for dynamic to
emerge.
  • Loading branch information
stellaraccident authored Dec 6, 2023
1 parent 6e8eec1 commit 247396c
Show file tree
Hide file tree
Showing 12 changed files with 986 additions and 145 deletions.
192 changes: 188 additions & 4 deletions python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
from . import context

__all__ = [
"BoundedSymbolicValue",
"KernelBuffer",
"Grid",
"InputBuffer",
"OutputBuffer",
"SymbolDef",
"TemporaryBuffer",
"sym",
"sym_0",
"sym_1",
"sym_2",
"sym_n1",
]


Expand Down Expand Up @@ -73,7 +78,37 @@ def ir_type_asm(self) -> str:
###############################################################################


class SymbolDef:
class SymbolExpr:
def is_one(self) -> Optional[bool]:
"""Returns True if the symbol is known to be 1.
Return False if known to be != 1 and None if not known.
"""
raise NotImplementedError

def is_non_negative(self) -> Optional[bool]:
"""Returns True is the symbol is known to be non-negative.
Returns False if known to be negative and None if not known.
"""
raise NotImplementedError

def is_positive(self) -> Optional[bool]:
"""Returns True is the symbol is known to be greater than zero.
Returns False if known to be <= 0 and None if not known.
"""
raise NotImplementedError

def is_negative(self) -> Optional[bool]:
"""Returns True is the symbol is known to be greater than zero.
Returns False if known to be <= 0 and None if not known.
"""
raise NotImplementedError


class SymbolDef(SymbolExpr):
"""Represents a named symbol representing a dimension in a shape."""

ALL_SYMBOLS: ClassVar[dict[str, "SymbolDef"]] = dict()
Expand Down Expand Up @@ -101,9 +136,153 @@ def __getattr__(self, n):

return Expando()

def is_one(self) -> Optional[bool]:
value = IndexingContext.current().get_static_value(self)
if value is None:
return None
return value == 1

def is_non_negative(self) -> Optional[bool]:
value = IndexingContext.current().get_static_value(self)
if value is None:
return None
return value >= 0

def is_positive(self) -> Optional[bool]:
value = IndexingContext.current().get_static_value(self)
if value is None:
return None
return value > 0

def is_negative(self) -> Optional[bool]:
value = IndexingContext.current().get_static_value(self)
if value is None:
return None
return value < 0


sym = SymbolDef.create_expando()

sym_0 = SymbolDef("0")
sym_1 = SymbolDef("1")
sym_2 = SymbolDef("2")
sym_n1 = SymbolDef("-1")


###############################################################################
# Bounded symbolic value.
###############################################################################

BoundedRangeExprT = tuple[Optional[SymbolExpr], Optional[SymbolExpr]]


class _BoundedSymbolicValueMeta(type):
"""Meta-class for deriving new bounded symbolic values."""

range: BoundedRangeExprT

def __new__(mcls, name: str, bases, dct, *, range: BoundedRangeExprT):
dct["range"] = range
dct["__qualname__"] = _bounded_symbolic_value_repr(range=range)
new_class = type.__new__(mcls, name, bases, dct)
return new_class

def __repr__(cls):
return _bounded_symbolic_value_repr(range=cls.range)

@property
def min_bound(cls) -> Optional[SymbolExpr]:
return cls.range[0]

@property
def max_bound(cls) -> Optional[SymbolExpr]:
return cls.range[1]

def bound(
cls: Type[SubtypeT],
min_bound: Optional[SymbolExpr],
max_bound: Optional[SymbolExpr],
) -> Type[SubtypeT]:
class Bounded(BoundedSymbolicValue, range=(min_bound, max_bound)):
...

return Bounded

def narrow(
cls: Type[SubtypeT],
*,
min_bound: Optional[SymbolExpr] = None,
max_bound: Optional[SymbolExpr] = None,
) -> Type[SubtypeT]:
class Bounded(
BoundedSymbolicValue,
range=(
min_bound if min_bound is not None else cls.min_bound,
max_bound if max_bound is not None else cls.max_bound,
),
):
...

return Bounded


def _bounded_symbolic_value_repr(*, range: BoundedRangeExprT) -> str:
min_expr, max_expr = range
min_s = repr(min_expr) if min_expr is not None else "*"
max_s = repr(max_expr) if max_expr is not None else "*"
return f"BoundedSymbolicValue({min_s} : {max_s})"


class BoundedSymbolicValue(
SymbolExpr, metaclass=_BoundedSymbolicValueMeta, range=(None, None)
):
"""Represents a symbolic value that is bounded to a range fixed for the type."""

def __init__(self, value: Optional[int] = None):
self.value = value

def __repr__(self):
return f"{type(self)}({'proxy' if self.value is None else self.value})"

@property
def static_range(self) -> Optional[tuple[int, int]]:
# TODO: This is a hack until shape derivation is in place.
ctx = IndexingContext.current()
mn, mx = type(self).range
if mn is not None:
mn = ctx.get_static_value(mn)
if mx is not None:
mx = ctx.get_static_value(mx)
if mn is not None and mx is not None:
return mn, mx
else:
return None

def is_one(self) -> Optional[bool]:
r = self.static_range
if r:
return r[0] == 1 and r[1] == 2
return None

def is_non_negative(self) -> Optional[bool]:
r = self.static_range
if r:
return r[0] >= 0
return None

def is_positive(self) -> Optional[bool]:
r = self.static_range
if r:
return r[0] > 0
return None

def is_negative(self) -> Optional[bool]:
r = self.static_range
if r:
return r[1] < 0
return None


###############################################################################
# Grid
###############################################################################
Expand Down Expand Up @@ -271,7 +450,7 @@ def __repr__(cls):
)


def _is_kernel_buffer_meta_derived(t: type) -> bool:
def is_kernel_buffer_meta_derived(t: type) -> bool:
return isinstance(t, _KernelBufferMeta)


Expand Down Expand Up @@ -361,7 +540,12 @@ class IndexingContext:
__tk_context_idname__ = "IndexingContext"

def __init__(self):
self.constant_bindings: dict[SymbolDef, int] = {}
self.constant_bindings: dict[SymbolDef, int] = {
sym_0: 0,
sym_1: 1,
sym_2: 2,
sym_n1: -1,
}

def bind_constant(self, sym: SymbolDef, value: int):
existing = self.constant_bindings.get(sym)
Expand All @@ -371,7 +555,7 @@ def bind_constant(self, sym: SymbolDef, value: int):
)
self.constant_bindings[sym] = value

def get_static_value(self, sym: SymbolDef) -> Optional[int]:
def get_static_value(self, sym: SymbolExpr) -> Optional[int]:
"""If the symbol can be resolved to a static value, returns it."""
return self.constant_bindings.get(sym)

Expand Down
17 changes: 15 additions & 2 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import torch.fx as fx

from .indexing import (
BoundedSymbolicValue,
Grid,
KernelBuffer,
sym_0,
)

from ..lang.types import (
Expand Down Expand Up @@ -106,13 +109,23 @@ def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, ite


class CompiledContext(BaseContext):
def __init__(self, tracer: KernelTracer):
def __init__(self, tracer: KernelTracer, *, grid_type: Type[Grid]):
super().__init__(eager=False)
self.tracer = tracer
self.grid_type = grid_type

def handle_thread_program_id(self, op, axis: int) -> Index:
grid_shape = self.grid_type.symbolic_shape
if axis < 0 or axis >= len(grid_shape):
raise IndexError(
f"Illegal index into grid of rank {len(grid_shape)}: {axis}"
)
proxy = self.tracer.create_proxy(
"call_function", op, args=(axis,), kwargs={}, type_expr=Index
"call_function",
op,
args=(axis,),
kwargs={},
type_expr=BoundedSymbolicValue.bound(sym_0, grid_shape[axis]),
)
return proxy

Expand Down
Loading

0 comments on commit 247396c

Please sign in to comment.