Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
Merge pull request #112 from jsadler2/111-train-model-cli
Browse files Browse the repository at this point in the history
[#111] train_model.py -> train_model_cli.py
  • Loading branch information
jsadler2 authored Jun 15, 2021
2 parents 9d2d340 + 08762b3 commit 0c78af2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ rule prep_io_data:
# shell:
# """
# module load analytics cuda10.1/toolkit/10.1.105
# run_training -e /home/jsadler/.conda/envs/rgcn --no-node-list "python {code_dir}/train_model.py -o {params.run_dir} -i {input[0]} -p {params.pt_epochs} -f {params.ft_epochs} --lambdas {params.lamb} --loss_func multitask_rmse --model rgcn -s 135"
# run_training -e /home/jsadler/.conda/envs/rgcn --no-node-list "python {code_dir}/train_model_cli.py -o {params.run_dir} -i {input[0]} -p {params.pt_epochs} -f {params.ft_epochs} --lambdas {params.lamb} --loss_func multitask_rmse --model rgcn -s 135"
# """


Expand Down
6 changes: 6 additions & 0 deletions river_dl/train_model.py → river_dl/train_model_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
This file provides a commandline interface (CLI) for the `train.train_model`
function. The commandline interface was originally provided to allow a command
to be sent to a slurm scheduler which was necessary to train the model using
GPUs. This has been tested on USGS's Tallgrass supercomputer.
"""
import argparse
from river_dl.train import train_model
import river_dl.loss_functions as lf
Expand Down

0 comments on commit 0c78af2

Please sign in to comment.