Skip to content

Commit

Permalink
Merge pull request #179 from gganapavarapu/master
Browse files Browse the repository at this point in the history
Adding Time Series Lime Explainer
  • Loading branch information
vijay-arya authored Jun 20, 2023
2 parents 76b6022 + 608501b commit 4f14691
Show file tree
Hide file tree
Showing 16 changed files with 1,720 additions and 25 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/Build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,17 @@ jobs:
- name: Step 6 - Test TSSaliencyExplainer
run: python ./tests/tssaliency/test_tssaliency.py

# tslime deps are already satisfied.
- name: Step 7 - Test TSLimeExplainer
run: python ./tests/tslime/test_tslime.py

build-imd-on-py38-310:
# The type of runner that the job will run on
runs-on: "${{ matrix.os }}"
strategy:
fail-fast: false
matrix:
# os: [ubuntu-18.04, ubuntu-latest, macos-latest, windows-latest]
# os: [ubuntu-18.04, ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-20.04]
python-version: ["3.10"]

Expand Down
1 change: 1 addition & 0 deletions aix360/algorithms/tslime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tslime import TSLimeExplainer
50 changes: 50 additions & 0 deletions aix360/algorithms/tslime/surrogate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from abc import abstractmethod
import numpy as np
from sklearn.linear_model import LinearRegression


class LinearSurrogateModel:
"""Linear Interpretable Surrogate Model Wrapper."""

def __init__(self, model):
self.model = model

def fit(self, *args, **kwargs):
self.model.fit(*args, **kwargs)

def predict(self, *args, **kwargs):
return self.model.predict(*args, **kwargs)

@abstractmethod
def get_weights(self):
pass


class LinearRegressionSurrogate(LinearSurrogateModel):
"""Linear Interpretable Surrogate Model using LinearRegression from Scikit-Learn."""

def __init__(self):
super(LinearRegressionSurrogate, self).__init__(LinearRegression())

def get_weights(self):
return self.model.coef_


def linear_surrogate_weights(
x_perturbations: np.ndarray,
y_perturbations: np.ndarray,
surrogate: LinearSurrogateModel = None,
):
"""Function to compute weights from a linear interpretable model
using provided time series pertubations."""

if surrogate is None:
surrogate = LinearRegressionSurrogate()

surrogate.fit(
x_perturbations.reshape(x_perturbations.shape[0], -1),
y_perturbations.reshape(y_perturbations.shape[0], -1),
)

# retrieve weights
return surrogate, surrogate.get_weights()
217 changes: 217 additions & 0 deletions aix360/algorithms/tslime/tslime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import warnings
import numpy as np
import pandas as pd
from typing import Union, List, Callable
from aix360.algorithms.tslbbe import TSLocalBBExplainer
from aix360.algorithms.tsutils.tsframe import tsFrame, to_np_array
from aix360.algorithms.tslime.surrogate import (
linear_surrogate_weights,
LinearSurrogateModel,
)
from aix360.algorithms.tsutils.tsperturbers.perturbed_data_generator import (
PerturbedDataGenerator,
)
from aix360.algorithms.tsutils.tsperturbers.tsperturber import (
TSPerturber,
BlockSelector,
)


class TSLimeExplainer(TSLocalBBExplainer):
"""Time Series Local Interpretable Model-agnostic Explainer (TSLime) is a model-agnostic local time series
explainer. LIME (Locally interpretable Model agnostic explainer) is a popular algorithm for local
explanation. LIME explains the model behavior by approximating the model response with linear models.
LIME algorithm specifically assumes tabular data format, where each row is a data point, and columns
are features. A generalization of LIME algorithm for image data uses super pixel based perturbation.
TSLIME generalizes LIME algorithm for time series context.
TSLIME uses time series perturbation methods to produce a local input perturbation, and linear model
surrogate which best approximates the model response. TSLime produces an interpretable explanation.
The explanation weights produced by the TSLime explanation indicates model local sensitivity.
References:
.. [##] Marco Tulio Ribeiro et al. '"Why Should I Trust You?": Explaining the Predictions of Any Classifier'
https://arxiv.org/abs/1602.04938
"""

