Skip to content

Commit

Permalink
Merge pull request #1 from gizatechxyz/fix-osiris
Browse files Browse the repository at this point in the history
fix generated files
  • Loading branch information
raphaelDkhn authored Dec 6, 2023
2 parents a8868cf + 2e7fad7 commit 1e0e9f5
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 95 deletions.
35 changes: 0 additions & 35 deletions osiris/cairo/data_converter/data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,3 @@ def convert_to_cairo(np_array: np.ndarray, output_file: str, dtype: Dtype):
tensor = create_tensor(dtype, np_array.shape, np_array)
cairo_data = create_cairo_data(output_file, tensor)
cairo_data.dump()


def inputs_gen(inputs: list[Tensor | Sequence]):
"""
Generate and write Cairo file based on the provided inputs .
Args:
inputs (list[Tensor | list[Tensor]]): A list of input tensors or tensor sequences.
name (str): The name of the inputs file.
"""
inputs_name = "inputs"

ModFile().update(inputs_name)

for i, input in enumerate(inputs):
input_data = CairoData(os.path.join(inputs_name, f"input_{i}.cairo"))
match input:
case list():
input_data.buffer = CairoData.sequence_template(
func=f"input_{i}",
dtype=input[0].dtype.value,
refs=get_data_refs(input[0].dtype),
data=get_data_statement_for_sequences(input, input[0].dtype),
shape=[x.shape for x in input],
)
case Tensor():
input_data.buffer = CairoData.base_template(
func=f"input_{i}",
dtype=input.dtype.value,
refs=get_data_refs(input.dtype),
data=get_data_statement(input.data, input.dtype),
shape=input.shape,
)

input_data.dump()
28 changes: 13 additions & 15 deletions osiris/cairo/data_converter/data_statement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ def get_data_refs(dtype: Dtype) -> list[str]:
return refs


class DataStatement(Enum):
U32 = "u32"
I32 = "i32 { mag: {magnitude}, sign: {sign} }"
I8 = "i8 { mag: {magnitude}, sign: {sign} }"
FP8x23 = "FP8x23 { mag: {magnitude}, sign: {sign} }"
FP16x16 = "FP16x16 { mag: {magnitude}, sign: {sign} }"


