Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enforce the validity of NIR graphs on creation #61

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions nir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class NIRGraph(NIRNode):

nodes: Nodes # List of computational nodes
edges: Edges # List of edges between nodes
ensure_validity: bool = True # Whether to check that the graph is valid

@property
def inputs(self):
Expand Down Expand Up @@ -176,8 +177,15 @@ def __post_init__(self):
self.output_type = {
node_key: self.nodes[node_key].output_type for node_key in output_node_keys
}

def _check_types(self):
# check that all nodes have consistent and defined input and output types
if self.ensure_validity:
if not self.is_valid():
print('invalid graph, attempting to infer types')
self.infer_types()
if not self.is_valid():
raise ValueError('invalid graph, could not infer types')

def is_valid(self):
"""Check that all nodes in the graph have input and output types. Will raise ValueError
if any node has no input or output type, or if the types are inconsistent."""
for edge in self.edges:
Expand All @@ -189,18 +197,21 @@ def _check_types(self):
v is None for v in pre_node.output_type.values()
)
if undef_out_type:
raise ValueError(f'pre node {edge[0]} has no output type')
print(f'pre node {edge[0]} has no output type')
return False
undef_in_type = post_node.input_type is None or any(
v is None for v in post_node.input_type.values()
)
if undef_in_type:
raise ValueError(f'post node {edge[1]} has no input type')
print(f'post node {edge[1]} has no input type')
return False

# make sure the length of types is equal
if len(pre_node.output_type) != len(post_node.input_type):
pre_repr = f'len({edge[0]}.output)={len(pre_node.output_type)}'
post_repr = f'len({edge[1]}.input)={len(post_node.input_type)}'
raise ValueError(f'type length mismatch: {pre_repr} -> {post_repr}')
print(f'type length mismatch: {pre_repr} -> {post_repr}')
return False

# make sure the type values match up
if len(pre_node.output_type.keys()) == 1:
Expand All @@ -209,9 +220,11 @@ def _check_types(self):
if not np.array_equal(post_input_type, pre_output_type):
pre_repr = f'{edge[0]}.output: {pre_output_type}'
post_repr = f'{edge[1]}.input: {post_input_type}'
raise ValueError(f'type mismatch: {pre_repr} -> {post_repr}')
print(f'type mismatch: {pre_repr} -> {post_repr}')
return False
else:
raise NotImplementedError('multiple input/output types not supported yet')
print('multiple input/output types not supported yet')
return False
return True

def _forward_type_inference(self, debug=True):
Expand Down
12 changes: 9 additions & 3 deletions nir/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,16 @@ def read_node(node: typing.Any) -> nir.NIRNode:
raise ValueError(f"Unknown unit type: {node['type'][()]}")


def read(filename: typing.Union[str, pathlib.Path]) -> nir.NIRGraph:
"""Load a NIR from a HDF/conn5 file."""
def read(filename: typing.Union[str, pathlib.Path], strict=True) -> nir.NIRGraph:
"""Load a NIR from a HDF/conn5 file. If strict, only load valid graphs."""
with h5py.File(filename, "r") as f:
return read_node(f["node"])
graph: nir.NIRGraph = read_node(f["node"])
if not graph.is_valid():
print('[WARNING] graph is invalid, attempting nir.NIRGraph.infer_types()')
graph = graph.infer_types()
if not graph.is_valid() and strict:
raise ValueError("Invalid graph, could not read.")
return graph


def read_version(filename: typing.Union[str, pathlib.Path]) -> str:
Expand Down
7 changes: 5 additions & 2 deletions nir/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ def _convert_node(node: nir.NIRNode) -> dict:
raise ValueError(f"Unknown node type: {node}")


def write(filename: typing.Union[str, pathlib.Path], graph: nir.NIRNode) -> None:
"""Write a NIR to a HDF5 file."""
def write(filename: typing.Union[str, pathlib.Path], graph: nir.NIRNode, strict=True) -> None:
"""Write a NIR to a HDF5 file. If strict, only allow valid graphs to be written."""

if not graph.is_valid() and strict:
raise ValueError("Cannot write an invalid graph. See NIRGraph.is_valid().")

def write_recursive(group: h5py.Group, node: dict) -> None:
for k, v in node.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,4 @@ def test_conv_type_inference():
}
for name, graph in graphs.items():
graph.infer_types()
assert graph._check_types(), f'type inference failed for: {name}'
assert graph.is_valid(), f'type inference failed for: {name}'