def __init__(
self,
model: Callable,
input_length: int,
n_perturbations: int = 2000,
relevant_history: int = None,
perturbers: List[Union[TSPerturber, dict]] = None,
local_interpretable_model: LinearSurrogateModel = None,
random_seed: int = None,
):
"""Initializer for TSLimeExplainer
Args:
model (Callable): Callable object produces a prediction as numpy array
for a given input as numpy array.
input_length (int): Input (history) length used for input model.
n_perturbations (int): Number of perturbed instance for TSExplanation. Defaults to 25.
relevant_history (int): Interested window size for explanations. The explanation is
computed for selected latest window of length `relevant_history`. If `input_length=20`
and `relevant_history=10`, explanation is computed for last 10 time points. If None,
relevant_history is set to input_length. Defaults to None.
perturbers (List[TSPerturber, dict]): data perturbation algorithm specification by TSPerturber
instance or dict. Allowed values for "type" key in dictionary are block-bootstrap, frequency,
moving-average, shift. Block-bootstrap split the time series into contiguous
chunks called blocks, for each block noise is estimated and noise is exchanged
and added to the signal between randomly selected blocks. Moving-average perturbation
maintains the moving mean of the time series data with the specified window length,
but add perturbed noise with similar distribution as the data. Frequency
perturber performs FFT on the noise, and removes random high frequency
components from the noise estimates. Number of frequencies to be removed
is specified by the truncate_frequencies argument. Shift perturber adds
random upward or downward shift in the data value over time continuous
blocks. If not provided default perturber is block-bootstrap. Defaults to None.
local_interpretable_model (LinearSurrogateModel): Local interpretable model, a surrogate that
is to be trained on the given input time series neighborhood. This model is used to provide
local weights for each time point in the selected timeseries. If None, sklearn's Linear Regression
model, aix360.algorithms.tslime.surrogate.LinearRegressionSurrogate is used. Defaults to None.
random_seed (int): random seed to get consistent results. Refer to numpy random state.
Defaults to None.
"""
self.model = model

if perturbers is None:
perturbers = [
dict(type="block-bootstrap"),
]

block_selector = BlockSelector(start=-input_length, end=None)
perturber = PerturbedDataGenerator(
perturber_engines=perturbers,
block_selector=block_selector,
)
self._parameters = dict()

# Input Specification
self.input_length = input_length

# Surrogate training params
self.local_interpretable_model = local_interpretable_model
self.n_perturbations = n_perturbations
self.perturber = perturber

# Explanation params
if relevant_history is None:
relevant_history = input_length

self.relevant_history = relevant_history
self.random_seed = random_seed

def get_params(self):
return self._parameters.copy()

def set_params(self, *argv, **kwargs):
self._parameters.update(kwargs)
return self

def _ts_perturb(self, x):
# create perturbations
x_perturbations = None
y_perturbations = None

x_perturbations, _ = self.perturber.fit_transform(
x, None, n=self.n_perturbations
)

x_perturbations = np.asarray(x_perturbations).astype("float")
return x_perturbations

def _batch_predict(self, x_perturbations):
f_predict_samples = None

try:
f_predict_samples = self.model(x_perturbations)
except Exception as ex:
warnings.warn(
"Batch scoring failed with error: {}. Scoring sequentially...".format(
ex
)
)
f_predict_samples = [
self.model(x_perturbations[i]) for i in range(x_perturbations.shape[0])
]
f_predict_samples = np.array(f_predict_samples)

return f_predict_samples

def explain_instance(self, ts: tsFrame, **explain_params):
"""Explain the prediction made by the time series model at a certain point in time
(**local explanation**).
Args
ts (tsFrame): Input time series signal in ``tsFrame`` format. This can
be generated using :py:mod:`aix360.algorithms.tsframe.tsFrame`.
A ``tsFrame`` is a pandas ``DataFrame`` indexed by ``Timestamp`` objects
(that is ``DatetimeIndex``). Each column corresponds to an input feature.
explain_params: Arbitrary explainer parameters.
Returns:
explanation (Union[List[Dict], Dict]): Dictionary with keys: input_data, history_weights,
model_prediction, surrogate_prediction, x_perturbations, y_perturbations.
"""
return super(TSLimeExplainer, self).explain_instance(
ts=ts, ts_related=None, **explain_params
)

def _explain_instance(
self,
ts: tsFrame,
**explain_params,
):
# for consistent results. Is it possible here?
np.random.seed(self.random_seed)

### input validation
if ts.shape[0] < self.input_length:
raise ValueError(
"Error: expecting input length {} but found {}.".format(
self.input_length, ts.shape[0]
)
)
xc = ts[-self.input_length :]
xc = to_np_array(xc)

### generate time series perturbations
x_perturbations = self._ts_perturb(x=xc)

### generate y
y_perturbations = self._batch_predict(x_perturbations)
if y_perturbations is None:
raise Exception(
"Model prediction could not be computed for gradient samples."
)

y_perturbations = np.asarray(y_perturbations).astype("float")

### select k time points - relevant_history
x_perturbations = x_perturbations[
:, -self.relevant_history :
] # consider only k time points

xc_relevant = xc[-self.relevant_history :, :].reshape(1, -1)

### compute weights using a linear model
surrogate, history_weights = linear_surrogate_weights(
surrogate=self.local_interpretable_model,
x_perturbations=x_perturbations,
y_perturbations=y_perturbations,
)

model_prediction = self._batch_predict(xc)

surrogate_prediction = surrogate.predict(xc_relevant)
explanation = {
"input_data": ts,
"model_prediction": model_prediction,
"surrogate_prediction": surrogate_prediction,
"history_weights": history_weights.reshape(self.relevant_history, -1),
"x_perturbations": x_perturbations,
"y_perturbations": y_perturbations,
}
return explanation
2 changes: 1 addition & 1 deletion aix360/algorithms/tssaliency/tssaliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def explain_instance(self, ts: tsFrame, **explain_params):

def _explain_instance(
self,
ts: Union["tsFrame", np.ndarray],
ts: tsFrame,
**explain_params,
):
# fix seed for consistent results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _transform(
block_swap = self._parameters.get("block_swap")

x_res = [self._residual.copy() for _ in range(n_perturbations)]
margin = self._residual.shape[0] - block_length
margin = self._residual.shape[0] - block_length + 1
for _ in range(block_swap):
if block_selector is None:
from_point = np.random.randint(
Expand Down
10 changes: 5 additions & 5 deletions aix360/algorithms/tsutils/tsperturbers/perturber_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def ts_rolling_mean(
if isinstance(ts, np.ndarray):
if len(ts.shape) == 1:
ts = ts.reshape(-1, 1)
ts = ts.astype("float32")
ts = ts.astype("float")
n_obs, n_vars = ts.shape
den = np.convolve(
np.ones(n_obs), np.ones(window_size, dtype="float32"), "same"
).astype("float32")
np.ones(n_obs), np.ones(window_size, dtype="float"), "same"
).astype("float")
df = np.asarray(
[
np.convolve(ts[:, i], np.ones(window_size), "same") / den
Expand All @@ -51,8 +51,8 @@ def ts_split_mean_residual(
format.
Args:
ts (Union[tsFrame, numpy ndarray]): input time series as dataframe or numpy array
window_size (int): numer of observation for averaging.
ts (Union[tsFrame, numpy ndarray]): input time series as tsFrame or numpy array
window_size (int): number of observations for averaging.
Returns:
tuple (Union[Tuple[numpy ndarray, numpy ndarray], Tuple[tsFrame, tsFrame]]): depending
Expand Down
19 changes: 11 additions & 8 deletions aix360/datasets/sunspots_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ class SunspotDataset:
References:
.. [#1] Andrews, D. F. and Herzberg, A. M., "Data: A Collection of Problems from
Many Fields for the Student and Research Worker,"
New York: Springer-Verlag, 1985.
Many Fields for the Student and Research Worker,"
New York: Springer-Verlag, 1985.
.. [#2] https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/sunspots.html
.. [#3] https://r-data.pmagunia.com/dataset/r-dataset-package-datasets-sunspots
.. [#4] Avishek Pal, PKS Prakash, "Practical Time Series Analysis"
https://github.com/PacktPublishing/Practical-Time-Series-Analysis/
"""

Expand All @@ -30,16 +32,17 @@ def __init__(self):
self.data_file = os.path.realpath(
os.path.join(self.data_folder, "sunspots.csv")
)
sunspots_url = (
"https://r-data.pmagunia.com/system/files/datasets/dataset-61024.csv"
)
sunspots_url = "https://raw.githubusercontent.com/PacktPublishing/Practical-Time-Series-Analysis/master/Data%20Files/monthly-sunspot-number-zurich-17.csv"

if not os.path.exists(self.data_file):
response = requests.get(sunspots_url)
data = pd.read_csv(StringIO(response.text))
data["time"] = pd.to_datetime(
data["time"].apply(self._convert_to_date), format="%Y-%m"
data = pd.read_csv(
StringIO(response.text),
skiprows=0,
nrows=2820,
)
data.columns = ["time", "sunspots"]
data["time"] = pd.to_datetime(data["time"], format="%Y-%m")

data.to_csv(self.data_file, index=False)

Expand Down
6 changes: 6 additions & 0 deletions docs/tslbbe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ Time Series Saliency (TSSaliency) Explainer

.. autoclass:: aix360.algorithms.tssaliency.tssaliency.TSSaliencyExplainer
:members:

Time Series Local Interpretable Model-agnostic Explainer (TSLime)
-------------------------------------------------------------------------

.. autoclass:: aix360.algorithms.tslime.tslime.TSLimeExplainer
:members:
4 changes: 4 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ the user through the various steps of the notebook.
- [TSSaliencyExplainer using FordA dataset](./tssaliency/tssaliency_univariate_demo.ipynb)[[on nbviewer](https://nbviewer.org/github/Trusted-AI/AIX360/blob/master/examples/tssaliency/tssaliency_univariate_demo.ipynb)]

- [TSSaliencyExplainer using Climate dataset](./tssaliency/tssaliency_multivariate_demo.ipynb)[[on nbviewer](https://nbviewer.org/github/Trusted-AI/AIX360/blob/master/examples/tssaliency/tssaliency_multivariate_demo.ipynb)]

- [TSLimeExplainer using FordA dataset](./tslime/tslime_univariate_demo.ipynb)[[on nbviewer](https://nbviewer.org/github/Trusted-AI/AIX360/blob/master/examples/tslime/tslime_univariate_demo.ipynb)]

- [TSLimeExplainer using Climate dataset](./tslime/tslime_multivariate_demo.ipynb)[[on nbviewer](https://nbviewer.org/github/Trusted-AI/AIX360/blob/master/examples/tslime/tslime_multivariate_demo.ipynb)]
Loading

0 comments on commit 4f14691

Please sign in to comment.