def get_data_statement(data: np.ndarray, dtype: Dtype) -> list[str]:
"""
Generate data statements based on the data type.
Expand All @@ -47,13 +39,19 @@ def get_data_statement(data: np.ndarray, dtype: Dtype) -> list[str]:
Returns:
list[str]: The generated data statements.
"""
statement_template = DataStatement[dtype.name].value
return [
statement_template.replace("{magnitude}", str(int(x))).replace(
"{sign}", str(x < 0).lower()
)
for x in data.flatten()
]
match dtype:
case Dtype.U32:
return [f"{int(x)}" for x in data.flatten()]
case Dtype.I32:
return ["i32 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()]
case Dtype.I8:
return ["i8 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()]
case Dtype.FP8x23:
return ["FP8x23 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()]
case Dtype.FP16x16:
return ["FP16x16 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()]
case Dtype.BOOL:
return [str(x).lower() for x in data.flatten()]


def get_data_statement_for_sequences(data: Sequence, dtype: Dtype) -> list[list[str]]:
Expand Down
73 changes: 35 additions & 38 deletions osiris/cairo/file_manager/cairo_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from osiris.cairo.file_manager.file import File


class CairoData(File):
def __init__(self, file: str):
super().__init__(file) # Use pathlib's / operator
super().__init__(file)

@classmethod
def base_template(
cls, func: str, dtype: str, refs: list[str], data: list[str], shape: tuple
) -> list[str]:
def base_template(cls, func: str, dtype: str, refs: list[str], data: list[str], shape: tuple) -> list[str]:
"""
Create a base template for data representation in Cairo.
Expand All @@ -27,28 +24,21 @@ def base_template(
"""
template = [
*[f"use {ref};" for ref in refs],
*[""],
*[f"fn {func}() -> Tensor<{dtype}>" + " {"],
*[" let mut shape = ArrayTrait::<usize>::new();"],
*[ ""],
*[f"fn {func}() -> Tensor<{dtype}>"+" {"],
*[ " let mut shape = ArrayTrait::<usize>::new();"],
*[f" shape.append({s});" for s in shape],
*[""],
*[" let mut data = ArrayTrait::new();"],
*[ ""],
*[ " let mut data = ArrayTrait::new();"],
*[f" data.append({d});" for d in data],
*[" TensorTrait::new(shape.span(), data.span())"],
*["}"],
*[ " TensorTrait::new(shape.span(), data.span())"],
*[ "}"],
]

return template

@classmethod
def sequence_template(
cls,
func: str,
dtype: str,
refs: list[str],
data: list[list[str]],
shape: list[tuple],
) -> list[str]:
def sequence_template(cls, func: str, dtype: str, refs: list[str], data: list[list[str]], shape: list[tuple]) -> list[str]:
"""
Create a template for handling tensor sequences in Cairo.
Expand All @@ -65,24 +55,31 @@ def sequence_template(
This method generates a list of strings representing a function in Cairo for handling a sequence
of tensors, each with its own data and shape.
"""

def expand_sequence_init(s: list[tuple], d: list[list[str]]) -> list[str]:
return [
f" let mut shape = ArrayTrait::<usize>::new();"
f" shape.append({s});"
f" let mut data = ArrayTrait::new();"
f" data.append({d});"
f" sequence.append(TensorTrait::new(shape.span(), data.span()));"
for s, d in zip(s, d)
]
snippet = []
for i in range(len(s)):
snippet += [
*[ " let mut shape = ArrayTrait::<usize>::new();"],
*[f" shape.append({s});" for s in s[i]],
*[ ""],
*[ " let mut data = ArrayTrait::new();"],
*[f" data.append({d});" for d in d[i]],
*[ ""],
*[ " sequence.append(TensorTrait::new(shape.span(), data.span()));"],
*[ ""],
]

return snippet

template = [
*[f"use {ref};" for ref in refs],
*[ ""],
*[f"fn {func}() -> Array<Tensor<{dtype}>>"+" {"],
*[ " let mut sequence = ArrayTrait::new();"],
*[ ""],
*expand_sequence_init(shape, data),
*[ " sequence"],
*[ "}"],
]

template = []
template.extend([f"use {ref};" for ref in refs])
template.append("")
template.append(f"fn {func}() -> Array<Tensor<{dtype}>>" + " {")
template.append(" let mut sequence = ArrayTrait::new();")
template.append("")
template.extend(expand_sequence_init(shape, data))
template.append(" sequence")
template.append("}")
return template
12 changes: 5 additions & 7 deletions osiris/cairo/file_manager/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ def update(self, name: str):
If it doesn't, the new module statement is appended to the file.
"""
statement = f"mod {name};"
if any(
line.startswith(statement) for line in self.buffer
): # Use generator expression
if any([line.startswith(statement) for line in self.buffer]):
# Use generator expression
return

with self.path.open("a") as f:
f.write(f"{statement}\n")


class File:
def __init__(self, path: str):
Expand All @@ -50,7 +49,7 @@ def __init__(self, path: str):
self.path.parent.mkdir(parents=True, exist_ok=True)
self.buffer = []

if self.path.is_file(): # Use pathlib's is_file method
if self.path.is_file(): # Use pathlib's is_file method
with self.path.open("r") as f:
self.buffer = f.readlines()

Expand All @@ -62,5 +61,4 @@ def dump(self):
properly terminated with a newline character.
"""
with self.path.open("w") as f:
for line in self.buffer:
f.write(f"{line}\n")
f.writelines([f"{line}\n" for line in self.buffer])
1 change: 1 addition & 0 deletions osiris/dtypes/cairo_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ class Dtype(Enum):
I8 = "i8"
I32 = "i32"
U32 = "u32"
BOOL = 'bool'

0 comments on commit 1e0e9f5

Please sign in to comment.