Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Commit

Permalink
Merge pull request #101 from BirkhoffG/doc
Browse files Browse the repository at this point in the history
Reorgnaize the getting started tutorial
  • Loading branch information
BirkhoffG authored Jan 18, 2023
2 parents 59ee4c8 + 575859f commit 355a8b7
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 27 deletions.
1 change: 1 addition & 0 deletions nbs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ format:
number-sections: false
number-depth: 2
html-math-method: katex
highlight-style: flatly

website:
twitter-card: true
Expand Down
1 change: 1 addition & 0 deletions nbs/theme.scss
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ $h3-font-size: 1.25em;
// #fa0

$hover-color: #fa0;
$code-block-bg: #f5f5f5;

/*-- scss:rules --*/
a{
Expand Down
85 changes: 58 additions & 27 deletions nbs/tutorials/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,10 @@
"source": [
"# This approach is equivalent to using `TabularDataModuleConfigs`\n",
"data_config_dict = {\n",
" # The name of this dataset is \"adult\"\n",
" \"data_name\": \"adult\",\n",
" # The data file is located in `../assets/data/s_adult.csv`.\n",
" \"data_dir\": \"../assets/data/s_adult.csv\",\n",
" # Contains 2 features with continuous variables\n",
" \"continous_cols\": [\"age\",\"hours_per_week\"],\n",
" # Contains 6 features with categorical (discrete) variables\n",
" \"discret_cols\": [\"workclass\",\"education\",\"marital_status\",\"occupation\",\"race\",\"gender\"],\n",
" # Contains 2 features which we do not wish to change\n",
" \"imutable_cols\": [\"race\",\"gender\"]\n",
"}\n",
"datamodule = TabularDataModule(data_config_dict)"
Expand Down Expand Up @@ -237,25 +232,31 @@
"Next, we use `train_model` to train the model on `TabularDataModule`.\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3b3cf023",
"metadata": {},
"source": [
"### Define the Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e914226",
"metadata": {},
"outputs": [],
"source": [
"from relax.module import PredictiveTrainingModuleConfigs, PredictiveTrainingModule\n",
"from relax.trainer import TrainingConfigs, train_model"
"from relax.module import PredictiveTrainingModuleConfigs, PredictiveTrainingModule"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3b3cf023",
"id": "bc7c8b22",
"metadata": {},
"source": [
"### Define the Model\n",
"\n",
"Defining `PredictiveTrainingModule` is similar with defining `TabularDataModule`.\n",
"We first specify the configurator as `PredictiveTrainingModuleConfigs`,\n",
"and pass this configurator to `PredictiveTrainingModule`.\n"
Expand Down Expand Up @@ -292,11 +293,27 @@
":::\n",
"\n",
"\n",
"### Train the Model\n",
"\n",
"\n",
"### Train the Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29383a39",
"metadata": {},
"outputs": [],
"source": [
"from relax.trainer import TrainingConfigs, train_model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8c36f447",
"metadata": {},
"source": [
"To train `PredictiveTrainingModule` for the entire dataset (specified in `TabularDataModule`),\n",
"we can simply call `train_model`:"
"we can simply call `train_model`:\n"
]
},
{
Expand Down Expand Up @@ -337,7 +354,7 @@
"### Make Predictions\n",
"\n",
"The forward pass is done via `PredictiveTrainingModule.forward`. \n",
"We wrap the `pred_fn` as follows:"
"We wrap the `pred_fn` as follows, which will be called later:"
]
},
{
Expand All @@ -357,7 +374,9 @@
"id": "4d1f2129",
"metadata": {},
"source": [
"## Generate Counterfactual Explanations"
"## Generate Counterfactual Explanations\n",
"\n",
"Now, it is time to use `ReLax` to generate counterfactual explanations (or recourse)."
]
},
{
Expand All @@ -367,16 +386,17 @@
"metadata": {},
"outputs": [],
"source": [
"from relax.methods import VanillaCF\n",
"from relax.evaluate import generate_cf_explanations"
"from relax.methods import VanillaCF, VanillaCFConfig"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f4a05346",
"metadata": {},
"source": [
"Setup the counterfactual method. Here we use `VanillaCF`."
"We use `VanillaCF` as an example for this tutorial.\n",
"Defining `VanillaCF` is similar to define `TabularDataModule` and `PredictiveTrainingModule`."
]
},
{
Expand All @@ -386,12 +406,11 @@
"metadata": {},
"outputs": [],
"source": [
"cf_configs = { \n",
" 'n_steps': 1000, # Number of steps\n",
" 'lr': 0.001 # Learning rate\n",
"}\n",
"\n",
"cf_exp = VanillaCF(cf_configs)"
"cf_config = VanillaCFConfig(\n",
" n_steps=1000, # Number of steps\n",
" lr=0.001 # Learning rate\n",
")\n",
"cf_exp = VanillaCF(cf_config)"
]
},
{
Expand All @@ -403,6 +422,16 @@
"Generate counterfactual examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1117363",
"metadata": {},
"outputs": [],
"source": [
"from relax.evaluate import generate_cf_explanations"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -419,9 +448,11 @@
],
"source": [
"cf_results = generate_cf_explanations(\n",
" cf_exp, datamodule, pred_fn, pred_fn_args={\n",
" cf_exp, datamodule, pred_fn, \n",
" pred_fn_args={\n",
" 'params': params, 'rng_key': jax.random.PRNGKey(0)\n",
"})"
" }\n",
")"
]
},
{
Expand Down

0 comments on commit 355a8b7

Please sign in to comment.