Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable model and data sharding #96

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
52e96ea
edit installation instructions in readme
gianlucadetommaso May 15, 2023
5e0076d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
4c7fd28
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
6cb6581
bump up version
gianlucadetommaso May 15, 2023
1b39780
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
cb2b49a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
14e3ca4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 25, 2023
580067d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 27, 2023
048ef09
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 2, 2023
ad542a4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
41417c1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
64be374
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
a2d0f34
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
66bba06
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
911aa82
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
01f959b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
79f8dca
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
99a3b78
add sequence probit
gianlucadetommaso Jun 19, 2023
1c23a9e
add possibility to run sequential probit on last steps only
gianlucadetommaso Jun 20, 2023
4dea50f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 21, 2023
915a1ea
Merge branch 'main' into seqprobit
gianlucadetommaso Jun 21, 2023
e966745
refactor sequential probit implementation
gianlucadetommaso Jun 23, 2023
529f9aa
add stop gradient flag
gianlucadetommaso Jun 24, 2023
42d2117
pre-commit
gianlucadetommaso Jun 24, 2023
734f597
add probit options in example script
gianlucadetommaso Jun 25, 2023
404840e
mesh
gianlucadetommaso Jun 25, 2023
4444907
enable model and data sharding
gianlucadetommaso Jun 25, 2023
830fbe8
make further changes after training roberta
gianlucadetommaso Jul 11, 2023
e3e1c4f
further changes
gianlucadetommaso Jul 16, 2023
6d47a47
refactoring laplace
gianlucadetommaso Jul 17, 2023
ed571de
start debugging swag
gianlucadetommaso Jul 18, 2023
1ced008
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
6992692
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
b2540c1
make small change in readme because of publish to pypi error
gianlucadetommaso Jul 18, 2023
2362998
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
ba52081
debug deep ensemble
gianlucadetommaso Jul 18, 2023
d2fc289
fix sghmc and sgld
gianlucadetommaso Jul 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions benchmarks/transformers/masked_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:

try:
logger.info(list(pathlib.Path(args.restore_checkpoint_dir).rglob("*")))
restore_checkpoint_path = unpack_model_tar(
restore_checkpoint_dir = unpack_model_tar(
list(pathlib.Path(args.restore_checkpoint_dir).rglob("*"))[0]
)
logger.info(list(pathlib.Path(restore_checkpoint_path).rglob("*")))
logger.info(list(pathlib.Path(restore_checkpoint_dir).rglob("*")))
except:
logger.info("No checkpoint to restore")
restore_checkpoint_path = None
restore_checkpoint_dir = None

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

Expand Down Expand Up @@ -303,11 +303,11 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
#### TRAIN! ####
#####################################
def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray:
if preds.ndim > 2:
if preds.ndim > 1:
raise ValueError(
"""`preds` must be a one-dimensional array of predicted classes."""
)
if targets.ndim > 2:
if targets.ndim > 1:
raise ValueError(
"""`targets` must be a one-dimensional array of target classes."""
)
Expand Down Expand Up @@ -341,7 +341,7 @@ def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray:
save_checkpoint_dir=args.save_checkpoint_dir,
save_every_n_steps=args.save_every_n_steps,
keep_top_n_checkpoints=args.keep_top_n_checkpoints,
restore_checkpoint_path=restore_checkpoint_path,
restore_checkpoint_dir=restore_checkpoint_dir,
),
)
if args.last_layer_only and (
Expand All @@ -357,7 +357,7 @@ def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray:
and args.last_layer_only
else None,
)
if restore_checkpoint_path is not None:
if restore_checkpoint_dir is not None:
fit_config.optimizer = last_layer_optimizer
train_kwargs = {"fit_config": fit_config}
else:
Expand Down
43 changes: 35 additions & 8 deletions benchmarks/transformers/prob_model_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
accuracy,
expected_calibration_error,
)
from fortuna.model_editor import ProbitModelEditor
from fortuna.prob_model import (
ADVIPosteriorApproximator,
DeepEnsemblePosteriorApproximator,
Expand Down Expand Up @@ -213,6 +214,11 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
parser.add_argument("--sgmcmc_polynomial_schedule_gamma", type=float, default=0.55)
parser.add_argument("--sgmcmc_preconditioner", type=strbool, default=False)
parser.add_argument("--sghmc_momentum_decay", type=float, default=0.01)
# model editor
parser.add_argument("--enable_probit_model_editor", type=strbool, default=False)
parser.add_argument("--probit_init_log_var", type=float, default=-5)
parser.add_argument("--probit_stop_gradient", type=strbool, default=False)
parser.add_argument("--probit_last_layer_only", type=strbool, default=False)
# optimizer
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--adam_eps", type=float, default=1e-8)
Expand All @@ -234,13 +240,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:

try:
logger.info(list(pathlib.Path(args.load_model_dir).rglob("*")))
restore_checkpoint_path = unpack_model_tar(
restore_checkpoint_dir = unpack_model_tar(
list(pathlib.Path(args.load_model_dir).rglob("*"))[0]
)
logger.info(list(pathlib.Path(restore_checkpoint_path).rglob("*")))
logger.info(list(pathlib.Path(restore_checkpoint_dir).rglob("*")))
except:
logger.info("No checkpoint to restore")
restore_checkpoint_path = None
restore_checkpoint_dir = None

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

Expand Down Expand Up @@ -392,6 +398,21 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
),
}

