Skip to content

Commit

Permalink
pint_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
abhisrkckl committed Jan 21, 2025
1 parent 76d0193 commit 02bb5c6
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions src/pint/pint_matrix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" pint_matrix module defines the pint matrix base class, the design matrix . and the covariance matrix
"""

from typing import List, Optional
import numpy as np
from itertools import combinations
import astropy.units as u
Expand Down Expand Up @@ -41,7 +42,7 @@ class PintMatrix:
TODO: 1. add index to label mapping
"""

def __init__(self, matrix, axis_labels):
def __init__(self, matrix: np.ndarray, axis_labels: List[str]):
self.matrix = matrix
self.axis_labels = axis_labels
# Check dimensions
Expand All @@ -59,26 +60,22 @@ def __getitem__(self, key):
)

@property
def ndim(self):
def ndim(self) -> int:
return self.matrix.ndim

@property
def shape(self):
def shape(self) -> tuple:
return self.matrix.shape

@property
def labels(self):
labels = []
labels.extend(self.get_axis_labels(dim) for dim in range(len(self.axis_labels)))
return labels
def labels(self) -> list:
return [self.get_axis_labels(dim) for dim in range(len(self.axis_labels))]

@property
def label_units(self):
units = []
units.extend(self.get_axis_labels(dim) for dim in range(len(self.axis_labels)))
return units
def label_units(self) -> list:
return [self.get_axis_labels(dim) for dim in range(len(self.axis_labels))]

def diag(self, k=0):
def diag(self, k: int = 0) -> np.ndarray:
"""
Extract a diagonal.
Expand All @@ -95,7 +92,7 @@ def diag(self, k=0):
"""
return np.diag(self.matrix, k=k)

def get_label_names(self, axis=None):
def get_label_names(self, axis: Optional[int] = None) -> List[List[str]]:
"""Return only the names of the labels
along the specified axis if requested.
Expand All @@ -120,7 +117,7 @@ def get_label_names(self, axis=None):
labels.extend([x[0] for x in self.get_axis_labels(dim)] for dim in r)
return labels

def get_unique_label_names(self):
def get_unique_label_names(self) -> List[str]:
"""Return all unique label names (there may be duplications between axes).
Returns
Expand Down Expand Up @@ -332,7 +329,7 @@ class DesignMatrix(PintMatrix):
TODO: 1. add index to unit mapping.
"""

def __init__(self, matrix, labels):
def __init__(self, matrix: np.ndarray, labels: List[str]):
super().__init__(matrix, labels)

@property
Expand Down

0 comments on commit 02bb5c6

Please sign in to comment.