-
Notifications
You must be signed in to change notification settings - Fork 825
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate op table and subgraph invoke functions (#2176)
As the next step in the codegen experiment, we want to generate the invoke calls for each layer. This is slightly challenging with the existing sources, as kernels only expose a registration function, not their individual Eval functions. In an effort to keep the code churn to a minimum, this PR introduces an inference only registration structure and function. It includes just two function pointers: invoke and reset. For this CL, we've only introduced it for FullyConnected. In the code generator, this PR creates a new op_table array in the generated source, with an enum for lookup. It also generates an invoke function for each subgraph, that calls each operator's invoke function. BUG=295174388
- Loading branch information
Showing
17 changed files
with
230 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
""" Provides object representation for the model that is conducive to code | ||
generation using templates. """ | ||
|
||
from typing import List, Sequence | ||
|
||
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb | ||
from tflite_micro.tensorflow.lite.tools import visualize | ||
|
||
|
||
def _to_pascal_case(s: str) -> str: | ||
return s.title().replace('_', '') | ||
|
||
|
||
class OpCode(object): | ||
|
||
def __init__(self, op_code: schema_fb.OperatorCodeT): | ||
self._op_code: schema_fb.OperatorCodeT = op_code | ||
|
||
def name(self) -> str: | ||
if self._op_code.customCode: | ||
return self._op_code.customCode | ||
return visualize.BuiltinCodeToName(self._op_code.builtinCode) | ||
|
||
def register_function(self) -> str: | ||
return "tflite::RegisterInference_{}".format(self.name()) | ||
|
||
def enum_name(self) -> str: | ||
return "k{}".format(_to_pascal_case(self.name())) | ||
|
||
|
||
class Operator(object): | ||
|
||
def __init__(self, model: schema_fb.ModelT, operator: schema_fb.OperatorT): | ||
self._operator: schema_fb.OperatorT = operator | ||
self._op_code: OpCode = OpCode( | ||
model.operatorCodes[self._operator.opcodeIndex]) | ||
|
||
@property | ||
def op_code(self) -> OpCode: | ||
return self._op_code | ||
|
||
|
||
class Subgraph(object): | ||
|
||
def __init__(self, model: schema_fb.ModelT, subgraph: schema_fb.SubGraphT): | ||
self._subgraph: schema_fb.SubGraphT = subgraph | ||
self._operators: List[Operator] = [ | ||
Operator(model, operator) for operator in subgraph.operators | ||
] | ||
|
||
@property | ||
def operators(self) -> Sequence[Operator]: | ||
return self._operators | ||
|
||
|
||
class Graph(object): | ||
|
||
def __init__(self, model: schema_fb.ModelT): | ||
self._subgraphs: List[SubGraph] = [ | ||
Subgraph(model, subgraph) for subgraph in model.subgraphs | ||
] | ||
|
||
@property | ||
def subgraphs(self) -> Sequence[Subgraph]: | ||
return self._subgraphs | ||
|
||
|
||
class OpCodeTable(object): | ||
|
||
def __init__(self, models: Sequence[schema_fb.ModelT]): | ||
op_codes = [] | ||
for model in models: | ||
for op_code in model.operatorCodes: | ||
op_codes.append(OpCode(op_code)) | ||
|
||
self._op_codes: List([OpCode]) = list(set(op_codes)) | ||
|
||
@property | ||
def op_codes(self) -> Sequence[OpCode]: | ||
return self._op_codes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters