diff --git a/language_interpolation/state_space_network.py b/language_interpolation/state_space_network.py index 0c5dc4a..e12e398 100644 --- a/language_interpolation/state_space_network.py +++ b/language_interpolation/state_space_network.py @@ -30,6 +30,7 @@ import torch.nn.functional as F from dataclasses import dataclass from einops import rearrange, repeat, einsum +from typing import Union @dataclass @@ -93,55 +94,6 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss return logits - - @staticmethod - def from_pretrained(pretrained_model_name: str): - """Load pretrained weights from HuggingFace into model. - - Args: - pretrained_model_name: One of - * 'state-spaces/mamba-2.8b-slimpj' - * 'state-spaces/mamba-2.8b' - * 'state-spaces/mamba-1.4b' - * 'state-spaces/mamba-790m' - * 'state-spaces/mamba-370m' - * 'state-spaces/mamba-130m' - - Returns: - model: Mamba model with weights loaded - - """ - from transformers.utils import WEIGHTS_NAME, CONFIG_NAME - from transformers.utils.hub import cached_file - - def load_config_hf(model_name): - resolved_archive_file = cached_file(model_name, CONFIG_NAME, - _raise_exceptions_for_missing_entries=False) - return json.load(open(resolved_archive_file)) - - - def load_state_dict_hf(model_name, device=None, dtype=None): - resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, - _raise_exceptions_for_missing_entries=False) - return torch.load(resolved_archive_file, weights_only=True) - - config_data = load_config_hf(pretrained_model_name) - args = ModelArgs( - d_model=config_data['d_model'], - n_layer=config_data['n_layer'], - vocab_size=config_data['vocab_size'] - ) - model = Mamba(args) - - state_dict = load_state_dict_hf(pretrained_model_name) - new_state_dict = {} - for key in state_dict: - new_key = key.replace('backbone.', '') - new_state_dict[new_key] = state_dict[key] - model.load_state_dict(new_state_dict) - - return model - class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs):