Skip to content

Commit

Permalink
Add transformer model
Browse files Browse the repository at this point in the history
  • Loading branch information
BobMcDear committed Jul 24, 2024
0 parents commit 336c92e
Showing 1 changed file with 166 additions and 0 deletions.
166 changes: 166 additions & 0 deletions transformer.apl
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
⎕IO0

:Namespace TRANSFORMER

Model, training, and optimizer hyperparameters
ITERS500 BS16 VCB128 SEQ32 DPTH4 HEADS4 DIM128
LR3E¯4 BETA10.9 BETA20.999 WD1E¯2

Miscellaneous
MASK-1E10×~(SEQ).SEQ
GELUF(2÷1)*0.5

Parameter initialization functions
RND{d10*9 (x y)[1+,](?(2,)d)÷d ((¯2×x)*0.5)×1○○2×y} Samples from normal distribution
LIN{WB{(0.02×RND ) ((1)0)}¨ (WB[;0]) (WB[;1])} Initiliazes weights & biases

Model parameters
WTE⊃⊃1 LIN VCB DIM WPE⊃⊃1 LIN SEQ DIM WH⊃⊃1 LIN DIM VCB
WLN(1+2×DPTH)DIM1 BLN(1+2×DPTH)DIM0
WQKV BQKVDPTH LIN DIM HEADS (DIM÷HEADS) 3 WO BODPTH LIN DIM DIM
W1 B1DPTH LIN DIM (4×DIM) W2 B2DPTH LIN (4×DIM) DIM

Activation tensors for backpropagation
W2INPACTINPW1INPWOINPATTNVKQWQKVINPDPTHWHINPWTEINP
STDMEANLNINP(1+2×DPTH)
CETARGCEPROB

Gradients
∆B2∆W2∆B1∆W1∆BO∆WO∆BQKV∆WQKVDPTH∆WH∆WPE∆WTE
∆BLN∆WLN(1+2×DPTH)

Optimizer states
T0
M1WHM1B2M1W2M1B1M1W1M1BOM1WOM1BQKVM1WQKVM1BLNM1WLNM1WPEM1WTE0
M2WHM2B2M2W2M2B1M2W1M2BOM2WOM2BQKVM2WQKVM2BLNM2WLNM2WPEM2WTE0

Utilities
UNSQZ{((),1)} Inserts axis of size 1 as last dimension
AVG{UNSQZ (+/)÷¯1} Averges along last dimension
SM{exp*-1UNSQZ/ exp÷1UNSQZ+/exp} Applies softmax along last dimension

FWD{
Gets token embeddings
TE{WTE[WTEINP;]}

Adds position embeddings to input
PE{WPE[1;]+2}

Layer-normalizes input
LN{
diff-1MEAN[]AVG LNINP[]
(BLN)+1(WLN)×1diff÷1STD[](1E¯5+AVG diff*2)*0.5
}

Applies multi-headed self-attention to input
MHSA{
qkv(BQKV)+3(WQKVINP[])+.×WQKV
q k v{0 2 1 31qkv}¨3 Q[]q K[]k V[]v
ATTN[]attnSM MASK[2q;2k]+2(q+.×22k)÷(¯1k)*0.5
(BO)+1(WOINP[]()0 2 1 3attn+.×2v)+.×WO
}

Transforms input using multilayer perceptron with one hidden layer
MLP{
ACTINP[]h(B1)+1(W1INP[])+.×W1
(B2)+1(W2INP[]0.5×h×1+7GELUF×h+0.044715×h*3)+.×W2
}

Passes input through a transformer block
BLK{
ind inp
outinp+ind MHSA (2×ind) LN inp
(ind+1) (out+ind MLP (1+2×ind) LN out)
}

Produces next token predictions
HEAD{(WHINP(2×DPTH) LN )+.×WH}

Calculates cross-entropy loss
CE{{(+)÷},(UNSQZ CETARG)1-CEPROBSM }

outHEAD 1BLKDPTH0 (PE TE )
0=:out CE out
}

