Skip to content

Transformer implementation with simple scikit-like API

Notifications You must be signed in to change notification settings

mhulden/eztransformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 

Repository files navigation


EZTransformer is a Transformer implementation that allows you to train translation models with simple API calls.

  • Dependencies: PyTorch, tqdm

Quickstart

from eztr import EZTransformer
import pickle

# Have data
mydata = pickle.load(open("data/sigmorphon2016spanish.p", "rb")) # Example word inflection data for Spanish

# Data format (whitespace-separated tokens, two-tuples for input/target)
>>> mydata['train'][:5]

[('a b a b o l # N PL', 'a b a b o l e s'),
 ('á b a c o # N SG', 'á b a c o'),
 ('a b a c o r a r # V IND PRS 1 PL IPFV/PFV', 'a b a c o r a m o s'),
 ('a b a c o r a r # V IND PRS 3 PL IPFV/PFV', 'a b a c o r a n'),
 ('a b a d e r n a r # V COND 1 PL', 'a b a d e r n a r í a m o s')]


# Initialize model
trf = EZTransformer(device = 'cuda')  # Change device as needed; 'cuda' (NVIDIA), 'mps' (Apple), or 'cpu'

# Train model
trf.fit(mydata['train'], valid_data = mydata['valid'], print_validation_examples = 2, max_epochs = 100)

# Load back the best model wrt validation set (or skip this step to use final weights with trf directly)
trf = EZTransformer(load_model = "best_model.pt")

# Make Predictions
>>> trf.predict(["c o m p r o m e t e r # V IND PST 3 PL PFV", "h a b l a r # V IND PST 1 SG IPFV"])
   ['c o m p r o m e t i e r o n', 'h a b l a b a']

# Evaluate on test set
trf.score(mydata['test_inputs'], mydata['test_targets'])

About

Transformer implementation with simple scikit-like API

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages