Skip to content

Commit

Permalink
Update _tensor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
adtzlr committed Jun 9, 2024
1 parent 4419c79 commit e02380a
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,15 @@ def T(self):
return transpose(self)

def ravel(self, order="C"):
"Return a contiguous flattened array."
return ravel(self, order=order)

def reshape(self, *shape, order="C"):
"Gives a new shape to an array without changing its data."
return reshape(self, newshape=shape, order=order)

def squeeze(self, axis=None):
"Remove axes of length one."
return squeeze(self, axis=axis)

def dual2real(self, like):
Expand Down Expand Up @@ -534,6 +537,7 @@ def mul(A, B):


def ravel(A, order="C"):
"Return a contiguous flattened array."
if isinstance(A, Tensor):
δtrax = δ(A).shape[len(A.shape) :]
Δtrax = Δ(A).shape[len(A.shape) :]
Expand All @@ -550,6 +554,7 @@ def ravel(A, order="C"):


def squeeze(A, axis=None):
"Remove axes of length one."
if isinstance(A, Tensor):
if axis is None:
if 1 in A.shape:
Expand All @@ -568,6 +573,7 @@ def squeeze(A, axis=None):


def reshape(A, newshape, order="C"):
"Gives a new shape to an array without changing its data."
if isinstance(A, Tensor):
δtrax = δ(A).shape[len(A.shape) :]
Δtrax = Δ(A).shape[len(A.shape) :]
Expand Down Expand Up @@ -749,12 +755,14 @@ def einsum(subscripts, *operands):


def transpose(A):
"Returns an array with axes transposed."
ij = "abcdefghijklmnopqrstuvwxyz"[: len(A.shape)]
ji = ij[::-1]
return einsum(f"{ij}...->{ji}...", A)


def matmul(A, B):
"Matrix product of two arrays."
ik = "ik"[2 - len(A.shape) :]
kj = "kj"[: len(B.shape)]
ij = (ik + kj).replace("k", "")
Expand Down

0 comments on commit e02380a

Please sign in to comment.