Skip to content

Commit

Permalink
Add deepspeed config parse error message
Browse files Browse the repository at this point in the history
  • Loading branch information
HHousen committed Feb 26, 2021
1 parent 322f21b commit 2be444e
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main(args):
)
args.callbacks.append(custom_checkpoint_callback)

if args.plugins.startswith("deepspeed"):
if args.plugins and args.plugins.startswith("deepspeed"):
deepspeed_config_path = args.plugins.split(":")[1]
with open(deepspeed_config_path, "r") as deepspeed_config_file:
deepspeed_config = json.load(deepspeed_config_file)
Expand Down Expand Up @@ -431,6 +431,15 @@ def main(args):
"You must specify the `--weights_save_path` to use `--custom_checkpoint_every_n`."
)

if (
main_args[0].plugins
and main_args[0].plugins.startswith("deepspeed")
and (":" not in main_args[0].plugins)
):
logger.error(
"If you are using the 'deepspeed' plugin, you must specify the path the to deepspeed config like so: `--plugins deepspeed:/path/to/config.json`."
)

main_args = parser.parse_args()

# Setup logging config
Expand Down

0 comments on commit 2be444e

Please sign in to comment.