diff --git a/nir/ir.py b/nir/ir.py index 30deece..86e17cd 100644 --- a/nir/ir.py +++ b/nir/ir.py @@ -112,6 +112,7 @@ class NIRGraph(NIRNode): nodes: Nodes # List of computational nodes edges: Edges # List of edges between nodes + ensure_validity: bool = True # Whether to check that the graph is valid @property def inputs(self): @@ -176,8 +177,15 @@ def __post_init__(self): self.output_type = { node_key: self.nodes[node_key].output_type for node_key in output_node_keys } - - def _check_types(self): + # check that all nodes have consistent and defined input and output types + if self.ensure_validity: + if not self.is_valid(): + print('invalid graph, attempting to infer types') + self.infer_types() + if not self.is_valid(): + raise ValueError('invalid graph, could not infer types') + + def is_valid(self): """Check that all nodes in the graph have input and output types. Will raise ValueError if any node has no input or output type, or if the types are inconsistent.""" for edge in self.edges: @@ -189,18 +197,21 @@ def _check_types(self): v is None for v in pre_node.output_type.values() ) if undef_out_type: - raise ValueError(f'pre node {edge[0]} has no output type') + print(f'pre node {edge[0]} has no output type') + return False undef_in_type = post_node.input_type is None or any( v is None for v in post_node.input_type.values() ) if undef_in_type: - raise ValueError(f'post node {edge[1]} has no input type') + print(f'post node {edge[1]} has no input type') + return False # make sure the length of types is equal if len(pre_node.output_type) != len(post_node.input_type): pre_repr = f'len({edge[0]}.output)={len(pre_node.output_type)}' post_repr = f'len({edge[1]}.input)={len(post_node.input_type)}' - raise ValueError(f'type length mismatch: {pre_repr} -> {post_repr}') + print(f'type length mismatch: {pre_repr} -> {post_repr}') + return False # make sure the type values match up if len(pre_node.output_type.keys()) == 1: @@ -209,9 +220,11 @@ def _check_types(self): if not np.array_equal(post_input_type, pre_output_type): pre_repr = f'{edge[0]}.output: {pre_output_type}' post_repr = f'{edge[1]}.input: {post_input_type}' - raise ValueError(f'type mismatch: {pre_repr} -> {post_repr}') + print(f'type mismatch: {pre_repr} -> {post_repr}') + return False else: - raise NotImplementedError('multiple input/output types not supported yet') + print('multiple input/output types not supported yet') + return False return True def _forward_type_inference(self, debug=True): diff --git a/nir/read.py b/nir/read.py index cc830b8..7b65a95 100644 --- a/nir/read.py +++ b/nir/read.py @@ -88,10 +88,16 @@ def read_node(node: typing.Any) -> nir.NIRNode: raise ValueError(f"Unknown unit type: {node['type'][()]}") -def read(filename: typing.Union[str, pathlib.Path]) -> nir.NIRGraph: - """Load a NIR from a HDF/conn5 file.""" +def read(filename: typing.Union[str, pathlib.Path], strict=True) -> nir.NIRGraph: + """Load a NIR from a HDF/conn5 file. If strict, only load valid graphs.""" with h5py.File(filename, "r") as f: - return read_node(f["node"]) + graph: nir.NIRGraph = read_node(f["node"]) + if not graph.is_valid(): + print('[WARNING] graph is invalid, attempting nir.NIRGraph.infer_types()') + graph = graph.infer_types() + if not graph.is_valid() and strict: + raise ValueError("Invalid graph, could not read.") + return graph def read_version(filename: typing.Union[str, pathlib.Path]) -> str: diff --git a/nir/write.py b/nir/write.py index 57a0661..d295ee8 100644 --- a/nir/write.py +++ b/nir/write.py @@ -107,8 +107,11 @@ def _convert_node(node: nir.NIRNode) -> dict: raise ValueError(f"Unknown node type: {node}") -def write(filename: typing.Union[str, pathlib.Path], graph: nir.NIRNode) -> None: - """Write a NIR to a HDF5 file.""" +def write(filename: typing.Union[str, pathlib.Path], graph: nir.NIRNode, strict=True) -> None: + """Write a NIR to a HDF5 file. If strict, only allow valid graphs to be written.""" + + if not graph.is_valid() and strict: + raise ValueError("Cannot write an invalid graph. See NIRGraph.is_valid().") def write_recursive(group: h5py.Group, node: dict) -> None: for k, v in node.items(): diff --git a/tests/test_ir.py b/tests/test_ir.py index 30cff5f..5bc8fea 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -431,4 +431,4 @@ def test_conv_type_inference(): } for name, graph in graphs.items(): graph.infer_types() - assert graph._check_types(), f'type inference failed for: {name}' + assert graph.is_valid(), f'type inference failed for: {name}'