Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 27, 2023
1 parent 0738286 commit 2268b7c
Showing 1 changed file with 1 addition and 49 deletions.
50 changes: 1 addition & 49 deletions language_interpolation/state_space_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
from typing import Union


@dataclass
Expand Down Expand Up @@ -93,55 +94,6 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss

return logits


@staticmethod
def from_pretrained(pretrained_model_name: str):
"""Load pretrained weights from HuggingFace into model.
Args:
pretrained_model_name: One of
* 'state-spaces/mamba-2.8b-slimpj'
* 'state-spaces/mamba-2.8b'
* 'state-spaces/mamba-1.4b'
* 'state-spaces/mamba-790m'
* 'state-spaces/mamba-370m'
* 'state-spaces/mamba-130m'
Returns:
model: Mamba model with weights loaded
"""
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file

def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME,
_raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))


def load_state_dict_hf(model_name, device=None, dtype=None):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, weights_only=True)

config_data = load_config_hf(pretrained_model_name)
args = ModelArgs(
d_model=config_data['d_model'],
n_layer=config_data['n_layer'],
vocab_size=config_data['vocab_size']
)
model = Mamba(args)

state_dict = load_state_dict_hf(pretrained_model_name)
new_state_dict = {}
for key in state_dict:
new_key = key.replace('backbone.', '')
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)

return model


class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
Expand Down

0 comments on commit 2268b7c

Please sign in to comment.