diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index adfb0d73..d3b5da17 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -7,7 +7,8 @@ from tensorflow_probability.substrates.jax.distributions import ( MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank, - MultivariateNormalFullCovariance as MVN) + MultivariateNormalFullCovariance as MVN, +) from jax.tree_util import tree_map from jaxtyping import Array, Float @@ -16,6 +17,7 @@ from dynamax.parameters import ParameterProperties from dynamax.types import PRNGKey, Scalar + class ParamsLGSSMInitial(NamedTuple): r"""Parameters of the initial distribution @@ -45,22 +47,30 @@ class ParamsLGSSMDynamics(NamedTuple): :param cov: dynamics covariance $Q$ """ - weights: Union[ParameterProperties, - Float[Array, "state_dim state_dim"], - Float[Array, "ntime state_dim state_dim"]] - - bias: Union[ParameterProperties, - Float[Array, "state_dim"], - Float[Array, "ntime state_dim"]] - - input_weights: Union[ParameterProperties, - Float[Array, "state_dim input_dim"], - Float[Array, "ntime state_dim input_dim"]] - - cov: Union[ParameterProperties, - Float[Array, "state_dim state_dim"], - Float[Array, "ntime state_dim state_dim"], - Float[Array, "state_dim_triu"]] + weights: Union[ + ParameterProperties, + Float[Array, "state_dim state_dim"], + Float[Array, "ntime state_dim state_dim"], + ] + + bias: Union[ + ParameterProperties, + Float[Array, "state_dim"], + Float[Array, "ntime state_dim"], + ] + + input_weights: Union[ + ParameterProperties, + Float[Array, "state_dim input_dim"], + Float[Array, "ntime state_dim input_dim"], + ] + + cov: Union[ + ParameterProperties, + Float[Array, "state_dim state_dim"], + Float[Array, "ntime state_dim state_dim"], + Float[Array, "state_dim_triu"], + ] class ParamsLGSSMEmissions(NamedTuple): @@ -76,24 +86,32 @@ class ParamsLGSSMEmissions(NamedTuple): :param cov: emission covariance $R$ """ - weights: Union[ParameterProperties, - Float[Array, "emission_dim state_dim"], - Float[Array, "ntime emission_dim state_dim"]] - - bias: Union[ParameterProperties, - Float[Array, "emission_dim"], - Float[Array, "ntime emission_dim"]] - - input_weights: Union[ParameterProperties, - Float[Array, "emission_dim input_dim"], - Float[Array, "ntime emission_dim input_dim"]] - - cov: Union[ParameterProperties, - Float[Array, "emission_dim emission_dim"], - Float[Array, "ntime emission_dim emission_dim"], - Float[Array, "emission_dim"], - Float[Array, "ntime emission_dim"], - Float[Array, "emission_dim_triu"]] + weights: Union[ + ParameterProperties, + Float[Array, "emission_dim state_dim"], + Float[Array, "ntime emission_dim state_dim"], + ] + + bias: Union[ + ParameterProperties, + Float[Array, "emission_dim"], + Float[Array, "ntime emission_dim"], + ] + + input_weights: Union[ + ParameterProperties, + Float[Array, "emission_dim input_dim"], + Float[Array, "ntime emission_dim input_dim"], + ] + + cov: Union[ + ParameterProperties, + Float[Array, "emission_dim emission_dim"], + Float[Array, "ntime emission_dim emission_dim"], + Float[Array, "emission_dim"], + Float[Array, "ntime emission_dim"], + Float[Array, "emission_dim_triu"], + ] class ParamsLGSSM(NamedTuple): @@ -145,6 +163,7 @@ class PosteriorGSSMSmoothed(NamedTuple): # Helper functions + def _get_one_param(x, dim, t): """Helper function to get one parameter at time t.""" if callable(x): @@ -154,6 +173,7 @@ def _get_one_param(x, dim, t): else: return x + def _get_params(params, num_timesteps, t): """Helper function to get parameters at time t.""" assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." @@ -166,9 +186,9 @@ def _get_params(params, num_timesteps, t): D = _get_one_param(params.emissions.input_weights, 2, t) d = _get_one_param(params.emissions.bias, 1, t) - if len(params.emissions.cov.shape) == 1: + if len(params.emissions.cov.shape) == 1: R = _get_one_param(params.emissions.cov, 1, t) - elif len(params.emissions.cov.shape) > 2: + elif len(params.emissions.cov.shape) > 2: R = _get_one_param(params.emissions.cov, 2, t) elif params.emissions.cov.shape[0] != num_timesteps: R = _get_one_param(params.emissions.cov, 2, t) @@ -179,7 +199,8 @@ def _get_params(params, num_timesteps, t): warnings.warn( "Emission covariance has shape (N,N) where N is the number of timesteps. " "The covariance will be interpreted as static and non-diagonal. To " - "specify a dynamic and diagonal covariance, pass it as a 3D array.") + "specify a dynamic and diagonal covariance, pass it as a 3D array." + ) return F, B, b, Q, H, D, d, R @@ -187,39 +208,40 @@ def _get_params(params, num_timesteps, t): _zeros_if_none = lambda x, shape: x if x is not None else jnp.zeros(shape) -def make_lgssm_params(initial_mean, - initial_cov, - dynamics_weights, - dynamics_cov, - emissions_weights, - emissions_cov, - dynamics_bias=None, - dynamics_input_weights=None, - emissions_bias=None, - emissions_input_weights=None): +def make_lgssm_params( + initial_mean, + initial_cov, + dynamics_weights, + dynamics_cov, + emissions_weights, + emissions_cov, + dynamics_bias=None, + dynamics_input_weights=None, + emissions_bias=None, + emissions_input_weights=None, +): """Helper function to construct a ParamsLGSSM object from arguments.""" state_dim = len(initial_mean) emission_dim = emissions_cov.shape[-1] - input_dim = max(dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0, - emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0) + input_dim = max( + dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0, + emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0, + ) params = ParamsLGSSM( - initial=ParamsLGSSMInitial( - mean=initial_mean, - cov=initial_cov - ), + initial=ParamsLGSSMInitial(mean=initial_mean, cov=initial_cov), dynamics=ParamsLGSSMDynamics( weights=dynamics_weights, - bias=_zeros_if_none(dynamics_bias,state_dim), + bias=_zeros_if_none(dynamics_bias, state_dim), input_weights=_zeros_if_none(dynamics_input_weights, (state_dim, input_dim)), - cov=dynamics_cov + cov=dynamics_cov, ), emissions=ParamsLGSSMEmissions( weights=emissions_weights, bias=_zeros_if_none(emissions_bias, emission_dim), input_weights=_zeros_if_none(emissions_input_weights, (emission_dim, input_dim)), - cov=emissions_cov - ) + cov=emissions_cov, + ), ) return params @@ -278,20 +300,20 @@ def _condition_on(m, P, H, D, d, R, u, y): if R.ndim == 2: S = R + H @ P @ H.T K = psd_solve(S, H @ P).T - else: + else: # Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I # (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity) I = jnp.eye(P.shape[0]) U = H @ jnp.linalg.cholesky(P) X = U / R[:, None] - S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) + S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) """ # Could alternatively use U=H and C=P R_inv = jnp.diag(1.0 / R) P_inv = psd_solve(P, jnp.eye(P.shape[0])) S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv) """ - K = P @ H.T @ S_inv + K = P @ H.T @ S_inv S = jnp.diag(R) + H @ P @ H.T Sigma_cond = P - K @ S @ K.T @@ -324,20 +346,20 @@ def preprocess_params_and_inputs(params, num_timesteps, inputs): emissions_bias = _zeros_if_none(params.emissions.bias, (emission_dim,)) full_params = ParamsLGSSM( - initial=ParamsLGSSMInitial( - mean=params.initial.mean, - cov=params.initial.cov), + initial=ParamsLGSSMInitial(mean=params.initial.mean, cov=params.initial.cov), dynamics=ParamsLGSSMDynamics( weights=params.dynamics.weights, bias=dynamics_bias, input_weights=dynamics_input_weights, - cov=params.dynamics.cov), + cov=params.dynamics.cov, + ), emissions=ParamsLGSSMEmissions( weights=params.emissions.weights, bias=emissions_bias, input_weights=emissions_input_weights, - cov=params.emissions.cov) - ) + cov=params.emissions.cov, + ), + ) return full_params, inputs @@ -350,28 +372,26 @@ def wrapper(*args, **kwargs): # Extract the arguments by name bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - params = bound_args.arguments['params'] - emissions = bound_args.arguments['emissions'] - inputs = bound_args.arguments['inputs'] + params = bound_args.arguments["params"] + emissions = bound_args.arguments["emissions"] + inputs = bound_args.arguments["inputs"] num_timesteps = len(emissions) full_params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) return f(full_params, emissions, inputs=inputs) - return wrapper - + return wrapper def lgssm_joint_sample( params: ParamsLGSSM, key: PRNGKey, num_timesteps: int, - inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None -)-> Tuple[Float[Array, "num_timesteps state_dim"], - Float[Array, "num_timesteps emission_dim"]]: + inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None, +) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: r"""Sample from the joint distribution to produce state and emission trajectories. - + Args: params: model parameters inputs: optional array of inputs. @@ -388,9 +408,9 @@ def _sample_transition(key, F, B, b, Q, x_tm1, u): def _sample_emission(key, H, D, d, R, x, u): mean = H @ x + D @ u + d - R = jnp.diag(R) if R.ndim==1 else R + R = jnp.diag(R) if R.ndim == 1 else R return MVN(mean, R).sample(seed=key) - + def _sample_initial(key, params, inputs): key1, key2 = jr.split(key) @@ -417,7 +437,7 @@ def _step(prev_state, args): # Sample the initial state key1, key2 = jr.split(key) - + initial_state, initial_emission = _sample_initial(key1, params, inputs) # Sample the remaining emissions and states @@ -437,8 +457,8 @@ def _step(prev_state, args): @preprocess_args def lgssm_filter( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMFiltered: r"""Run a Kalman filter to produce the marginal likelihood and filtered state estimates. @@ -456,13 +476,12 @@ def lgssm_filter( def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y): m = H @ pred_mean + D @ u + d - if R.ndim==2: + if R.ndim == 2: S = R + H @ pred_cov @ H.T return MVN(m, S).log_prob(y) else: L = H @ jnp.linalg.cholesky(pred_cov) return MVNLowRank(m, R, L).log_prob(y) - def _step(carry, t): ll, pred_mean, pred_cov = carry @@ -493,7 +512,7 @@ def _step(carry, t): def lgssm_smoother( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMSmoothed: r"""Run forward-filtering, backward-smoother to compute expectations under the posterior distribution on latent states. Technically, this @@ -560,10 +579,9 @@ def _step(carry, args): def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None, - jitter: Optional[Scalar]=0 - + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, + jitter: Optional[Scalar] = 0, ) -> Float[Array, "ntime state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$. diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 8e88c6bf..9d5faadd 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -12,7 +12,12 @@ from dynamax.ssm import SSM from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample -from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions +from dynamax.linear_gaussian_ssm.inference import ( + ParamsLGSSM, + ParamsLGSSMInitial, + ParamsLGSSMDynamics, + ParamsLGSSMEmissions, +) from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed from dynamax.parameters import ParameterProperties, ParameterSet from dynamax.types import PRNGKey, Scalar @@ -22,8 +27,10 @@ from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update from dynamax.utils.utils import pytree_stack, psd_solve + class SuffStatsLGSSM(Protocol): """A :class:`NamedTuple` with sufficient statistics for LGSSM parameter estimation.""" + pass @@ -63,13 +70,14 @@ class LinearGaussianSSM(SSM): :param has_emissions_bias: Whether model contains an offset term $d$. Defaults to True. """ + def __init__( self, state_dim: int, emission_dim: int, - input_dim: int=0, - has_dynamics_bias: bool=True, - has_emissions_bias: bool=True + input_dim: int = 0, + has_dynamics_bias: bool = True, + has_emissions_bias: bool = True, ): self.state_dim = state_dim self.emission_dim = emission_dim @@ -87,8 +95,8 @@ def inputs_shape(self): def initialize( self, - key: PRNGKey =jr.PRNGKey(0), - initial_mean: Optional[Float[Array, "state_dim"]]=None, + key: PRNGKey = jr.PRNGKey(0), + initial_mean: Optional[Float[Array, "state_dim"]] = None, initial_covariance=None, dynamics_weights=None, dynamics_bias=None, @@ -97,7 +105,7 @@ def initialize( emission_weights=None, emission_bias=None, emission_input_weights=None, - emission_covariance=None + emission_covariance=None, ) -> Tuple[ParamsLGSSM, ParamsLGSSM]: r"""Initialize model parameters that are set to None, and their corresponding properties. @@ -137,41 +145,47 @@ def initialize( params = ParamsLGSSM( initial=ParamsLGSSMInitial( mean=default(initial_mean, _initial_mean), - cov=default(initial_covariance, _initial_covariance)), + cov=default(initial_covariance, _initial_covariance), + ), dynamics=ParamsLGSSMDynamics( weights=default(dynamics_weights, _dynamics_weights), bias=default(dynamics_bias, _dynamics_bias), input_weights=default(dynamics_input_weights, _dynamics_input_weights), - cov=default(dynamics_covariance, _dynamics_covariance)), + cov=default(dynamics_covariance, _dynamics_covariance), + ), emissions=ParamsLGSSMEmissions( weights=default(emission_weights, _emission_weights), bias=default(emission_bias, _emission_bias), input_weights=default(emission_input_weights, _emission_input_weights), - cov=default(emission_covariance, _emission_covariance)) - ) + cov=default(emission_covariance, _emission_covariance), + ), + ) # The keys of param_props must match those of params! props = ParamsLGSSM( initial=ParamsLGSSMInitial( mean=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())), + cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), dynamics=ParamsLGSSMDynamics( weights=ParameterProperties(), bias=ParameterProperties(), input_weights=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())), + cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), emissions=ParamsLGSSMEmissions( weights=ParameterProperties(), bias=ParameterProperties(), input_weights=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())) - ) + cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), + ) return params, props def initial_distribution( self, params: ParamsLGSSM, - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: return MVN(params.initial.mean, params.initial.cov) @@ -179,7 +193,7 @@ def transition_distribution( self, params: ParamsLGSSM, state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.dynamics.weights @ state + params.dynamics.input_weights @ inputs @@ -191,7 +205,7 @@ def emission_distribution( self, params: ParamsLGSSM, state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.emissions.weights @ state + params.emissions.input_weights @ inputs @@ -203,7 +217,7 @@ def marginal_log_prob( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Scalar: filtered_posterior = lgssm_filter(params, emissions, inputs) return filtered_posterior.marginal_loglik @@ -212,7 +226,7 @@ def filter( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMFiltered: return lgssm_filter(params, emissions, inputs) @@ -220,7 +234,7 @@ def smoother( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMSmoothed: return lgssm_smoother(params, emissions, inputs) @@ -229,7 +243,7 @@ def posterior_sample( key: PRNGKey, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Float[Array, "ntime state_dim"]: return lgssm_posterior_sample(key, params, emissions, inputs) @@ -237,7 +251,7 @@ def posterior_predictive( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Tuple[Float[Array, "ntime emission_dim"], Float[Array, "ntime emission_dim"]]: r"""Compute marginal posterior predictive smoothing distribution for each observation. @@ -257,18 +271,23 @@ def posterior_predictive( emission_dim = R.shape[0] smoothed_emissions = posterior.smoothed_means @ H.T + b smoothed_emissions_cov = H @ posterior.smoothed_covariances @ H.T + R - smoothed_emissions_std = jnp.sqrt( - jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)])) + smoothed_emissions_std = jnp.sqrt(jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)])) return smoothed_emissions, smoothed_emissions_std # Expectation-maximization (EM) code def e_step( self, params: ParamsLGSSM, - emissions: Union[Float[Array, "num_timesteps emission_dim"], - Float[Array, "num_batches num_timesteps emission_dim"]], - inputs: Optional[Union[Float[Array, "num_timesteps input_dim"], - Float[Array, "num_batches num_timesteps input_dim"]]]=None, + emissions: Union[ + Float[Array, "num_timesteps emission_dim"], + Float[Array, "num_batches num_timesteps emission_dim"], + ], + inputs: Optional[ + Union[ + Float[Array, "num_timesteps input_dim"], + Float[Array, "num_batches num_timesteps input_dim"], + ] + ] = None, ) -> Tuple[SuffStatsLGSSM, Scalar]: num_timesteps = emissions.shape[0] if inputs is None: @@ -301,18 +320,17 @@ def e_step( # let zp[t] = [x[t], u[t]] for t = 0...T-2 # let xn[t] = x[t+1] for t = 0...T-2 sum_zpzpT = jnp.block([[Exp.T @ Exp, Exp.T @ up], [up.T @ Exp, up.T @ up]]) - sum_zpzpT = sum_zpzpT.at[:self.state_dim, :self.state_dim].add(Vxp.sum(0)) + sum_zpzpT = sum_zpzpT.at[: self.state_dim, : self.state_dim].add(Vxp.sum(0)) sum_zpxnT = jnp.block([[Expxn.sum(0)], [up.T @ Exn]]) sum_xnxnT = Vxn.sum(0) + Exn.T @ Exn dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1) if not self.has_dynamics_bias: - dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, - num_timesteps - 1) + dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, num_timesteps - 1) # more expected sufficient statistics for the emissions # let z[t] = [x[t], u[t]] for t = 0...T-1 sum_zzT = jnp.block([[Ex.T @ Ex, Ex.T @ u], [u.T @ Ex, u.T @ u]]) - sum_zzT = sum_zzT.at[:self.state_dim, :self.state_dim].add(Vx.sum(0)) + sum_zzT = sum_zzT.at[: self.state_dim, : self.state_dim].add(Vx.sum(0)) sum_zyT = jnp.block([[Ex.T @ y], [u.T @ y]]) sum_yyT = emissions.T @ emissions emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps) @@ -321,22 +339,12 @@ def e_step( return (init_stats, dynamics_stats, emission_stats), posterior.marginal_loglik - - def initialize_m_step_state( - self, - params: ParamsLGSSM, - props: ParamsLGSSM - ) -> Any: + def initialize_m_step_state(self, params: ParamsLGSSM, props: ParamsLGSSM) -> Any: return None def m_step( - self, - params: ParamsLGSSM, - props: ParamsLGSSM, - batch_stats: SuffStatsLGSSM, - m_step_state: Any + self, params: ParamsLGSSM, props: ParamsLGSSM, batch_stats: SuffStatsLGSSM, m_step_state: Any ) -> Tuple[ParamsLGSSM, Any]: - def fit_linear_regression(ExxT, ExyT, EyyT, N): # Solve a linear regression given sufficient statistics W = psd_solve(ExxT, ExyT).T @@ -353,19 +361,17 @@ def fit_linear_regression(ExxT, ExyT, EyyT, N): m = sum_x0 / N FB, Q = fit_linear_regression(*dynamics_stats) - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], None) + F = FB[:, : self.state_dim] + B, b = (FB[:, self.state_dim : -1], FB[:, -1]) if self.has_dynamics_bias else (FB[:, self.state_dim :], None) HD, R = fit_linear_regression(*emission_stats) - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], None) + H = HD[:, : self.state_dim] + D, d = (HD[:, self.state_dim : -1], HD[:, -1]) if self.has_emissions_bias else (HD[:, self.state_dim :], None) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params, m_step_state @@ -387,40 +393,51 @@ class LinearGaussianConjugateSSM(LinearGaussianSSM): :param has_emissions_bias: Whether model contains an offset term d. Defaults to True. """ - def __init__(self, - state_dim, - emission_dim, - input_dim=0, - has_dynamics_bias=True, - has_emissions_bias=True, - **kw_priors): - super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim, - has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias) + + def __init__( + self, state_dim, emission_dim, input_dim=0, has_dynamics_bias=True, has_emissions_bias=True, **kw_priors + ): + super().__init__( + state_dim=state_dim, + emission_dim=emission_dim, + input_dim=input_dim, + has_dynamics_bias=has_dynamics_bias, + has_emissions_bias=has_emissions_bias, + ) # Initialize prior distributions def default_prior(arg, default): return kw_priors[arg] if arg in kw_priors else default self.initial_prior = default_prior( - 'initial_prior', - NIW(loc=jnp.zeros(self.state_dim), - mean_concentration=1., + "initial_prior", + NIW( + loc=jnp.zeros(self.state_dim), + mean_concentration=1.0, df=self.state_dim + 0.1, - scale=jnp.eye(self.state_dim))) + scale=jnp.eye(self.state_dim), + ), + ) self.dynamics_prior = default_prior( - 'dynamics_prior', - MNIW(loc=jnp.zeros((self.state_dim, self.state_dim + self.input_dim + self.has_dynamics_bias)), - col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_dynamics_bias), - df=self.state_dim + 0.1, - scale=jnp.eye(self.state_dim))) + "dynamics_prior", + MNIW( + loc=jnp.zeros((self.state_dim, self.state_dim + self.input_dim + self.has_dynamics_bias)), + col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_dynamics_bias), + df=self.state_dim + 0.1, + scale=jnp.eye(self.state_dim), + ), + ) self.emission_prior = default_prior( - 'emission_prior', - MNIW(loc=jnp.zeros((self.emission_dim, self.state_dim + self.input_dim + self.has_emissions_bias)), - col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_emissions_bias), - df=self.emission_dim + 0.1, - scale=jnp.eye(self.emission_dim))) + "emission_prior", + MNIW( + loc=jnp.zeros((self.emission_dim, self.state_dim + self.input_dim + self.has_emissions_bias)), + col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_emissions_bias), + df=self.emission_dim + 0.1, + scale=jnp.eye(self.emission_dim), + ), + ) @property def emission_shape(self): @@ -430,39 +447,23 @@ def emission_shape(self): def covariates_shape(self): return dict(inputs=(self.input_dim,)) if self.input_dim > 0 else dict() - def log_prior( - self, - params: ParamsLGSSM - ) -> Scalar: + def log_prior(self, params: ParamsLGSSM) -> Scalar: lp = self.initial_prior.log_prob((params.initial.cov, params.initial.mean)) # dynamics dynamics_bias = params.dynamics.bias if self.has_dynamics_bias else jnp.zeros((self.state_dim, 0)) - dynamics_matrix = jnp.column_stack((params.dynamics.weights, - params.dynamics.input_weights, - dynamics_bias)) + dynamics_matrix = jnp.column_stack((params.dynamics.weights, params.dynamics.input_weights, dynamics_bias)) lp += self.dynamics_prior.log_prob((params.dynamics.cov, dynamics_matrix)) emission_bias = params.emissions.bias if self.has_emissions_bias else jnp.zeros((self.emission_dim, 0)) - emission_matrix = jnp.column_stack((params.emissions.weights, - params.emissions.input_weights, - emission_bias)) + emission_matrix = jnp.column_stack((params.emissions.weights, params.emissions.input_weights, emission_bias)) lp += self.emission_prior.log_prob((params.emissions.cov, emission_matrix)) return lp - def initialize_m_step_state( - self, - params: ParamsLGSSM, - props: ParamsLGSSM - ) -> Any: + def initialize_m_step_state(self, params: ParamsLGSSM, props: ParamsLGSSM) -> Any: return None - def m_step( - self, - params: ParamsLGSSM, - props: ParamsLGSSM, - batch_stats: SuffStatsLGSSM, - m_step_state: Any): + def m_step(self, params: ParamsLGSSM, props: ParamsLGSSM, batch_stats: SuffStatsLGSSM, m_step_state: Any): # Sum the statistics across all batches stats = tree_map(partial(jnp.sum, axis=0), batch_stats) init_stats, dynamics_stats, emission_stats = stats @@ -473,20 +474,26 @@ def m_step( dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats) Q, FB = dynamics_posterior.mode() - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], jnp.zeros(self.state_dim)) + F = FB[:, : self.state_dim] + B, b = ( + (FB[:, self.state_dim : -1], FB[:, -1]) + if self.has_dynamics_bias + else (FB[:, self.state_dim :], jnp.zeros(self.state_dim)) + ) emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats) R, HD = emission_posterior.mode() - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim)) + H = HD[:, : self.state_dim] + D, d = ( + (HD[:, self.state_dim : -1], HD[:, -1]) + if self.has_emissions_bias + else (HD[:, self.state_dim :], jnp.zeros(self.emission_dim)) + ) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params, m_step_state @@ -496,7 +503,7 @@ def fit_blocked_gibbs( initial_params: ParamsLGSSM, sample_size: int, emissions: Float[Array, "nbatch ntime emission_dim"], - inputs: Optional[Float[Array, "nbatch ntime input_dim"]]=None + inputs: Optional[Float[Array, "nbatch ntime input_dim"]] = None, ) -> ParamsLGSSM: r"""Estimate parameter posterior using block-Gibbs sampler. @@ -532,8 +539,7 @@ def sufficient_stats_from_sample(states): sum_xnxnT = xn.T @ xn dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1) if not self.has_dynamics_bias: - dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, - num_timesteps - 1) + dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, num_timesteps - 1) # Quantities for the emissions # Let z[t] = [x[t], u[t]] for t = 0...T-1 @@ -558,21 +564,27 @@ def lgssm_params_sample(rng, stats): # Sample the dynamics params dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats) Q, FB = dynamics_posterior.sample(seed=next(rngs)) - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], jnp.zeros(self.state_dim)) + F = FB[:, : self.state_dim] + B, b = ( + (FB[:, self.state_dim : -1], FB[:, -1]) + if self.has_dynamics_bias + else (FB[:, self.state_dim :], jnp.zeros(self.state_dim)) + ) # Sample the emission params emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats) R, HD = emission_posterior.sample(seed=next(rngs)) - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim)) + H = HD[:, : self.state_dim] + D, d = ( + (HD[:, self.state_dim : -1], HD[:, -1]) + if self.has_emissions_bias + else (HD[:, self.state_dim :], jnp.zeros(self.emission_dim)) + ) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params @@ -585,7 +597,6 @@ def one_sample(_params, rng): _stats = sufficient_stats_from_sample(states) return lgssm_params_sample(rngs[1], _stats) - sample_of_params = [] keys = iter(jr.split(key, sample_size)) current_params = initial_params diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index c4394858..b4adec23 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -12,10 +12,11 @@ (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10), None), ] + @pytest.mark.parametrize(["cls", "kwargs", "inputs"], CONFIGS) def test_sample_and_fit(cls, kwargs, inputs): model = cls(**kwargs) - #key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) + # key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) key1, key2 = jr.split(jr.PRNGKey(0)) params, param_props = model.initialize(key1) states, emissions = model.sample(params, key2, num_timesteps=NUM_TIMESTEPS, inputs=inputs) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 2e50e8d6..fc8b0453 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -1,4 +1,4 @@ -''' +""" Parallel filtering and smoothing for a lgssm. This implementation is adapted from the work of Adrien Correnflos: @@ -28,7 +28,7 @@ | | | | Y₀ Y₁ Y₂ Y₃ -''' +""" import jax.numpy as jnp from jax import vmap, lax @@ -40,7 +40,8 @@ from tensorflow_probability.substrates.jax.distributions import ( MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank, - MultivariateNormalFullCovariance as MVN) + MultivariateNormalFullCovariance as MVN, +) from jax.scipy.linalg import cho_solve, cho_factor from dynamax.utils.utils import symmetrize, psd_solve @@ -56,6 +57,7 @@ def _get_one_param(x, dim, t): else: return x + def _get_params(params, num_timesteps, t): """Helper function to get parameters at time t.""" assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." @@ -63,30 +65,32 @@ def _get_params(params, num_timesteps, t): F = _get_one_param(params.dynamics.weights, 2, t) b = _get_one_param(params.dynamics.bias, 1, t) Q = _get_one_param(params.dynamics.cov, 2, t) - H = _get_one_param(params.emissions.weights, 2, t+1) - d = _get_one_param(params.emissions.bias, 1, t+1) + H = _get_one_param(params.emissions.weights, 2, t + 1) + d = _get_one_param(params.emissions.bias, 1, t + 1) - if len(params.emissions.cov.shape) == 1: - R = _get_one_param(params.emissions.cov, 1, t+1) - elif len(params.emissions.cov.shape) > 2: - R = _get_one_param(params.emissions.cov, 2, t+1) + if len(params.emissions.cov.shape) == 1: + R = _get_one_param(params.emissions.cov, 1, t + 1) + elif len(params.emissions.cov.shape) > 2: + R = _get_one_param(params.emissions.cov, 2, t + 1) elif params.emissions.cov.shape[0] != num_timesteps: - R = _get_one_param(params.emissions.cov, 2, t+1) + R = _get_one_param(params.emissions.cov, 2, t + 1) elif params.emissions.cov.shape[1] != num_timesteps: - R = _get_one_param(params.emissions.cov, 1, t+1) + R = _get_one_param(params.emissions.cov, 1, t + 1) else: - R = _get_one_param(params.emissions.cov, 2, t+1) + R = _get_one_param(params.emissions.cov, 2, t + 1) warnings.warn( "Emission covariance has shape (N,N) where N is the number of timesteps. " "The covariance will be interpreted as static and non-diagonal. To " - "specify a dynamic and diagonal covariance, pass it as a 3D array.") + "specify a dynamic and diagonal covariance, pass it as a 3D array." + ) return F, b, Q, H, d, R -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# # Filtering # -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# + def _emissions_scale(Q, H, R): """Compute the scale matrix for the emissions given the state covariance. @@ -110,13 +114,13 @@ def _emissions_scale(Q, H, R): I = jnp.eye(Q.shape[0]) U = H @ jnp.linalg.cholesky(Q) X = U / R[:, None] - S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) + S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) return S_inv def _marginal_loglik_elem(Q, H, R, y): - """Compute marginal log-likelihood elements. - + """Compute marginal log-likelihood elements. + Args: Q (state_dim, state_dim): State covariance. H (emission_dim, state_dim): Emission matrix. @@ -143,11 +147,12 @@ class FilterMessage(NamedTuple): eta: P(z_{i-1} | y_{i:j}) mean. logZ: log P(y_{i:j}) marginal log-likelihood. """ - A: Float[Array, "ntime state_dim state_dim"] - b: Float[Array, "ntime state_dim"] - C: Float[Array, "ntime state_dim state_dim"] - J: Float[Array, "ntime state_dim state_dim"] - eta: Float[Array, "ntime state_dim"] + + A: Float[Array, "ntime state_dim state_dim"] + b: Float[Array, "ntime state_dim"] + C: Float[Array, "ntime state_dim state_dim"] + J: Float[Array, "ntime state_dim state_dim"] + eta: Float[Array, "ntime state_dim"] logZ: Float[Array, "ntime"] @@ -155,13 +160,13 @@ def _initialize_filtering_messages(params, emissions): """Preprocess observations to construct input for filtering assocative scan.""" num_timesteps = emissions.shape[0] - + def _first_message(params, y): H, d, R = _get_params(params, num_timesteps, -1)[3:] m = params.initial.mean P = params.initial.cov - S = H @ P @ H.T + (R if R.ndim==2 else jnp.diag(R)) + S = H @ P @ H.T + (R if R.ndim == 2 else jnp.diag(R)) S_inv = _emissions_scale(P, H, R) K = P @ H.T @ S_inv @@ -174,14 +179,13 @@ def _first_message(params, y): logZ = _marginal_loglik_elem(P, H, R, y) return A, b, C, J, eta, logZ - @partial(vmap, in_axes=(None, 0, 0)) def _generic_message(params, y, t): F, b, Q, H, d, R = _get_params(params, num_timesteps, t) S_inv = _emissions_scale(Q, H, R) K = Q @ H.T @ S_inv - + eta = F.T @ H.T @ S_inv @ (y - H @ b - d) J = symmetrize(F.T @ H.T @ S_inv @ H @ F) @@ -193,7 +197,7 @@ def _generic_message(params, y, t): return A, b, C, J, eta, logZ A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0]) - At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions)-1)) + At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions) - 1)) return FilterMessage( A=jnp.concatenate([A0[None], At]), @@ -201,21 +205,18 @@ def _generic_message(params, y, t): C=jnp.concatenate([C0[None], Ct]), J=jnp.concatenate([J0[None], Jt]), eta=jnp.concatenate([eta0[None], etat]), - logZ=jnp.concatenate([logZ0[None], logZt]) + logZ=jnp.concatenate([logZ0[None], logZt]), ) - -def lgssm_filter( - params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] -) -> PosteriorGSSMFiltered: +def lgssm_filter(params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"]) -> PosteriorGSSMFiltered: """A parallel version of the lgssm filtering algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. Note: This function does not yet handle `inputs` to the system. """ + @vmap def _operator(elem1, elem2): A1, b1, C1, J1, eta1, logZ1 = elem1 @@ -234,22 +235,22 @@ def _operator(elem1, elem2): J = symmetrize(temp @ J2 @ A1 + J1) mu = jnp.linalg.solve(C1, b1) - t1 = (b1 @ mu - (eta2 + mu) @ jnp.linalg.solve(I_C1J2, C1 @ eta2 + b1)) - logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1) + t1 = b1 @ mu - (eta2 + mu) @ jnp.linalg.solve(I_C1J2, C1 @ eta2 + b1) + logZ = logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1 return FilterMessage(A, b, C, J, eta, logZ) initial_messages = _initialize_filtering_messages(params, emissions) final_messages = lax.associative_scan(_operator, initial_messages) return PosteriorGSSMFiltered( - filtered_means=final_messages.b, - filtered_covariances=final_messages.C, - marginal_loglik=-final_messages.logZ[-1]) + filtered_means=final_messages.b, filtered_covariances=final_messages.C, marginal_loglik=-final_messages.logZ[-1] + ) -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# # Smoothing # -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# + class SmoothMessage(NamedTuple): """ @@ -260,6 +261,7 @@ class SmoothMessage(NamedTuple): g: P(z_i | y_{1:j}, z_{j+1}) bias. L: P(z_i | y_{1:j}, z_{j+1}) covariance. """ + E: Float[Array, "ntime state_dim state_dim"] g: Float[Array, "ntime state_dim"] L: Float[Array, "ntime state_dim state_dim"] @@ -278,24 +280,21 @@ def _generic_message(params, m, P, t): F, b, Q = _get_params(params, num_timesteps, t)[:3] CF, low = cho_factor(F @ P @ F.T + Q) E = cho_solve((CF, low), F @ P).T - g = m - E @ (F @ m + b) - L = symmetrize(P - E @ F @ P) + g = m - E @ (F @ m + b) + L = symmetrize(P - E @ F @ P) return E, g, L - + En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1]) - Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1)) - + Et, gt, Lt = _generic_message( + params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means) - 1) + ) + return SmoothMessage( - E=jnp.concatenate([Et, En[None]]), - g=jnp.concatenate([gt, gn[None]]), - L=jnp.concatenate([Lt, Ln[None]]) + E=jnp.concatenate([Et, En[None]]), g=jnp.concatenate([gt, gn[None]]), L=jnp.concatenate([Lt, Ln[None]]) ) -def lgssm_smoother( - params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] -) -> PosteriorGSSMSmoothed: +def lgssm_smoother(params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"]) -> PosteriorGSSMSmoothed: """A parallel version of the lgssm smoothing algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. @@ -305,7 +304,7 @@ def lgssm_smoother( filtered_posterior = lgssm_filter(params, emissions) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances - + @vmap def _operator(elem1, elem2): E1, g1, L1 = elem1 @@ -323,13 +322,14 @@ def _operator(elem1, elem2): filtered_means=filtered_means, filtered_covariances=filtered_covs, smoothed_means=final_messages.g, - smoothed_covariances=final_messages.L + smoothed_covariances=final_messages.L, ) -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# # Sampling # -#---------------------------------------------------------------------------# +# ---------------------------------------------------------------------------# + class SampleMessage(NamedTuple): """ @@ -339,14 +339,15 @@ class SampleMessage(NamedTuple): E: z_i ~ z_{j+1} weights. h: z_i ~ z_{j+1} bias. """ + E: Float[Array, "ntime state_dim state_dim"] h: Float[Array, "ntime state_dim"] def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances): """A parallel version of the lgssm sampling algorithm. - - Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, + + Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`. """ E, g, L = _initialize_smoothing_messages(params, filtered_means, filtered_covariances) @@ -354,9 +355,7 @@ def _initialize_sampling_messages(key, params, filtered_means, filtered_covarian def lgssm_posterior_sample( - key: PRNGKey, - params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + key: PRNGKey, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"] ) -> Float[Array, "ntime state_dim"]: """A parallel version of the lgssm sampling algorithm. @@ -379,4 +378,4 @@ def _operator(elem1, elem2): initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs) _, samples = lax.associative_scan(_operator, initial_messages, reverse=True) - return samples \ No newline at end of file + return samples diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index cd6376b3..3bef9b75 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -12,40 +12,40 @@ from dynamax.linear_gaussian_ssm.inference_test import flatten_diagonal_emission_cov -def allclose(x,y, atol=1e-2): - m = jnp.abs(jnp.max(x-y)) +def allclose(x, y, atol=1e-2): + m = jnp.abs(jnp.max(x - y)) if m > atol: print(m) return False else: return True - + def make_static_lgssm_params(): dt = 0.1 F = jnp.eye(4) + dt * jnp.eye(4, k=2) - Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], - [dt**2/2, dt]]), - jnp.eye(2)) - + Q = 1.0 * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(2)) + H = jnp.eye(2, 4) - R = 0.5 ** 2 * jnp.eye(2) - μ0 = jnp.array([0.,0.,1.,-1.]) + R = 0.5**2 * jnp.eye(2) + μ0 = jnp.array([0.0, 0.0, 1.0, -1.0]) Σ0 = jnp.eye(4) latent_dim = 4 observation_dim = 2 lgssm = LinearGaussianSSM(latent_dim, observation_dim) - params, _ = lgssm.initialize(jr.PRNGKey(0), - initial_mean=μ0, - initial_covariance= Σ0, - dynamics_weights=F, - dynamics_covariance=Q, - emission_weights=H, - emission_covariance=R) + params, _ = lgssm.initialize( + jr.PRNGKey(0), + initial_mean=μ0, + initial_covariance=Σ0, + dynamics_weights=F, + dynamics_covariance=Q, + emission_weights=H, + emission_covariance=R, + ) return params, lgssm - + def make_dynamic_lgssm_params(num_timesteps, latent_dim=4, observation_dim=2, seed=0): key = jr.PRNGKey(seed) @@ -53,39 +53,39 @@ def make_dynamic_lgssm_params(num_timesteps, latent_dim=4, observation_dim=2, se dt = 0.1 f_scale = jr.normal(key_f, (num_timesteps,)) * 0.5 - F = f_scale[:,None,None] * jnp.tile(jnp.eye(latent_dim), (num_timesteps, 1, 1)) + F = f_scale[:, None, None] * jnp.tile(jnp.eye(latent_dim), (num_timesteps, 1, 1)) F += dt * jnp.eye(latent_dim, k=2) - Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], - [dt**2/2, dt]]), - jnp.eye(latent_dim // 2)) + Q = 1.0 * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(latent_dim // 2)) assert Q.shape[-1] == latent_dim H = jnp.eye(observation_dim, latent_dim) r_scale = jr.normal(key_r, (num_timesteps,)) * 0.1 - R = (r_scale**2)[:,None,None] * jnp.tile(jnp.eye(observation_dim), (num_timesteps, 1, 1)) - - μ0 = jnp.array([0.,0.,1.,-1.]) + R = (r_scale**2)[:, None, None] * jnp.tile(jnp.eye(observation_dim), (num_timesteps, 1, 1)) + + μ0 = jnp.array([0.0, 0.0, 1.0, -1.0]) Σ0 = jnp.eye(latent_dim) lgssm = LinearGaussianSSM(latent_dim, observation_dim) - params, _ = lgssm.initialize(key_init, - initial_mean=μ0, - initial_covariance=Σ0, - dynamics_weights=F, - dynamics_covariance=Q, - emission_weights=H, - emission_covariance=R) + params, _ = lgssm.initialize( + key_init, + initial_mean=μ0, + initial_covariance=Σ0, + dynamics_weights=F, + dynamics_covariance=Q, + emission_weights=H, + emission_covariance=R, + ) return params, lgssm class TestParallelLGSSMSmoother: - """ Compare parallel and serial lgssm smoothing implementations.""" - + """Compare parallel and serial lgssm smoothing implementations.""" + num_timesteps = 50 key = jr.PRNGKey(1) - params, lgssm = make_static_lgssm_params() + params, lgssm = make_static_lgssm_params() params_diag = flatten_diagonal_emission_cov(params) _, emissions = lgssm_joint_sample(params, key, num_timesteps) @@ -111,20 +111,21 @@ def test_smoothed_covariances(self): def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1) - assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1) - - + assert jnp.allclose( + self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1 + ) class TestTimeVaryingParallelLGSSMSmoother: """Compare parallel and serial time-varying lgssm smoothing implementations. - + Vary dynamics weights and observation covariances with time. """ + num_timesteps = 50 key = jr.PRNGKey(1) - params, lgssm = make_dynamic_lgssm_params(num_timesteps) + params, lgssm = make_dynamic_lgssm_params(num_timesteps) params_diag = flatten_diagonal_emission_cov(params) _, emissions = lgssm_joint_sample(params, key, num_timesteps) @@ -150,17 +151,18 @@ def test_smoothed_covariances(self): def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1) - assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1) - + assert jnp.allclose( + self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1 + ) -class TestTimeVaryingParallelLGSSMSampler(): +class TestTimeVaryingParallelLGSSMSampler: """Compare parallel and serial lgssm posterior sampling implementations in expectation.""" - + num_timesteps = 50 key = jr.PRNGKey(1) - params, lgssm = make_dynamic_lgssm_params(num_timesteps) + params, lgssm = make_dynamic_lgssm_params(num_timesteps) params_diag = flatten_diagonal_emission_cov(params) _, emissions = lgssm_joint_sample(params_diag, key, num_timesteps) @@ -168,14 +170,13 @@ class TestTimeVaryingParallelLGSSMSampler(): serial_keys = jr.split(jr.PRNGKey(2), num_samples) parallel_keys = jr.split(jr.PRNGKey(3), num_samples) - serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0,None,None))( - serial_keys, params, emissions) - - parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None))( - parallel_keys, params, emissions) - + serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0, None, None))(serial_keys, params, emissions) + + parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None))(parallel_keys, params, emissions) + parallel_samples_diag = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None))( - parallel_keys, params, emissions) + parallel_keys, params, emissions + ) def test_sampled_means(self): serial_mean = self.serial_samples.mean(axis=0) @@ -190,4 +191,4 @@ def test_sampled_covariances(self): parallel_cov = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.parallel_samples) parallel_cov_diag = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.parallel_samples) assert allclose(serial_cov, parallel_cov, atol=1e-1) - assert allclose(serial_cov, parallel_cov_diag, atol=1e-1) \ No newline at end of file + assert allclose(serial_cov, parallel_cov_diag, atol=1e-1)