-
Notifications
You must be signed in to change notification settings - Fork 0
/
hyena-finetune.py
44 lines (33 loc) · 1.5 KB
/
hyena-finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import TrainingArguments, Trainer, logging
import torch
from datasets import load_dataset
import os
os.environ["WANDB_PROJECT"] = "CompBio-Evo" # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints
# instantiate pretrained model
checkpoint = 'LongSafari/hyenadna-tiny-16k-seqlen-d128'
max_length = 160_000
# bfloat16 for better speed and reduced memory usage
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
ds = load_dataset("anudaw/genome-classification")
def preprocess_function(examples):
return tokenizer(examples['seq'], truncation=True)
tokenized_ds = ds.map(preprocess_function, batched=True)
# Initialize Trainer
# Note that we're using extremely small batch sizes to maximize
# our ability to fit long sequences in memory!
args = {
"output_dir": "tmp",
"num_train_epochs": 1,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"learning_rate": 2e-5,
}
training_args = TrainingArguments(**args)
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_ds["train"], eval_dataset=tokenized_ds["test"])
result = trainer.train()
print(result)
# Now we can save_pretrained() or push_to_hub() to share the trained model!