From 8f82962c329a70103ed6f0cad186e34b25e1a6ae Mon Sep 17 00:00:00 2001 From: zhengp0 Date: Tue, 15 Oct 2024 15:15:19 -0700 Subject: [PATCH] add _infer_shape function to mrbert class --- src/mrtool/core/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mrtool/core/model.py b/src/mrtool/core/model.py index ca94e4c..df8bed3 100644 --- a/src/mrtool/core/model.py +++ b/src/mrtool/core/model.py @@ -523,6 +523,7 @@ def __init__( self.weights = np.ones(self.num_sub_models) / self.num_sub_models + def _infer_shape(self) -> None: # inherent the dimension variable self.num_x_vars = self.sub_models[0].num_x_vars self.num_z_vars = self.sub_models[0].num_z_vars @@ -542,6 +543,8 @@ def fit_model( for sub_model in self.sub_models: sub_model.fit_model(**fit_options) + self._infer_shape() + self.score_model( scores_weights=scores_weights, slopes=slopes, quantiles=quantiles )