Skip to content

Commit

Permalink
Merge pull request #360 from rsokl/develop
Browse files Browse the repository at this point in the history
Add missing type hints and update docs
  • Loading branch information
rsokl authored Mar 3, 2021
2 parents 82b31e2 + be42765 commit a60481b
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 136 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![Documentation Status](https://readthedocs.org/projects/mygrad/badge/?version=latest)](https://mygrad.readthedocs.io/en/latest/?badge=latest)
[![Automated tests status](https://github.com/rsokl/MyGrad/workflows/Tests/badge.svg)](https://github.com/rsokl/MyGrad/actions?query=workflow%3ATests+branch%3Amaster)
[![PyPi version](https://img.shields.io/pypi/v/mygrad.svg)](https://pypi.python.org/pypi/mygrad)
![Python version support](https://img.shields.io/badge/python-3.6%20‐%203.9-blue.svg)
![Python version support](https://img.shields.io/badge/python-3.7%20‐%203.9-blue.svg)

# [MyGrad's Documentation](https://mygrad.readthedocs.io/en/latest/)

Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
"Intended Audience :: Education",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering",
]

Expand Down
45 changes: 25 additions & 20 deletions src/mygrad/indexing_routines/funcs.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
from typing import Optional

import numpy as np

from mygrad import Tensor, asarray
from mygrad.operation_base import _NoValue
from mygrad.typing import ArrayLike

from .ops import Where

__all__ = ["where"]


class _UniqueIdentifier:
def __init__(self, identifier):
self.identifier = identifier

def __repr__(self): # pragma: nocover
return self.identifier


not_set = _UniqueIdentifier("not_set")


def where(condition, x=not_set, y=not_set, *, constant=None):
def where(
condition: ArrayLike,
x: ArrayLike = _NoValue,
y: ArrayLike = _NoValue,
*,
constant: Optional[bool] = None
) -> Tensor:
"""
where(condition, [x, y])
Expand All @@ -34,19 +33,25 @@ def where(condition, x=not_set, y=not_set, *, constant=None):
Parameters
----------
condition : array_like, bool
condition : ArrayLike, bool
Where True, yield `x`, otherwise yield ``y``. ``x``, ``y``
and `condition` need to be broadcastable to some shape.
x : array_like
x : ArrayLike
Values from which to chosen where ``condition`` is ``True``.
y : array_like
y : ArrayLike
Values from which to chosen where ``condition`` is ``False``.
constant : bool, optional(default=False)
If ``True``, the returned tensor is a constant (it
does not back-propagate a gradient)
constant : Optional[bool]
If ``True``, this tensor is treated as a constant, and thus does not
facilitate back propagation (i.e. ``constant.grad`` will always return
``None``).
Defaults to ``False`` for float-type data.
Defaults to ``True`` for integer-type data.
Integer-type tensors must be constant.
Returns
-------
Expand Down Expand Up @@ -87,10 +92,10 @@ def where(condition, x=not_set, y=not_set, *, constant=None):
[ 0, 2, -1],
[ 0, 3, -1]])
"""
if x is not_set and y is not_set:
if x is _NoValue and y is _NoValue:
return np.where(asarray(condition))

if x is not_set or y is not_set:
if x is _NoValue or y is _NoValue:
raise ValueError("either both or neither of x and y should be given")

return Tensor._op(
Expand Down
32 changes: 24 additions & 8 deletions src/mygrad/linalg/funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Sequence, Union

import numpy as np
from numpy.core.einsumfunc import _parse_einsum_input
Expand Down Expand Up @@ -136,7 +136,11 @@ def matmul(
...


def einsum(*operands, optimize=False, constant=None):
def einsum(
*operands: Union[ArrayLike, str, Sequence[int]],
optimize: bool = False,
constant: Optional[bool] = None,
) -> Tensor:
r"""
einsum(subscripts, *operands)
Expand Down Expand Up @@ -179,8 +183,15 @@ def einsum(*operands, optimize=False, constant=None):
algorithm. Also accepts an explicit contraction list from the
``np.einsum_path`` function. See ``np.einsum_path`` for more details.
constant : bool, optional (default=False)
If True, the resulting Tensor is a constant.
constant : Optional[bool]
If ``True``, this tensor is treated as a constant, and thus does not
facilitate back propagation (i.e. ``constant.grad`` will always return
``None``).
Defaults to ``False`` for float-type data.
Defaults to ``True`` for integer-type data.
Integer-type tensors must be constant.
Returns
-------
Expand Down Expand Up @@ -396,7 +407,7 @@ def einsum(*operands, optimize=False, constant=None):
)


def multi_matmul(tensors, *, constant=None):
def multi_matmul(tensors: ArrayLike, *, constant: Optional[bool] = None) -> Tensor:
"""
Matrix product of two or more tensors calculated in the optimal ordering
Expand All @@ -405,10 +416,15 @@ def multi_matmul(tensors, *, constant=None):
tensors: Sequence[array_like]
The sequence of tensors to be matrix-multiplied.
constant : bool, optional(default=False)
If ``True``, the returned tensor is a constant (it
does not back-propagate a gradient).
constant : Optional[bool]
If ``True``, this tensor is treated as a constant, and thus does not
facilitate back propagation (i.e. ``constant.grad`` will always return
``None``).
Defaults to ``False`` for float-type data.
Defaults to ``True`` for integer-type data.
Integer-type tensors must be constant.
Returns
-------
mygrad.Tensor
Expand Down
Loading

0 comments on commit a60481b

Please sign in to comment.