Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Multi query attention #3

Open
wants to merge 9 commits into
base: load-iter
Choose a base branch
from
Open

Conversation

RaymondLi0
Copy link
Collaborator

@RaymondLi0 RaymondLi0 commented Aug 9, 2022

#1
TODO:

  • inference speed benchmark to compare multi-query with multi-head
  • Test with 3D parallelism configuration

@RaymondLi0 RaymondLi0 self-assigned this Aug 9, 2022
@RaymondLi0
Copy link
Collaborator Author

Ran some inference benchmark on an A100 gpu to compare multi-query (MQ) with multi-head (MH) attention. I used:

"--num-layers", "8", "--hidden-size", "1024", "--num-attention-heads", "16", "--seq-length", "1024", "--max-position-embeddings", "1024"


BATCH_SIZE = 512
TOKENS_TO_GENERATE = 128
PROMPT_LENGTH = 128
NUM_BATCHES = 8

Some findings:

Some timers slow-down the inference, significantly more for the MH model than for the MQ model

Times are in ms.
With timers in each layer:
MH: generate: 46721.59 | Transformer forward: 18433.44 | attention forward: 13127.98 | MLP forward: 2462.11
MQ: generate: 39200.74 | Transformer forward: 12065.38 | attention forward: 6762.22 | MLP forward: 2474.14

Only a timer for the whole model:
MH: generate: 40845.17 | Transformer forward: 13884.54
MQ: generate: 37670.10 | Transformer forward: 10263.32

The difference of 3 seconds (in favour of MQ) with a timer on the whole model jumps to 6 seconds when using timers within each layer for some reason. The timers use torch.cuda.synchronize(), which is probably a reason for the slow-down. No idea why the slowdown is bigger for the MH model though.

We end up with a reduction of 26% on the transformer-forward step

when comparing 13884.54 against 10263.32.
However most of the inference time is not spent on model computations, but on other stuff. Using a profiler could be good to find the other bottlenecks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant