Skip to content

Commit

Permalink
free buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
qianhao0713 committed Aug 5, 2024
1 parent 8f43fc1 commit 86167eb
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/llamafactory/easy_context/dist_flash_attn/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from .lightseq_async_attn import _lightseq_forward, _lightseq_backward
from .async_communication import initialize_distributed, reset_global_memory_buffer

import os

# define a global buffer to save flash attention outputs
# it's called global because it saves the outputs for all layers
global_flash_attn_out_buffer = None
Expand All @@ -40,7 +42,7 @@ def clean_hook():

def clear_all_buffers_at_the_end_of_training():
# call it at the end of training
global lobal_flash_attn_out_buffer
global global_flash_attn_out_buffer
global_flash_attn_out_buffer = None
global local_res_grad_buffer
local_res_grad_buffer = None
Expand Down Expand Up @@ -129,6 +131,8 @@ def forward(ctx, run_function, layer_idx, preserve_rng_state, *args):

# save flash attention output to global buffer
save_flash_attn_out_to_global_buffer(ctx.layer_idx, out)
if int(os.getenv('RANK')) == 0:
print(f"forward layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated() // (1<<40):.2f}")
tensor_inputs += [softmax_lse]
ctx.softmax_scale = softmax_scale

Expand Down Expand Up @@ -202,7 +206,9 @@ def backward(ctx, *args):
# write flash attention output gradients to buffer
if ctx.layer_idx > 0:
write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad)

free_flash_attn_out_buffer(ctx.layer_idx)
if int(os.getenv('RANK')) == 0:
print(f"backward layer: {ctx.layer_idx}, MA: {torch.cuda.memory_allocated() // (1<<40):.2f}")
return (None, None, None) + grads


Expand Down

0 comments on commit 86167eb

Please sign in to comment.