model_editor = None
if args.enable_probit_model_editor:
probit_freeze_fun = (
lambda p, v: True
if "classifier" in p
else False
if args.probit_last_layer_only
else None
)
model_editor = ProbitModelEditor(
freeze_fun=probit_freeze_fun,
init_log_var=args.probit_init_log_var,
stop_gradient=args.probit_stop_gradient,
)

### TRAINING
prob_model = ProbClassifier(
model=model,
Expand All @@ -400,6 +421,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
],
prior=IsotropicGaussianPrior(log_var=args.prior_log_var),
output_calibrator=None,
model_editor=model_editor
)

fit_config = FitConfig(
Expand All @@ -422,7 +444,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
save_checkpoint_dir=args.output_data_dir,
save_every_n_steps=args.save_every_n_steps,
keep_top_n_checkpoints=args.keep_top_n_checkpoints,
restore_checkpoint_path=restore_checkpoint_path,
restore_checkpoint_dir=restore_checkpoint_dir,
),
callbacks=[
ResetCovarianceCallback(
Expand Down Expand Up @@ -453,7 +475,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
last_layer_optimizer = FitOptimizer(
method=optimizer, n_epochs=args.num_train_epochs, freeze_fun=freeze_fun
)
if restore_checkpoint_path is not None:
if restore_checkpoint_dir is not None:
fit_config.optimizer = last_layer_optimizer
train_kwargs = {"fit_config": fit_config}
else:
Expand All @@ -478,11 +500,16 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
calib_data_loader=None,
**train_kwargs,
)
elif restore_checkpoint_path is not None:
prob_model.load_state(restore_checkpoint_path)
elif restore_checkpoint_dir is not None:
prob_model.load_state(restore_checkpoint_dir)
else:
raise ValueError(
"Either restore_checkpoint_path or num_train_epochs > 0 should be specified."
"Either restore_checkpoint_dir or num_train_epochs > 0 should be specified."
)

if args.enable_probit_model_editor:
logger.info(
f"Probit log-variance: {prob_model.posterior.state.get().params['model_editor']['params']['log_var']}"
)

### IN-D PERFORMANCE
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ hparams:
per_device_eval_batch_size: 32
per_device_train_batch_size: 32
learning_rate: 2e-05
num_warmup_steps: 10000
num_warmup_steps: 500
prior_log_var: 100.0
weight_decay: 0.01
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ Please find their references below.

.. automodule:: fortuna.output_calib_model.classification
:members:
:exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint
:exclude-members: save_checkpoint, restore_checkpoint

.. _output_calib_regressor:

.. automodule:: fortuna.output_calib_model.regression
:members:
:exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint
:exclude-members: save_checkpoint, restore_checkpoint

.. _output_calib_base:

.. automodule:: fortuna.output_calib_model.base
:members:
:exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint
:exclude-members: save_checkpoint, restore_checkpoint

.. toctree::
:maxdepth: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/scaling_up_bayesian_inference.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray:

# We are ready to call `prob_model.train`, which will perform posterior inference under-the-hood. In order to do Bayesian inference on the last layer only and freeze the other parameters, all we need to do is to pass a function `freeze_fun` to the optimizer configuration object, deciding which parameters should be "frozen" and which should be "trainable".
#
# In addition, we configure `map_fit_config` to make a preliminary run with MAP, and set the frozen parameters to a meaningful value. Alternatively, if any of these is available, you can also either restore an existing checkpoint by configuring `FitCheckpointer.restore_checkpoint_path`, or start from a current state by setting `FitCheckpointer.start_from_current_state` to `True`.
# In addition, we configure `map_fit_config` to make a preliminary run with MAP, and set the frozen parameters to a meaningful value. Alternatively, if any of these is available, you can also either restore an existing checkpoint by configuring `FitCheckpointer.restore_checkpoint_dir`, or start from a current state by setting `FitCheckpointer.start_from_current_state` to `True`.

from fortuna.prob_model import FitConfig, FitOptimizer

Expand Down
Loading