BWD{
∆TE{∆WTE(WTE)0 ∆WTE[WTEINP;]+}

∆PE{∆WPE(WPE)0 ∆WPE[1;]+ }

∆LN{
∆WLN[]+,[2]×prelin((LNINP)-1MEAN)÷1STD ∆BLN[]+,[2]
wgprod×1WLN
(wgprod-(prelin×1AVG prelin×wgprod)+1AVG wgprod)÷1STD
}

∆MHSA{
∆WO[](,[2]WOINP)+.×,[2] ∆BO[]+,[2]
∆woinp0 2 1 3((2),HEADS (DIM÷HEADS))+.×WO
∆v(2ATTN)+.×2∆woinp
∆qkprod({(ATTN)×-1UNSQZ +/×ATTN}∆woinp+.×22V)÷(¯1K)*0.5
∆k2(2Q)+.×2∆qkprod
∆q∆qkprod+.×2K
∆qkv4 0 2 1 3∆q ∆k ∆v
∆WQKV[](,[2]WQKVINP)+.×,[2]∆qkv ∆BQKV[]+,[2]∆qkv
(,[2+3]∆qkv)+.×,[1+3]WQKV
}

∆MLP{
∆W2[](,[2]W2INP)+.×,[2] ∆B2[]+,[2]
∆w2inp+.×W2
argGELUF×(ACTINP)+0.044715×(ACTINP)*3
∆h∆w2inp×(0.5×1+7arg)+(ACTINP)×0.5×(÷(6arg)*2)×GELUF×1+0.134145×(ACTINP)*2
∆W1[](,[2]W1INP)+.×,[2]∆h ∆B1[]+,[2]∆h
∆h+.×W1
}

∆BLK{
ind ∆out
∆inp∆out+(1+2×ind) ∆LN ind ∆MLP ∆out
(ind-1) (∆inp+(2×ind) ∆LN ind ∆MHSA ∆inp)
}

∆HEAD{∆WH(,[2]WHINP)+.×,[2] (2×DPTH) ∆LN +.×WH}

∆CE{(CEPROB-CETARG.=¯1CEPROB)÷×/CETARG}

_∆TE ∆PE 1∆BLKDPTH(DPTH-1) (∆HEAD ∆CE )
}

TRAIN{
Updates set of parameters using AdamW
OPT{
P ∆P M1 M2
M1(BETA1×M1)+∆P×1-BETA1 M2(BETA2×M2)+∆P×∆P×1-BETA2
(P-LR×(WD×P)+(M1÷1-BETA1*T)÷1E¯8+(M2÷1-BETA2*T)*0.5) M1 M2
}

Performs one training iteration
ITER{
seqdata[(?BS(data)-SEQ+1).+SEQ+1] inp¯11seq targ11seq
losstarg FWD inp _BWD T+1
XOPT WTE ∆WTE M1WTE M2WTE WTE0X M1WTE1X M2WTE2X
XOPT WPE ∆WPE M1WPE M2WPE WPE0X M1WPE1X M2WPE2X
XOPT WLN ∆WLN M1WLN M2WLN WLN0X M1WLN1X M2WLN2X
XOPT BLN ∆BLN M1BLN M2BLN BLN0X M1BLN1X M2BLN2X
XOPT WQKV ∆WQKV M1WQKV M2WQKV WQKV0X M1WQKV1X M2WQKV2X
XOPT BQKV ∆BQKV M1BQKV M2BQKV BQKV0X M1BQKV1X M2BQKV2X
XOPT WO ∆WO M1WO M2WO WO0X M1WO1X M2WO2X
XOPT BO ∆BO M1BO M2BO BO0X M1BO1X M2BO2X
XOPT W1 ∆W1 M1W1 M2W1 W10X M1W11X M2W12X
XOPT B1 ∆B1 M1B1 M2B1 B10X M1B11X M2B12X
XOPT W2 ∆W2 M1W2 M2W2 W20X M1W21X M2W22X
XOPT B2 ∆B2 M1B2 M2B2 B20X M1B21X M2B22X
XOPT WH ∆WH M1WH M2WH WH0X M1WH1X M2WH2X
}

data
_ITERITERS
}

Greedily generates next tokens
GEN{{,()1¯1[1]FWD (-SEQ1)[1]}}

:EndNamespace

0 comments on commit 336c92e

Please sign in to comment.