From 940785b12367d8f2a76bdde70fdb66e38e77c9be Mon Sep 17 00:00:00 2001 From: Takuya Narihira Date: Sun, 28 Mar 2021 09:00:56 +0900 Subject: [PATCH] Add dot matmul shortcut --- python/src/nnabla/_arithmetic_ops.pyx | 23 +++++++++ python/src/nnabla/_variable.pyx | 3 ++ python/src/nnabla/functions.py | 3 +- python/src/nnabla/numpy_compat_functions.py | 53 +++++++++++++++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 python/src/nnabla/numpy_compat_functions.py diff --git a/python/src/nnabla/_arithmetic_ops.pyx b/python/src/nnabla/_arithmetic_ops.pyx index 36fe3e1ca..7ddd88d8c 100644 --- a/python/src/nnabla/_arithmetic_ops.pyx +++ b/python/src/nnabla/_arithmetic_ops.pyx @@ -209,3 +209,26 @@ cdef object pow(object x, object y, object z): return F.r_pow_scalar(y, x) else: return x ** y + + +cdef object matmul(object x, object y): + """ + Matlix multiplication + + Implements the matmul operator expression ``x @ y``. + When both of ``x`` and ``y`` are either :obj:`~nnabla.Variable` or + :obj:`~nnabla.NdArray`, :func:`~nnabla.functions.affine`` is + internally called. + + Args: + x (~nnabla.Variable or ~nnabla.NdArray): Left operand. It must be 2-dimensional. + y (~nnabla.Variable or ~nnabla.NdArray): Right operand. It must be 2-dimensional. + + Returns: :class:`~nnabla.Variable` or :class:`~nnabla.NdArray`. + + """ + import nnabla.functions as F + assert x.ndim == 2 and y.ndim == 2, "Both of x and y must be matrices." + assert isinstance(x, (NdArray, Variable)) + assert isinstance(y, (NdArray, Variable)) + return F.affine(x, y) diff --git a/python/src/nnabla/_variable.pyx b/python/src/nnabla/_variable.pyx index f49c5cc22..839812182 100644 --- a/python/src/nnabla/_variable.pyx +++ b/python/src/nnabla/_variable.pyx @@ -1062,6 +1062,9 @@ cdef class Variable: def __pow__(x, y, z): return AOP.pow(x, y, z) + def __matmul__(x, y): + return AOP.matmul(x, y) + def __iadd__(self, x): import nnabla.functions as F if isinstance(x, (NdArray, Variable)): diff --git a/python/src/nnabla/functions.py b/python/src/nnabla/functions.py index bd297ff68..351a7d156 100644 --- a/python/src/nnabla/functions.py +++ b/python/src/nnabla/functions.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017 Sony Corporation. All Rights Reserved. +# Copyright (c) 2017-2021 Sony Corporation. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import nnabla as nn import numpy as np from .normalization_functions import * +from .numpy_compat_functions import * def sum(x, axis=None, keepdims=False): diff --git a/python/src/nnabla/numpy_compat_functions.py b/python/src/nnabla/numpy_compat_functions.py new file mode 100644 index 000000000..799c67c13 --- /dev/null +++ b/python/src/nnabla/numpy_compat_functions.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021 Sony Corporation. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def dot(a, b, out=None): + ''' + A compatible operation with ``numpy.dot``. + + Note: + Any operation between nnabla's Variable/NdArray and numpy array is not supported. + + Args: + a (Variable, NdArray or scalar): Left input array. + b (Variable, NdArray or scalar): Right input array. + out: Not supported so far. + + Returns: + ~nnabla.Variable: N-D array. + + ''' + import nnabla as nn + import nnabla.fucntions as F + assert out is None, "The `out` option is not supported." + + def _chk(x): + return isinstance(x, (nn.NdArray, nn.Variable)) + + if _chk(a) and _chk(b): + if a.ndim == 1 and b.ndim == 1: + return return F.sum(a * b) + if a.ndim == 2 and b.ndim >= 2: + return F.affine(a, b) + if a.ndim == 0 or b.ndim == 0: + return a * b + if a.ndim > 2 and b.ndim == 1: + h = F.affine(x, F.reshape(y, (-1, 1)), base_axis=x.ndim - 1) + return F.reshape(h, h.shape[:-1]) + raise ValueError(f'Undefined configuration: a.ndim={a.ndim}, b.ndim:{b.ndim}') + + return x * y