diff --git a/src/pint/pint_matrix.py b/src/pint/pint_matrix.py index 3ae2ae2dc..89b1e1a63 100644 --- a/src/pint/pint_matrix.py +++ b/src/pint/pint_matrix.py @@ -1,6 +1,7 @@ """ pint_matrix module defines the pint matrix base class, the design matrix . and the covariance matrix """ +from typing import List, Optional import numpy as np from itertools import combinations import astropy.units as u @@ -41,7 +42,7 @@ class PintMatrix: TODO: 1. add index to label mapping """ - def __init__(self, matrix, axis_labels): + def __init__(self, matrix: np.ndarray, axis_labels: List[str]): self.matrix = matrix self.axis_labels = axis_labels # Check dimensions @@ -59,26 +60,22 @@ def __getitem__(self, key): ) @property - def ndim(self): + def ndim(self) -> int: return self.matrix.ndim @property - def shape(self): + def shape(self) -> tuple: return self.matrix.shape @property - def labels(self): - labels = [] - labels.extend(self.get_axis_labels(dim) for dim in range(len(self.axis_labels))) - return labels + def labels(self) -> list: + return [self.get_axis_labels(dim) for dim in range(len(self.axis_labels))] @property - def label_units(self): - units = [] - units.extend(self.get_axis_labels(dim) for dim in range(len(self.axis_labels))) - return units + def label_units(self) -> list: + return [self.get_axis_labels(dim) for dim in range(len(self.axis_labels))] - def diag(self, k=0): + def diag(self, k: int = 0) -> np.ndarray: """ Extract a diagonal. @@ -95,7 +92,7 @@ def diag(self, k=0): """ return np.diag(self.matrix, k=k) - def get_label_names(self, axis=None): + def get_label_names(self, axis: Optional[int] = None) -> List[List[str]]: """Return only the names of the labels along the specified axis if requested. @@ -120,7 +117,7 @@ def get_label_names(self, axis=None): labels.extend([x[0] for x in self.get_axis_labels(dim)] for dim in r) return labels - def get_unique_label_names(self): + def get_unique_label_names(self) -> List[str]: """Return all unique label names (there may be duplications between axes). Returns @@ -332,7 +329,7 @@ class DesignMatrix(PintMatrix): TODO: 1. add index to unit mapping. """ - def __init__(self, matrix, labels): + def __init__(self, matrix: np.ndarray, labels: List[str]): super().__init__(matrix, labels) @property