diff --git a/programming_examples/basic/passthrough_kernel/aie2.py b/programming_examples/basic/passthrough_kernel/aie2.py index f146341f6f..1a219ee097 100644 --- a/programming_examples/basic/passthrough_kernel/aie2.py +++ b/programming_examples/basic/passthrough_kernel/aie2.py @@ -16,6 +16,7 @@ from aie.api.kernels.binkernel import BinKernel from aie.api.phys.device import NPU1Col1 from aie.api.program import MyProgram +from aie.api.tensor import MyTensorType from aie.api.worker import MyWorker try: @@ -28,19 +29,16 @@ assert vector_size % 4 == 0 line_size = vector_size // 4 - -# TODO: clean up types -inout_type = ((vector_size,), np.uint8) -fifo_memref_type = ((line_size,), np.uint8) +line_type = MyTensorType(np.uint8, (line_size,)) # TODO: rely on depth inference -of_in = MyObjectFifo(2, memref_type=fifo_memref_type) -of_out = MyObjectFifo(2, memref_type=fifo_memref_type) +of_in = MyObjectFifo(2, line_type) +of_out = MyObjectFifo(2, line_type) passthrough_fn = BinKernel( "passThroughLine", "passThrough.cc.o", - [fifo_memref_type, fifo_memref_type, np.int32], + [line_type, line_type, np.int32], ) diff --git a/python/api/dataflow/endpoint.py b/python/api/dataflow/endpoint.py index 5fd048f041..255be139e2 100644 --- a/python/api/dataflow/endpoint.py +++ b/python/api/dataflow/endpoint.py @@ -6,7 +6,7 @@ * join """ -from abc import ABC, abstractmethod +from abc import abstractmethod from ..phys.tile import MyTile from ..resolvable import Resolvable diff --git a/python/api/dataflow/inout/inout.py b/python/api/dataflow/inout/inout.py index 1b3d1d8e76..6edde8f290 100644 --- a/python/api/dataflow/inout/inout.py +++ b/python/api/dataflow/inout/inout.py @@ -4,11 +4,12 @@ """ from abc import abstractmethod +from typing import Callable from .... import ir from ....extras.util import np_dtype_to_mlir_type from ....dialects.aiex import runtime_sequence -from ....dialects.memref import MemRefType +from ...tensor import MyTensorType from ..endpoint import MyObjectFifoEndpoint from ..objectfifo import ObjectFifoHandle from ...phys.tile import MyTile @@ -22,8 +23,8 @@ def get_fifos(self) -> list[ObjectFifoHandle]: ... class MyInOutProgram(InOutProgram): def __init__( self, - sequence_fn, # TODO: needs a type - inout_types, # TODO: needs a type + sequence_fn: Callable[..., None], + inout_types: list[MyTensorType], fifos=list[ObjectFifoHandle], coords: tuple[int, int] = (0, 0), # TODO: how to get default ): @@ -36,7 +37,7 @@ def __init__( def get_tile(self) -> MyTile: assert self.tile != None - return self.tile.op + return self.tile def get_fifos(self) -> list[ObjectFifoHandle]: return self.fifos @@ -45,14 +46,8 @@ def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: - # TODO: cleanup types - my_memref_types = [ - MemRefType.get(shape=t[0], element_type=np_dtype_to_mlir_type(t[1])) - for t in self.inout_types - ] - + my_memref_types = [t.memref_type for t in self.inout_types] @runtime_sequence(*my_memref_types) def sequence(*args, **kwargs): self.sequence_fn(*args, *self.fifos, **kwargs) diff --git a/python/api/dataflow/inout/simplefifoinout.py b/python/api/dataflow/inout/simplefifoinout.py index 927c37c91e..1f6ae01a44 100644 --- a/python/api/dataflow/inout/simplefifoinout.py +++ b/python/api/dataflow/inout/simplefifoinout.py @@ -5,13 +5,14 @@ """ import numpy as np +from typing import Optional from .... import ir -from ....extras.util import np_dtype_to_mlir_type from ....dialects.aiex import runtime_sequence, npu_dma_memcpy_nd, npu_sync, T from .inout import InOutProgram from ...phys.tile import MyTile from ..objectfifo import ObjectFifoHandle +from ...tensor import MyTensorType class SimpleFifoInOutProgram(InOutProgram): @@ -21,14 +22,17 @@ def __init__( bytes_in: int, fifo_out: ObjectFifoHandle, bytes_out: int, - in_sizes: list[int] = None, - in_strides: list[int] = None, - out_sizes: list[int] = None, - out_strides: list[int] = None, - dtype=np.uint8, # TODO: needs type + in_sizes: Optional[list[int]] = None, + in_strides: Optional[list[int]] = None, + out_sizes: Optional[list[int]] = None, + out_strides: Optional[list[int]] = None, + dtype: np.generic = np.uint8, ): assert bytes_in % np.prod(fifo_in.obj_type[0]) == 0 assert bytes_out % np.prod(fifo_out.obj_type[0]) == 0 + assert bytes_in > 0 + assert bytes_out > 0 + self.fifo_in = fifo_in self.fifo_out = fifo_out self.bytes_in = bytes_in @@ -77,7 +81,7 @@ def __init__( def get_tile(self) -> MyTile: assert self.tile != None - return self.tile.op + return self.tile def get_fifos(self) -> list[ObjectFifoHandle]: return [self.fifo_in, self.fifo_out] @@ -86,10 +90,9 @@ def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: - tensor_in_ty = T.memref(self.bytes_in, np_dtype_to_mlir_type(self.dtype)) - tensor_out_ty = T.memref(self.bytes_out, np_dtype_to_mlir_type(self.dtype)) + tensor_in_ty = MyTensorType(self.bytes_in, self.dtype).memref_type + tensor_out_ty = MyTensorType(self.bytes_out, self.dtype).memref_type @runtime_sequence(tensor_in_ty, tensor_out_ty) def sequence(inTensor, outTensor): diff --git a/python/api/dataflow/objectfifo.py b/python/api/dataflow/objectfifo.py index e6162ff2ad..5376f3cf95 100644 --- a/python/api/dataflow/objectfifo.py +++ b/python/api/dataflow/objectfifo.py @@ -7,6 +7,7 @@ # Address circular dependency between MyObjectFifo and ObjectFifoHandle from __future__ import annotations +from typing import Optional from ... import ir from ..._mlir_libs._aie import ObjectFifoSubviewType @@ -21,10 +22,9 @@ ) from ...dialects.aie import object_fifo -from ...dialects.memref import MemRefType from ..resolvable import Resolvable from .endpoint import MyObjectFifoEndpoint - +from ..tensor import MyTensorType class MyObjectFifo(Resolvable): __of_index = 0 @@ -32,15 +32,15 @@ class MyObjectFifo(Resolvable): def __init__( self, depth: int = 1, - memref_type=None, + obj_type: MyTensorType = None, name: str = None, end1: MyObjectFifoEndpoint = None, end2: MyObjectFifoEndpoint = None, - dimensionsToStream=None, # TODO(erika): needs a type - dimensionsFromStreamPerConsumer=None, # TODO(erika): needs a type + dimensionsToStream=list[list[int]], # TODO(erika): needs a type + dimensionsFromStreamPerConsumer=list[list[int]], # TODO(erika): needs a type ): - self.__depth: int = depth - self.__memref_type = memref_type + self.__depth = depth + self.__obj_type = obj_type self.end1: MyObjectFifoEndpoint = end1 self.end2: MyObjectFifoEndpoint = end2 self.dimensionToStream = dimensionsToStream @@ -50,9 +50,9 @@ def __init__( self.name = f"myof{MyObjectFifo.__get_index()}" else: self.name = name - self.__op = None - self.__first = None - self.__second = None + self.__op: Optional[ObjectFifoCreateOp] = None + self.__first: ObjectFifoHandle = ObjectFifoHandle(self, True) + self.__second: ObjectFifoHandle = ObjectFifoHandle(self, False) @classmethod def __get_index(cls) -> int: @@ -67,48 +67,38 @@ def op(self) -> ObjectFifoCreateOp: @property def first(self) -> ObjectFifoHandle: - if self.__first == None: - self.__first = ObjectFifoHandle(self, True) return self.__first - # TODO: type this - @property - def obj_type(self): - return self.__memref_type - @property def second(self) -> ObjectFifoHandle: - if self.__second == None: - self.__second = ObjectFifoHandle(self, False) return self.__second + @property + def obj_type(self) -> MyTensorType: + return self.__obj_type + def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: - if self.__op != None: - return - assert self.end1 != None, "ObjectFifo missing endpoint 1" - assert self.end2 != None, "ObjectFifo missing endpoint 2" - assert self.__memref_type != None, "ObjectFifo missing memref_type" - dtype = np_dtype_to_mlir_type(self.__memref_type[1]) - assert dtype != None - memRef_ty = MemRefType.get(shape=self.__memref_type[0], element_type=dtype) - self.__op = object_fifo( - self.name, - self.end1.get_tile(), - self.end2.get_tile(), - self.__depth, - memRef_ty, - dimensionsToStream=self.dimensionToStream, - dimensionsFromStreamPerConsumer=self.dimensionsFromStreamPerConsumer, - loc=loc, - ip=ip, - ) - - def _set_endpoint(self, endpoint, first=True): + if self.__op == None: + assert self.end1 != None, "ObjectFifo missing endpoint 1" + assert self.end2 != None, "ObjectFifo missing endpoint 2" + assert self.__memref_type != None, "ObjectFifo missing memref_type" + self.__op = object_fifo( + self.name, + self.end1.get_tile().op, + self.end2.get_tile().op, + self.__depth, + self.__obj_type.memref_type, + dimensionsToStream=self.dimensionToStream, + dimensionsFromStreamPerConsumer=self.dimensionsFromStreamPerConsumer, + loc=loc, + ip=ip, + ) + + def _set_endpoint(self, endpoint: MyObjectFifoEndpoint, first: bool = True) -> None: if first: assert self.end1 == None, "ObjectFifo already assigned endpoint 1" self.end1 = endpoint @@ -116,51 +106,40 @@ def _set_endpoint(self, endpoint, first=True): assert self.end2 == None, "ObjectFifo already assigned endpoint 2" self.end2 = endpoint - def _acquire( - self, port: ObjectFifoPort, num_elem: int, loc=None, ip=None, context=None - ): + def _acquire(self, port: ObjectFifoPort, num_elem: int, loc: ir.Location = None, ip: ir.InsertionPoint = None): assert num_elem > 0, "Must consume at least one element" - assert ( - num_elem <= self.__depth - ), "Cannot consume elements to exceed ObjectFifo depth" - dtype = np_dtype_to_mlir_type(self.__memref_type[1]) - assert dtype != None - memRef_ty = MemRefType.get(shape=self.__memref_type[0], element_type=dtype) - subview_t = ObjectFifoSubviewType.get(memRef_ty) - acq = ObjectFifoAcquireOp(subview_t, port, self.name, num_elem) + assert num_elem <= self.__depth, "Cannot consume elements to exceed ObjectFifo depth" + + subview_t = ObjectFifoSubviewType.get(self.__obj_type.memref_type) + acq = ObjectFifoAcquireOp(subview_t, port, self.name, num_elem, loc=loc, ip=ip) objects = [] if acq.size.value == 1: - return ObjectFifoSubviewAccessOp(memRef_ty, acq.subview, acq.size.value - 1) + return ObjectFifoSubviewAccessOp(self.__obj_type.memref_type, acq.subview, acq.size.value - 1, loc=loc, ip=ip) for i in range(acq.size.value): - objects.append(ObjectFifoSubviewAccessOp(memRef_ty, acq.subview, i)) + objects.append(ObjectFifoSubviewAccessOp(self.__obj_type.memref_type, acq.subview, i, loc=loc, ip=ip)) return objects - def _release( - self, port: ObjectFifoPort, num_elem: int, loc=None, ip=None, context=None - ): + def _release(self, port: ObjectFifoPort, num_elem: int, loc: ir.Location = None, ip: ir.InsertionPoint = None): assert num_elem > 0, "Must consume at least one element" - assert ( - num_elem <= self.__depth - ), "Cannot consume elements to exceed ObjectFifo depth" + assert num_elem <= self.__depth, "Cannot consume elements to exceed ObjectFifo depth" objectfifo_release(port, self.name, num_elem, loc=loc, ip=ip) class ObjectFifoHandle(Resolvable): def __init__(self, of: MyObjectFifo, is_first: bool): - self.__port = ObjectFifoPort.Produce if is_first else ObjectFifoPort.Consume + self.__port: ObjectFifoPort = ObjectFifoPort.Produce if is_first else ObjectFifoPort.Consume self.__is_first = is_first self.__object_fifo = of - def acquire(self, num_elem: int, loc=None, ip=None, context=None): - return self.__object_fifo._acquire(self.__port, num_elem) + def acquire(self, num_elem: int, loc: ir.Location = None, ip: ir.InsertionPoint = None): + return self.__object_fifo._acquire(self.__port, num_elem, loc=loc, ip=ip) - def release(self, num_elem: int, loc=None, ip=None, context=None): - return self.__object_fifo._release(self.__port, num_elem) + def release(self, num_elem: int, loc: ir.Location = None, ip: ir.InsertionPoint = None): + return self.__object_fifo._release(self.__port, num_elem, loc=loc, ip=ip) - # TODO: type this @property - def obj_type(self): + def obj_type(self) -> MyTensorType: return self.__object_fifo.obj_type @property @@ -178,6 +157,5 @@ def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: - return self.__object_fifo.resolve(loc, ip, context) + return self.__object_fifo.resolve(loc=loc, ip=ip) diff --git a/python/api/dataflow/objectfifolink.py b/python/api/dataflow/objectfifolink.py index 1a3b5bde8a..82532f1ea2 100644 --- a/python/api/dataflow/objectfifolink.py +++ b/python/api/dataflow/objectfifolink.py @@ -1,19 +1,20 @@ -from ... import ir +from typing import Optional +from ... import ir from ...dialects._aie_ops_gen import ObjectFifoLinkOp from ...dialects.aie import object_fifo_link +from ..tensor import MyTensorType from ..phys.tile import MyTile from .endpoint import MyObjectFifoEndpoint from .objectfifo import ObjectFifoHandle - class MyObjectFifoLink(MyObjectFifoEndpoint): def __init__( self, - seconds: list[ObjectFifoHandle] = None, - firsts: list[ObjectFifoHandle] = None, - coords: tuple[int, int] = None, + seconds: list[ObjectFifoHandle] = [], + firsts: list[ObjectFifoHandle] = [], + coords: Optional[tuple[int, int]] = None, ): column, row = coords self.tile = MyTile(column, row) @@ -22,43 +23,38 @@ def __init__( self.__firsts = [] self.__op = None - self.__memref_type = seconds[0].obj_type + self.__obj_type = seconds[0].obj_type for s in seconds: - assert isinstance(s, ObjectFifoHandle) - assert s.obj_type == self.__memref_type + assert s.obj_type == self.__obj_type s.set_endpoint(self) self.__seconds.append(s) for f in firsts: - assert isinstance(f, ObjectFifoHandle) - assert f.obj_type == self.__memref_type + assert f.obj_type == self.__obj_type f.set_endpoint(self) self.__firsts.append(f) def get_tile(self) -> MyTile: assert self.tile != None - return self.tile.op + return self.tile @property def op(self) -> ObjectFifoLinkOp: assert self.__op != None return self.__op - # TODO: type this @property - def obj_type(self): - return self.__memref_type + def obj_type(self) -> MyTensorType: + return self.__obj_type def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: - if self.__op != None: - return - self.__op = object_fifo_link( - [s.op for s in self.__seconds], - [f.op for f in self.__firsts], - loc=loc, - ip=ip, - ) + if self.__op == None: + self.__op = object_fifo_link( + [s.op for s in self.__seconds], + [f.op for f in self.__firsts], + loc=loc, + ip=ip, + ) diff --git a/python/api/kernels/binkernel.py b/python/api/kernels/binkernel.py index c14739597f..8c7fb8fcaf 100644 --- a/python/api/kernels/binkernel.py +++ b/python/api/kernels/binkernel.py @@ -3,25 +3,25 @@ * docs * types for inout_types """ +import numpy as np +from typing import Optional from ... import ir + from ...extras.util import np_dtype_to_mlir_type, get_arg_types from ...dialects.memref import MemRefType +from ...dialects.func import FuncOp from ...dialects.aie import external_func, call +from ..tensor import MyTensorType from .kernel import MyKernel class BinKernel(MyKernel): - def __init__(self, name: str, bin_name: str, inout_types: list) -> None: - assert isinstance(name, str) - assert len(name) > 0 - assert isinstance(bin_name, str) - assert len(bin_name) > 0 - assert isinstance(inout_types, list) + def __init__(self, name: str, bin_name: str, inout_types: list[MyTensorType, np.generic] = []) -> None: self.__name = name self.__bin_name = bin_name self.__inout_types = inout_types - self.__op = None + self.__op: Optional[FuncOp] = None @property def bin_name(self) -> str: @@ -31,23 +31,22 @@ def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: - assert self.__op == None - resolved_inout_types = [] - for t in self.__inout_types: - try: - dtype = np_dtype_to_mlir_type(t) - except Exception: - dtype = get_arg_types(t) - if dtype is None: - # Interpret as a dummy memref - dtype = MemRefType.get( - shape=t[0], element_type=np_dtype_to_mlir_type(t[1]) - ) - resolved_inout_types.append(dtype) - self.__op = external_func(self.__name, inputs=resolved_inout_types) + if self.__op == None: + resolved_inout_types = [] + for t in self.__inout_types: + try: + dtype = np_dtype_to_mlir_type(t) + except Exception: + dtype = get_arg_types(t) + if dtype is None: + # Interpret as a dummy memref + dtype = MemRefType.get( + shape=t[0], element_type=np_dtype_to_mlir_type(t[1]) + ) + resolved_inout_types.append(dtype) + self.__op = external_func(self.__name, inputs=resolved_inout_types) def __call__(self, *args, **kwargs): - assert self.__op + assert self.__op, "Need to resolve BinKernel before it can be called" call(self.__name, args, kwargs) diff --git a/python/api/kernels/kernel.py b/python/api/kernels/kernel.py index 6feab152ec..9644501841 100644 --- a/python/api/kernels/kernel.py +++ b/python/api/kernels/kernel.py @@ -10,4 +10,5 @@ class MyKernel(Resolvable): @abstractmethod - def __call__(self, *args, **kwargs): ... + def __call__(self, *args, **kwargs) -> None: + pass diff --git a/python/api/phys/device.py b/python/api/phys/device.py index a3a9c3a501..3286e8d82a 100644 --- a/python/api/phys/device.py +++ b/python/api/phys/device.py @@ -14,74 +14,60 @@ } """ -from abc import abstractmethod +from typing import Optional from ... import ir -from ...dialects.aie import AIEDevice, tile - +from ...dialects.aie import AIEDevice, tile, TileOp from ..resolvable import Resolvable from .tile import MyTile class MyDevice(Resolvable): - class MyDeviceTile(Resolvable): - def __init__(self, column: int, row: int) -> None: - assert isinstance(column, int) - assert isinstance(row, int) - self.__column: int = column + """ + Note: this class is abstract because it does not implement Resolve + """ + class __MyDeviceTile(Resolvable): + """ + Interior class for tiles objects owned by a particular device + """ + def __init__(self, col: int, row: int) -> None: + self.__col: int = col self.__row: int = row - self.__op: int = None - - @property - def op(self): - assert self.__op != None - return self.__op + self.__op: Optional[TileOp] = None def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: - if self.__op != None: - return - self.__op = tile(self.__column, self.__row, loc=loc, ip=ip) + if self.__op == None: + self.__op = tile(self.__col, self.__row, loc=loc, ip=ip) - @property - @abstractmethod - def rows(self) -> int: ... - - @property - @abstractmethod - def columns(self) -> int: ... - - @abstractmethod - def resolve_tile( - self, - col: int, - row: int, - loc: ir.Location = None, - ip: ir.InsertionPoint = None, - context: ir.Context = None, - ) -> None: ... + @property + def op(self) -> TileOp: + assert self.__op != None + return self.__op + @op.setter + def op(self, op: TileOp): + assert self.__op == None + self.__op = op -class NPU1Col1(MyDevice): - def __init__(self) -> None: - self.__cols: int = 1 - self.__rows: int = 4 - self.__tiles: list[list[MyDevice.MyDeviceTile]] = [] + def __init__(self, cols: int, rows: int) -> None: + self.__cols = cols + self.__rows = rows + self.__tiles: list[list[MyDevice.__MyDeviceTile]] = [] for c in range(self.__cols): self.__tiles.append([]) for r in range(self.__rows): - self.__tiles[c].append(MyDevice.MyDeviceTile(c, r)) + self.__tiles[c].append(MyDevice.__MyDeviceTile(c, r)) @property def rows(self) -> int: return self.__rows @property - def columns(self) -> int: + def cols(self) -> int: return self.__cols def resolve_tile( @@ -89,15 +75,18 @@ def resolve_tile( tile: MyTile, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: - self.__tiles[tile.col][tile.row].resolve(loc, ip, context) + self.__tiles[tile.col][tile.row].resolve(loc, ip) tile.op = self.__tiles[tile.col][tile.row].op + +class NPU1Col1(MyDevice): + def __init__(self) -> None: + super().__init__(cols=1, rows=4) + def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: return AIEDevice.npu1_1col diff --git a/python/api/phys/tile.py b/python/api/phys/tile.py index 9a3a467d71..26197d1b41 100644 --- a/python/api/phys/tile.py +++ b/python/api/phys/tile.py @@ -4,14 +4,21 @@ * error handling * tile types" """ - -from ... import ir - +from typing import Optional +from ...dialects.aie import TileOp class MyTile: def __init__(self, col: int, row: int) -> None: - assert isinstance(col, int) - assert isinstance(row, int) self.col: int = col self.row: int = row - self.op = None + self.__op: Optional[TileOp] = None + + @property + def op(self) -> TileOp: + assert self.__op != None + return self.__op + + @op.setter + def op(self, op: TileOp): + assert self.__op == None + self.__op = op \ No newline at end of file diff --git a/python/api/program.py b/python/api/program.py index 38236a4849..975b9c04ba 100644 --- a/python/api/program.py +++ b/python/api/program.py @@ -21,13 +21,6 @@ def __init__( inout_program: InOutProgram, links: list[MyObjectFifoLink] = [], ): - assert isinstance(device, MyDevice) - assert worker_programs != None and len(worker_programs) >= 1 - for w in worker_programs: - assert isinstance(w, MyWorker) - assert isinstance(inout_program, InOutProgram) - for l in links: - assert isinstance(l, MyObjectFifoLink) self.__device = device self.__worker_programs = worker_programs self.__inout_program = inout_program diff --git a/python/api/resolvable.py b/python/api/resolvable.py index 1d54aa19a2..84e92dcbe0 100644 --- a/python/api/resolvable.py +++ b/python/api/resolvable.py @@ -14,5 +14,4 @@ def resolve( cls, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: ... diff --git a/python/api/tensor.py b/python/api/tensor.py new file mode 100644 index 0000000000..a002084479 --- /dev/null +++ b/python/api/tensor.py @@ -0,0 +1,31 @@ + +from functools import cached_property +import numpy as np +from typing import Literal + +from ..extras.util import np_dtype_to_mlir_type +from ..dialects.memref import MemRefType + +class MyTensorType: + def __init__(self, dtype: np.generic, shape: np.generic.shape): + self.__dtype = dtype + self.__shape = shape + self.__my_numpy_type = np.ndarray[dtype, Literal[tuple(shape)]] + + @cached_property + def memref_type(self) -> MemRefType: + return MemRefType.get( + shape=self.__shape, element_type=np_dtype_to_mlir_type(self.__dtype) + ) + + @property + def shape(self) -> np.generic.shape: + return self.__shape + + @property + def dtype(self) -> np.generic: + return self.__dtype + + def __eq__(self, other): + # TODO: may want to be equal to numpy datatypes as well?? + return self.__my_numpy_type == other.__my_numpy_type \ No newline at end of file diff --git a/python/api/worker.py b/python/api/worker.py index 5f00034da3..0011a90221 100644 --- a/python/api/worker.py +++ b/python/api/worker.py @@ -6,7 +6,7 @@ """ import sys -from typing import Union +from typing import Callable, Optional, Union from .. import ir from ..dialects.aie import core @@ -20,24 +20,20 @@ class MyWorker(MyObjectFifoEndpoint): def __init__( self, - core_fn, - fn_args: list[Union[ObjectFifoHandle, MyKernel]], + core_fn: Optional[Callable[[Union[ObjectFifoHandle, MyKernel], None]]], + fn_args: list[Union[ObjectFifoHandle, MyKernel]] = [], coords: tuple[int, int] = None, ): column, row = coords self.tile = MyTile(column, row) if core_fn is None: - - def do_nothing_core_fun(): + def do_nothing_core_fun() -> None: for _ in for_(sys.maxsize): yield_([]) - self.core_fn = do_nothing_core_fun else: self.core_fn = core_fn - self.link_with = None - - assert isinstance(fn_args, list) + self.link_with: Optional[str] = None self.fn_args = fn_args bin_names = set() @@ -53,13 +49,12 @@ def do_nothing_core_fun(): def get_tile(self) -> MyTile: assert self.tile != None - return self.tile.op + return self.tile def resolve( self, loc: ir.Location = None, ip: ir.InsertionPoint = None, - context: ir.Context = None, ) -> None: my_tile = self.tile.op my_link = self.link_with