Skip to content

Commit

Permalink
update fp8 handler
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Nov 5, 2024
1 parent 92e40ab commit c6ddc54
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
TensorParallelCommunicator,
)
from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler
from internlm.core.quantization.fp8handler import Float8Handler
from internlm.core.trainer import TrainState
from internlm.data.utils import unpack_type_ids
from internlm.model.builder import create_model
Expand Down Expand Up @@ -278,6 +279,12 @@ def inject_model(model):
if hasattr(model, IS_INJECTED) and getattr(model, IS_INJECTED):
return model

# FP8 Linear and compile model
if hasattr(gpc.config, "use_fp8") and gpc.config.get("use_fp8", False):
float8_handler = Float8Handler()
float8_handler.convert_to_float8_training(model)
model = torch.compile(model)

inject_model_helper(model, inject_info=gpc.config.model.get("inject_info", None))

# should be set before NaiveAMPModel
Expand Down

0 comments on commit c6ddc54

Please sign in to comment.