diff --git a/src/autora/experimentalist/sampler/model_disagreement/__init__.py b/src/autora/experimentalist/sampler/model_disagreement/__init__.py index 0934b65..a610967 100644 --- a/src/autora/experimentalist/sampler/model_disagreement/__init__.py +++ b/src/autora/experimentalist/sampler/model_disagreement/__init__.py @@ -63,6 +63,6 @@ def model_disagreement_sample(condition_pool: np.array, models: List, num_sample # sort the summed disagreements and select the top n idx = (-summed_disagreement).argsort()[:num_samples] - return X[idx] + return condition_pool[idx] model_disagreement_sampler = deprecated_alias(model_disagreement_sample, "model_disagreement_sampler")