Skip to content

Commit

Permalink
[PythonAPI] Adding input and output names options to the wraping func…
Browse files Browse the repository at this point in the history
…tion

Those two options can help for pytorch bad consistency in the naming in their ONNX exported layers.
This fix allow the user to provide two list of names for the input and output of the network.
  • Loading branch information
cmoineau committed Jan 11, 2023
2 parents e1bed92 + 56ffa3f commit 88e9ebe
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
2 changes: 0 additions & 2 deletions docs/export/CPP_STM32.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
Export: C++/STM32
=================

**N2D2-IP only: available upon request.**

Export type: ``CPP_STM32``
C++ export for STM32.

Expand Down
14 changes: 7 additions & 7 deletions docs/quant/pruning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ Example with Python
:members:
:inherited-members:

Example of code to use the *PruneCell* in your scripts:
Example of code to use the :py:class:`n2d2.quantizer.PruneCell` in your scripts:

.. code-block:: python
for cell in model:
### Add Pruning ###
if isinstance(cell, n2d2.cells.Conv) or isinstance(cell, n2d2.cells.Fc):
cell.quantizer = n2d2.quantizer.PruneCell(prune_mode="Static", threshold=0.3, prune_filler="IterNonStruct")
### Add Pruning ###
if isinstance(cell, n2d2.cells.Conv) or isinstance(cell, n2d2.cells.Fc):
cell.quantizer = n2d2.quantizer.PruneCell(prune_mode="Static", threshold=0.3, prune_filler="IterNonStruct")
Some explanations with the differents options of the *PruneCell*:
Some explanations with the differents options of the :py:class:`n2d2.quantizer.PruneCell` :

Pruning mode
^^^^^^^^^^^^
Expand All @@ -42,7 +42,7 @@ For example, to update each two epochs, write:
n2d2.quantizer.PruneCell(prune_mode="Gradual", threshold=0.3, stepsize=2*DATASET_SIZE)
Where *DATASET_SIZE* is the size of the dataset you are using.
Where ``DATASET_SIZE`` is the size of the dataset you are using.

Pruning filler
^^^^^^^^^^^^^^
Expand All @@ -53,7 +53,7 @@ Pruning filler
- IterNonStruct: all weights below than the ``delta`` factor are pruned. If this is not enough to reach ``threshold``, all the weights below 2 "delta" are pruned and so on...


**Important**: With *PruneCell*, ``quant_mode`` and ``range`` are not used.
**Important**: With :py:class:`n2d2.quantizer.PruneCell`, ``quant_mode`` and ``range`` are not used.


Example with INI file
Expand Down
8 changes: 8 additions & 0 deletions python/pytorch_to_n2d2/pytorch_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def __exit__(self, exc_type, exc_value, traceback):
def wrap(torch_model:torch.nn.Module,
input_size: Union[list, tuple],
opset_version:int=11,
in_names:list=None,
out_names:list=None,
verbose:bool=False) -> Block:
"""Function generating a ``torch.nn.Module`` which embed a :py:class:`n2d2.cells.DeepNetCell`.
The torch_model is exported to N2D2 via ONNX.
Expand All @@ -341,6 +343,10 @@ def wrap(torch_model:torch.nn.Module,
:type input_size: ``list``
:param opset_version: Opset version used to generate the intermediate ONNX file, default=11
:type opset_version: int, optional
:param in_names: Specify specific names for the network inputs
:type in_names: list, optional
:param out_names: Specify specific names for the network outputs
:type in_names: list, optional
:param verbose: Enable the verbose output of torch onnx export, default=False
:type verbose: bool, optional
:return: A custom ``torch.nn.Module`` which embed a :py:class:`n2d2.cells.DeepNetCell`.
Expand All @@ -361,6 +367,8 @@ def wrap(torch_model:torch.nn.Module,
dummy_in,
raw_model_path,
verbose=verbose,
input_names=in_names,
output_names=out_names,
export_params=True,
opset_version=opset_version,
training=torch.onnx.TrainingMode.TRAINING,
Expand Down

0 comments on commit 88e9ebe

Please sign in to comment.