Skip to content

Commit

Permalink
Expose mixed_precision dtype arguments
Browse files Browse the repository at this point in the history
add training.mixed_precision_param and .mixed_precision_reduce options

refactor a util to map strings to torch dtypes

ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796
Pull Request resolved: #348
  • Loading branch information
wconstab authored and tianyu-l committed May 28, 2024
1 parent 3602679 commit 0bc9ecb
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
11 changes: 2 additions & 9 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,10 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torchtitan.config_manager import JobConfig
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger


DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


class IntervalType(enum.Enum):
SECONDS = enum.auto()
STEPS = enum.auto()
Expand Down Expand Up @@ -141,7 +134,7 @@ def __init__(
self.pg = dist.new_group(backend="gloo")

self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype]
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]

self.mp = None
async_mode = ckpt_config.async_mode.lower()
Expand Down
29 changes: 29 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@
from collections import defaultdict
from typing import Tuple, Union

import torch

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib

from torchtitan.logging_utils import logger

TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


class JobConfig:
"""
Expand Down Expand Up @@ -207,6 +215,26 @@ def __init__(self):
default=1,
help="Pipeline Parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
default="bfloat16",
choices=["bfloat16", "float32"],
help="""
torch dtype to use for parameters when applying mixed precision via FSDP.
This feature only takes effect when data_parallel_degree > 1
""",
)
self.parser.add_argument(
"--training.mixed_precision_reduce",
type=str,
default="float32",
choices=["float32"],
help="""
torch dtype to use for reductions when applying mixed precision via FSDP.
This feature only takes effect when data_parallel_degree > 1
""",
)
self.parser.add_argument(
"--training.compile",
action="store_true",
Expand Down Expand Up @@ -275,6 +303,7 @@ def __init__(self):
"--checkpoint.export_dtype",
type=str,
default="float32",
choices=["float16", "bfloat16", "float32"],
help="""
Converts to the specified precision when training completes and model_weights_only=true.
Currently supports float32, float16, and bfloat16.
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint

from torchtitan.config_manager import JobConfig
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger


Expand Down Expand Up @@ -209,9 +209,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
Expand Down

0 comments on commit 0bc9ecb

Please sign in to comment.