From c6e85d56b3288dc222d951c7e009690d2bad67c7 Mon Sep 17 00:00:00 2001 From: James Duncan Date: Wed, 26 Apr 2023 09:20:01 -0700 Subject: [PATCH] Add call_fitted_method to Vset --- vflow/vset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vflow/vset.py b/vflow/vset.py index bf3ed91..4d26bfe 100644 --- a/vflow/vset.py +++ b/vflow/vset.py @@ -223,6 +223,19 @@ def evaluate(self, *args): """ return self._apply_func(*args) + def call_fitted_method(self, *args, method: str, with_uncertainty: bool=False, group_by: list=None): + if not self._fitted: + raise AttributeError('Please fit the Vset object before calling call_fitted_method.') + pred_dict = {} + for k, v in self.fitted_vfuncs.items(): + if k != '__prev__': + assert hasattr(v, method), f'{v} does not have a "{method}" method.' + pred_dict[k] = getattr(v, method) + preds = self._apply_func(*args, out_dict=pred_dict) + if with_uncertainty: + return prediction_uncertainty(preds, group_by) + return preds + def __call__(self, *args, n_out: int = None, keys=None, **kwargs): """Call args using `_apply_func`, optionally seperating output dictionary into `n_out` dictionaries with `keys`