From dd1b5503ae16e0df0cd78f68388c9cb3c2a01acc Mon Sep 17 00:00:00 2001 From: jloveric Date: Fri, 29 Dec 2023 19:01:38 -0800 Subject: [PATCH] Add mamba specific generator --- language_interpolation/utils.py | 67 +++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/language_interpolation/utils.py b/language_interpolation/utils.py index 273752e..a427c15 100644 --- a/language_interpolation/utils.py +++ b/language_interpolation/utils.py @@ -158,6 +158,63 @@ def justify_sample(sample): return results +def generate_mamba_text( + model: nn.Module, + characters_per_feature: int, + max_characters: int, + text_list: List[str], + output_size: int, + topk: int = 1, +): + """ + TODO: This can be done much more efficiently because right now I'm re-computing + the output from the 0th input every time, while those may be identical (I think) + and so can be re-used. I think I have quadratic generation here when it should + be linear. + + :param characters_per_feature: typically 1, the number of characters that make up + a feature. + :param max_characters: The maximum number of characters to use, acts like a moving + window. Set to 0 if all characters should be used. + :param text_list: List of prompts + :param output_size: The number of characters to generate + :param topk: weighted random selection of the topk next characters + :returns: the continuation of the prompts, the original text + the next output_size + characters + """ + model.eval() + + + results = [] + for text_raw in text_list: + text_in = text_raw + for i in range(output_size): + encoding, text_used = encode_input_from_text( + text_in=text_in, features=max_characters + ) + encoding = ( + encoding + .to(model._device) + .reshape(1, -1, characters_per_feature) + ) + model.eval() + + output = model(encoding) + values, indices, ascii = decode_output_to_text( + encoding=output[0,-1,:], topk=topk + ) + + # pick the next character weighted by probabilities of each character + # prevents the same response for every query. + values = values.nan_to_num(nan=1.0) + actual = random.choices(ascii, values.tolist()) + text_in = text_in + actual[0] + + results.append(text_in.replace("\n", " ")) + + return results + + class TextGenerationSampler(Callback): def __init__(self, cfg): @@ -180,6 +237,16 @@ def on_train_epoch_end(self, trainer, pl_module, outputs=None): output_size=self._cfg.num_predict, topk=topk, ) + elif self._cfg.net.model_type in ["mamba"]: + predictions = generate_mamba_text( + pl_module, + characters_per_feature=self._cfg.data.characters_per_feature, + max_characters=self._cfg.data.characters_per_feature + * self._cfg.data.max_features, + text_list=self._cfg.prompts, + output_size=self._cfg.num_predict, + topk=topk, + ) else: predictions = generate_text( pl_module,