Skip to content

Commit

Permalink
Update _evaluate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
adtzlr committed Jun 9, 2024
1 parent 5590d4a commit b60b025
Showing 1 changed file with 91 additions and 5 deletions.
96 changes: 91 additions & 5 deletions src/tensortrax/_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,16 @@ def function(fun, wrt=0, ntrax=0, parallel=False):
>>> return mu / 2 * (J ** (-2 / 3) * I1 - 3)
>>>
>>> np.random.seed(125161)
>>> F = (np.eye(3) + np.random.rand(50, 8, 3, 3) / 10).T
>>> F = (np.eye(3) + np.random.rand(20, 8, 3, 3) / 10).T
>>>
>>> F.shape
(3, 3, 8, 50)
(3, 3, 8, 20)
>>> W = tr.function(fun, wrt=0, ntrax=2)(F)
>>> W = tr.function(fun, wrt="F", ntrax=2)(F=F)
>>>
>>> W.shape
>>> (8, 50)
>>> (8, 20)
"""

@wraps(fun)
Expand Down Expand Up @@ -213,7 +213,50 @@ def kernel(args, kwargs):


def gradient(fun, wrt=0, ntrax=0, parallel=False, full_output=False, sym=False):
"Evaluate the gradient of a scalar-valued function."
r"""Evaluate the gradient of a scalar-valued function.
Parameters
----------
fun : callable
The function to be evaluated.
wrt : int or str, optional
The input argument which will be treated as :class:`~tensortrax.Tensor` (default
is 0). The gradient is carried out with respect to this argument.
ntrax : int, optional
Number of elementwise-operating trailing axes (batch dimensions). Default is 0.
parallel : bool, optional
Flag to evaluate the gradient in parallel (threaded).
Returns
-------
ndarray
NumPy array containing the gradient result.
Examples
--------
>>> import numpy as np
>>> import tensortrax as tr
>>> import tensortrax.math as tm
>>>
>>> def fun(F, mu=1):
>>> C = F.T @ F
>>> I1 = tm.trace(C)
>>> J = tm.linalg.det(F)
>>> return mu / 2 * (J ** (-2 / 3) * I1 - 3)
>>>
>>> np.random.seed(125161)
>>> F = (np.eye(3) + np.random.rand(20, 8, 3, 3) / 10).T
>>>
>>> F.shape
(3, 3, 8, 20)
>>> dWdF = tr.gradient(fun, wrt=0, ntrax=2)(F)
>>> dWdF = tr.gradient(fun, wrt="F", ntrax=2)(F=F)
>>>
>>> dWdF.shape
>>> (3, 3, 8, 20)
"""

@wraps(fun)
def evaluate_gradient(*args, **kwargs):
Expand Down Expand Up @@ -245,7 +288,50 @@ def kernel(args, kwargs):


def hessian(fun, wrt=0, ntrax=0, parallel=False, full_output=False, sym=False):
"Evaluate the hessian of a scalar-valued function."
r"""Evaluate the Hessian of a scalar-valued function.
Parameters
----------
fun : callable
The function to be evaluated.
wrt : int or str, optional
The input argument which will be treated as :class:`~tensortrax.Tensor` (default
is 0). The Hessian is carried out with respect to this argument.
ntrax : int, optional
Number of elementwise-operating trailing axes (batch dimensions). Default is 0.
parallel : bool, optional
Flag to evaluate the Hessian in parallel (threaded).
Returns
-------
ndarray
NumPy array containing the Hessian result.
Examples
--------
>>> import numpy as np
>>> import tensortrax as tr
>>> import tensortrax.math as tm
>>>
>>> def fun(F, mu=1):
>>> C = F.T @ F
>>> I1 = tm.trace(C)
>>> J = tm.linalg.det(F)
>>> return mu / 2 * (J ** (-2 / 3) * I1 - 3)
>>>
>>> np.random.seed(125161)
>>> F = (np.eye(3) + np.random.rand(20, 8, 3, 3) / 10).T
>>>
>>> F.shape
(3, 3, 8, 20)
>>> d2WdFdF = tr.hessian(fun, wrt=0, ntrax=2)(F)
>>> d2WdFdF = tr.hessian(fun, wrt="F", ntrax=2)(F=F)
>>>
>>> d2WdFdF.shape
>>> (3, 3, 3, 3, 8, 20)
"""

@wraps(fun)
def evaluate_hessian(*args, **kwargs):
Expand Down

0 comments on commit b60b025

Please sign in to comment.