From 2e7fad7b883ac516750ee1054c3be8e650f501ff Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Wed, 6 Dec 2023 11:25:25 +0200 Subject: [PATCH] fix generated files --- osiris/cairo/data_converter/data_converter.py | 35 --------- .../data_statement_generator.py | 28 ++++--- osiris/cairo/file_manager/cairo_data.py | 73 +++++++++---------- osiris/cairo/file_manager/file.py | 12 ++- osiris/dtypes/cairo_dtypes.py | 1 + 5 files changed, 54 insertions(+), 95 deletions(-) diff --git a/osiris/cairo/data_converter/data_converter.py b/osiris/cairo/data_converter/data_converter.py index c82f757..8c63d62 100644 --- a/osiris/cairo/data_converter/data_converter.py +++ b/osiris/cairo/data_converter/data_converter.py @@ -49,38 +49,3 @@ def convert_to_cairo(np_array: np.ndarray, output_file: str, dtype: Dtype): tensor = create_tensor(dtype, np_array.shape, np_array) cairo_data = create_cairo_data(output_file, tensor) cairo_data.dump() - - -def inputs_gen(inputs: list[Tensor | Sequence]): - """ - Generate and write Cairo file based on the provided inputs . - - Args: - inputs (list[Tensor | list[Tensor]]): A list of input tensors or tensor sequences. - name (str): The name of the inputs file. - """ - inputs_name = "inputs" - - ModFile().update(inputs_name) - - for i, input in enumerate(inputs): - input_data = CairoData(os.path.join(inputs_name, f"input_{i}.cairo")) - match input: - case list(): - input_data.buffer = CairoData.sequence_template( - func=f"input_{i}", - dtype=input[0].dtype.value, - refs=get_data_refs(input[0].dtype), - data=get_data_statement_for_sequences(input, input[0].dtype), - shape=[x.shape for x in input], - ) - case Tensor(): - input_data.buffer = CairoData.base_template( - func=f"input_{i}", - dtype=input.dtype.value, - refs=get_data_refs(input.dtype), - data=get_data_statement(input.data, input.dtype), - shape=input.shape, - ) - - input_data.dump() diff --git a/osiris/cairo/data_converter/data_statement_generator.py b/osiris/cairo/data_converter/data_statement_generator.py index f04d5af..3b67e51 100644 --- a/osiris/cairo/data_converter/data_statement_generator.py +++ b/osiris/cairo/data_converter/data_statement_generator.py @@ -28,14 +28,6 @@ def get_data_refs(dtype: Dtype) -> list[str]: return refs -class DataStatement(Enum): - U32 = "u32" - I32 = "i32 { mag: {magnitude}, sign: {sign} }" - I8 = "i8 { mag: {magnitude}, sign: {sign} }" - FP8x23 = "FP8x23 { mag: {magnitude}, sign: {sign} }" - FP16x16 = "FP16x16 { mag: {magnitude}, sign: {sign} }" - - def get_data_statement(data: np.ndarray, dtype: Dtype) -> list[str]: """ Generate data statements based on the data type. @@ -47,13 +39,19 @@ def get_data_statement(data: np.ndarray, dtype: Dtype) -> list[str]: Returns: list[str]: The generated data statements. """ - statement_template = DataStatement[dtype.name].value - return [ - statement_template.replace("{magnitude}", str(int(x))).replace( - "{sign}", str(x < 0).lower() - ) - for x in data.flatten() - ] + match dtype: + case Dtype.U32: + return [f"{int(x)}" for x in data.flatten()] + case Dtype.I32: + return ["i32 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()] + case Dtype.I8: + return ["i8 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()] + case Dtype.FP8x23: + return ["FP8x23 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()] + case Dtype.FP16x16: + return ["FP16x16 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()] + case Dtype.BOOL: + return [str(x).lower() for x in data.flatten()] def get_data_statement_for_sequences(data: Sequence, dtype: Dtype) -> list[list[str]]: diff --git a/osiris/cairo/file_manager/cairo_data.py b/osiris/cairo/file_manager/cairo_data.py index 412ad07..0fbea87 100644 --- a/osiris/cairo/file_manager/cairo_data.py +++ b/osiris/cairo/file_manager/cairo_data.py @@ -1,14 +1,11 @@ from osiris.cairo.file_manager.file import File - class CairoData(File): def __init__(self, file: str): - super().__init__(file) # Use pathlib's / operator + super().__init__(file) @classmethod - def base_template( - cls, func: str, dtype: str, refs: list[str], data: list[str], shape: tuple - ) -> list[str]: + def base_template(cls, func: str, dtype: str, refs: list[str], data: list[str], shape: tuple) -> list[str]: """ Create a base template for data representation in Cairo. @@ -27,28 +24,21 @@ def base_template( """ template = [ *[f"use {ref};" for ref in refs], - *[""], - *[f"fn {func}() -> Tensor<{dtype}>" + " {"], - *[" let mut shape = ArrayTrait::::new();"], + *[ ""], + *[f"fn {func}() -> Tensor<{dtype}>"+" {"], + *[ " let mut shape = ArrayTrait::::new();"], *[f" shape.append({s});" for s in shape], - *[""], - *[" let mut data = ArrayTrait::new();"], + *[ ""], + *[ " let mut data = ArrayTrait::new();"], *[f" data.append({d});" for d in data], - *[" TensorTrait::new(shape.span(), data.span())"], - *["}"], + *[ " TensorTrait::new(shape.span(), data.span())"], + *[ "}"], ] return template @classmethod - def sequence_template( - cls, - func: str, - dtype: str, - refs: list[str], - data: list[list[str]], - shape: list[tuple], - ) -> list[str]: + def sequence_template(cls, func: str, dtype: str, refs: list[str], data: list[list[str]], shape: list[tuple]) -> list[str]: """ Create a template for handling tensor sequences in Cairo. @@ -65,24 +55,31 @@ def sequence_template( This method generates a list of strings representing a function in Cairo for handling a sequence of tensors, each with its own data and shape. """ - def expand_sequence_init(s: list[tuple], d: list[list[str]]) -> list[str]: - return [ - f" let mut shape = ArrayTrait::::new();" - f" shape.append({s});" - f" let mut data = ArrayTrait::new();" - f" data.append({d});" - f" sequence.append(TensorTrait::new(shape.span(), data.span()));" - for s, d in zip(s, d) - ] + snippet = [] + for i in range(len(s)): + snippet += [ + *[ " let mut shape = ArrayTrait::::new();"], + *[f" shape.append({s});" for s in s[i]], + *[ ""], + *[ " let mut data = ArrayTrait::new();"], + *[f" data.append({d});" for d in d[i]], + *[ ""], + *[ " sequence.append(TensorTrait::new(shape.span(), data.span()));"], + *[ ""], + ] + + return snippet + + template = [ + *[f"use {ref};" for ref in refs], + *[ ""], + *[f"fn {func}() -> Array>"+" {"], + *[ " let mut sequence = ArrayTrait::new();"], + *[ ""], + *expand_sequence_init(shape, data), + *[ " sequence"], + *[ "}"], + ] - template = [] - template.extend([f"use {ref};" for ref in refs]) - template.append("") - template.append(f"fn {func}() -> Array>" + " {") - template.append(" let mut sequence = ArrayTrait::new();") - template.append("") - template.extend(expand_sequence_init(shape, data)) - template.append(" sequence") - template.append("}") return template diff --git a/osiris/cairo/file_manager/file.py b/osiris/cairo/file_manager/file.py index 132eca0..bf321c9 100644 --- a/osiris/cairo/file_manager/file.py +++ b/osiris/cairo/file_manager/file.py @@ -26,14 +26,13 @@ def update(self, name: str): If it doesn't, the new module statement is appended to the file. """ statement = f"mod {name};" - if any( - line.startswith(statement) for line in self.buffer - ): # Use generator expression + if any([line.startswith(statement) for line in self.buffer]): + # Use generator expression return with self.path.open("a") as f: f.write(f"{statement}\n") - + class File: def __init__(self, path: str): @@ -50,7 +49,7 @@ def __init__(self, path: str): self.path.parent.mkdir(parents=True, exist_ok=True) self.buffer = [] - if self.path.is_file(): # Use pathlib's is_file method + if self.path.is_file(): # Use pathlib's is_file method with self.path.open("r") as f: self.buffer = f.readlines() @@ -62,5 +61,4 @@ def dump(self): properly terminated with a newline character. """ with self.path.open("w") as f: - for line in self.buffer: - f.write(f"{line}\n") + f.writelines([f"{line}\n" for line in self.buffer]) diff --git a/osiris/dtypes/cairo_dtypes.py b/osiris/dtypes/cairo_dtypes.py index f86e89f..890b0c8 100644 --- a/osiris/dtypes/cairo_dtypes.py +++ b/osiris/dtypes/cairo_dtypes.py @@ -8,3 +8,4 @@ class Dtype(Enum): I8 = "i8" I32 = "i32" U32 = "u32" + BOOL = 'bool' \ No newline at end of file