-
Notifications
You must be signed in to change notification settings - Fork 0
/
llama_test.py
56 lines (39 loc) · 1.41 KB
/
llama_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Literal
import equinox as eqx
import jax
import jax.numpy as jnp
from jax.tree import leaves, map
from jaxtyping import Array, PyTree
from kira.model_args import LLaMAModelArgs
from kira.models.llama.llama2 import LLaMA
from memory_profiler import profile
def find_min_max(pytree: PyTree) -> tuple[float, float]:
return min(map(lambda x: jnp.min(x), leaves(pytree))), max(
map(lambda x: jnp.max(x), leaves(pytree))
)
def quantize(pytree: PyTree, bits: Literal[8, 16]) -> PyTree:
quantized_dtype = None
match bits:
case 8:
quantized_dtype = jnp.uint8
case 16:
quantized_dtype = jnp.float16
pytree = eqx.filter(pytree, eqx.is_inexact_array)
min_val, max_val = find_min_max(pytree)
scale = (max_val - min_val) / (2**bits - 1)
def quantize_array(x: Array) -> Array:
return jnp.array(jnp.round(x / scale) * scale, dtype=quantized_dtype)
return map(quantize_array, pytree)
@profile
def main():
llama_args = LLaMAModelArgs(dim=4096, n_layers=4, n_heads=4, vocab_size=384)
llama, state = eqx.nn.make_with_state(LLaMA)(
model_args=llama_args, key=jax.random.PRNGKey(-1)
)
# llama = quantize(llama, 8)
llama = quantize(llama, 16)
# test_x = jnp.ones((8,), dtype=jnp.int32)
# y, state = llama(test_x, state, key=None, inference=True)
# print(y.shape)
if __name__ == "__main__":
main()