diff --git a/configs/ag_news/albert/lime-100.jsonnet b/configs/ag_news/albert/lime-100.jsonnet new file mode 100644 index 0000000..450616c --- /dev/null +++ b/configs/ag_news/albert/lime-100.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/albert-base-v2-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/ag_news/bert/lime-100.jsonnet b/configs/ag_news/bert/lime-100.jsonnet new file mode 100644 index 0000000..0dce43e --- /dev/null +++ b/configs/ag_news/bert/lime-100.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/bert-base-uncased-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/ag_news/roberta/lime-100.jsonnet b/configs/ag_news/roberta/lime-100.jsonnet new file mode 100644 index 0000000..1d32911 --- /dev/null +++ b/configs/ag_news/roberta/lime-100.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "ag_news", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/roberta-base-ag-news", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/albert/lime-100.jsonnet b/configs/imdb/albert/lime-100.jsonnet new file mode 100644 index 0000000..53acd44 --- /dev/null +++ b/configs/imdb/albert/lime-100.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/albert-base-v2-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/bert/lime-100.jsonnet b/configs/imdb/bert/lime-100.jsonnet new file mode 100644 index 0000000..607a0d2 --- /dev/null +++ b/configs/imdb/bert/lime-100.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/bert-base-uncased-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/electra/lime-100.jsonnet b/configs/imdb/electra/lime-100.jsonnet new file mode 100644 index 0000000..94cf611 --- /dev/null +++ b/configs/imdb/electra/lime-100.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "monologg/electra-small-finetuned-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/roberta/lime-100.jsonnet b/configs/imdb/roberta/lime-100.jsonnet new file mode 100644 index 0000000..b0a72d0 --- /dev/null +++ b/configs/imdb/roberta/lime-100.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/roberta-base-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/imdb/xlnet/lime-100.jsonnet b/configs/imdb/xlnet/lime-100.jsonnet new file mode 100644 index 0000000..d7b7fb9 --- /dev/null +++ b/configs/imdb/xlnet/lime-100.jsonnet @@ -0,0 +1,34 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "imdb", + "split": "test", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/xlnet-base-cased-imdb", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} \ No newline at end of file diff --git a/configs/mnli/albert/lime-100.jsonnet b/configs/mnli/albert/lime-100.jsonnet new file mode 100644 index 0000000..90ecbde --- /dev/null +++ b/configs/mnli/albert/lime-100.jsonnet @@ -0,0 +1,35 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "prajjwal1/albert-base-v2-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/bert/lime-100.jsonnet b/configs/mnli/bert/lime-100.jsonnet new file mode 100644 index 0000000..4b8f880 --- /dev/null +++ b/configs/mnli/bert/lime-100.jsonnet @@ -0,0 +1,35 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'token_type_ids', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/bert-base-uncased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/electra/lime-100.jsonnet b/configs/mnli/electra/lime-100.jsonnet new file mode 100644 index 0000000..628a38c --- /dev/null +++ b/configs/mnli/electra/lime-100.jsonnet @@ -0,0 +1,35 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "howey/electra-base-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/mnli/roberta/lime-100.jsonnet b/configs/mnli/roberta/lime-100.jsonnet new file mode 100644 index 0000000..116e7d3 --- /dev/null +++ b/configs/mnli/roberta/lime-100.jsonnet @@ -0,0 +1,35 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/roberta-base-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} \ No newline at end of file diff --git a/configs/mnli/xlnet/lime-100.jsonnet b/configs/mnli/xlnet/lime-100.jsonnet new file mode 100644 index 0000000..5f5e2e7 --- /dev/null +++ b/configs/mnli/xlnet/lime-100.jsonnet @@ -0,0 +1,35 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "multi_nli", + "text_field": ["premise", "hypothesis"], + "split": "validation_matched", + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'token_type_ids', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/xlnet-base-cased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "labels"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/albert/lime-100.jsonnet b/configs/xnli/albert/lime-100.jsonnet new file mode 100644 index 0000000..e2fb57f --- /dev/null +++ b/configs/xnli/albert/lime-100.jsonnet @@ -0,0 +1,36 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "prajjwal1/albert-base-v2-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/bert/lime-100.jsonnet b/configs/xnli/bert/lime-100.jsonnet new file mode 100644 index 0000000..7aff957 --- /dev/null +++ b/configs/xnli/bert/lime-100.jsonnet @@ -0,0 +1,36 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/bert-base-uncased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/electra/lime-100.jsonnet b/configs/xnli/electra/lime-100.jsonnet new file mode 100644 index 0000000..fe0bbde --- /dev/null +++ b/configs/xnli/electra/lime-100.jsonnet @@ -0,0 +1,36 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "howey/electra-base-mnli", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/roberta/lime-100.jsonnet b/configs/xnli/roberta/lime-100.jsonnet new file mode 100644 index 0000000..6310552 --- /dev/null +++ b/configs/xnli/roberta/lime-100.jsonnet @@ -0,0 +1,36 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/roberta-base-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/configs/xnli/xlnet/lime-100.jsonnet b/configs/xnli/xlnet/lime-100.jsonnet new file mode 100644 index 0000000..0a96cff --- /dev/null +++ b/configs/xnli/xlnet/lime-100.jsonnet @@ -0,0 +1,36 @@ +{ + "path": "$HOME/experiments/thermostat", + "device": "cuda", + "dataset": { + "name": "xnli", + "subset": "en", + "split": "test", + "text_field": ["premise", "hypothesis"], + "columns": ['input_ids', 'attention_mask', 'token_type_ids', 'special_tokens_mask', 'labels'], + "batch_size": 1, + "root_dir": "$HOME/experiments/thermostat/datasets", + }, + "explainer": { + "name": "LimeBase", + "internal_batch_size": 1, + "n_samples": 100, + "mask_prob": 0.3, + }, + "model": { + "name": "textattack/xlnet-base-cased-MNLI", + "mode_load": "hf", + "path_model": null, + "tokenization": { + "max_length": 512, + "padding": "max_length", + "return_tensors": "np", + "truncation": true, + "special_tokens_mask": true, + } + }, + "visualization": { + "columns": ["attributions", "predictions", "input_ids", "label"], + "gamma": 2.0, + "normalize": true, + } +} diff --git a/demo.ipynb b/demo.ipynb index 6be4b05..38487d9 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -181,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -205,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": { "pycharm": { "name": "#%%\n" @@ -236,22 +236,51 @@ }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/imdb-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)\n" + "Loading Thermostat configuration: ag_news-bert-lime-100\n", + "Downloading and preparing dataset thermostat/ag_news-bert-lime-100 to C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\ag_news-bert-lime-100\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "10bcd13dc63844e4b73e7f382f9e2a7c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/48.2M [00:00\n", " \n", - " \n", - " amazing\n", + " stunt\n", " \n", " \n", - " \n", - " movie\n", + " pilots\n", " \n", " \n", - " \n", - " .\n", + " to\n", " \n", " \n", - " \n", - " some\n", + " snag\n", " \n", " \n", - " \n", - " of\n", + " a\n", " \n", " \n", - " \n", - " the\n", + " falling\n", " \n", " \n", - " \n", - " script\n", + " nasa\n", " \n", " \n", - " \n", - " writing\n", + " craft\n", " \n", " \n", - " \n", - " could\n", + " nasa\n", " \n", " \n", - " \n", - " have\n", + " #\n", " \n", " \n", - " \n", - " been\n", + " 39\n", " \n", " \n", - " \n", - " better\n", + " ;\n", " \n", " \n", - " \n", - " (\n", + " s\n", " \n", " \n", - " \n", - " some\n", + " three\n", " \n", " \n", - " \n", - " cliched\n", + " -\n", " \n", " \n", - " \n", - " language\n", + " year\n", " \n", " \n", - " \n", - " )\n", + " effort\n", " \n", " \n", - " \n", - " .\n", + " to\n", " \n", " \n", - " \n", - " joyce\n", + " bring\n", " \n", " \n", - " \n", - " '\n", + " some\n", " \n", " \n", - " \n", - " s\n", + " genuine\n", " \n", " \n", - " \n", - " "\n", + " star\n", " \n", " \n", - " \n", - " the\n", + " dust\n", " \n", " \n", - " \n", - " dead\n", + " back\n", " \n", " \n", - " \n", - " "\n", + " to\n", " \n", " \n", - " \n", - " is\n", + " earth\n", " \n", " \n", - " \n", - " alluded\n", + " is\n", " \n", " \n", - " \n", - " to\n", + " set\n", " \n", " \n", - " \n", - " throughout\n", + " for\n", " \n", " \n", - " \n", - " the\n", + " a\n", " \n", " \n", - " \n", - " movie\n", + " dramatic\n", " \n", " \n", - " \n", - " .\n", + " finale\n", " \n", " \n", - " \n", - " beautiful\n", + " sept\n", " \n", " \n", - " \n", - " scenery\n", + " .\n", " \n", " \n", - " \n", - " and\n", + " 8\n", " \n", " \n", - " \n", - " great\n", + " when\n", " \n", " \n", - " \n", - " acting\n", + " hollywood\n", " \n", " \n", - " \n", - " .\n", + " helicopter\n", " \n", " \n", - " \n", - " very\n", + " pilots\n", " \n", " \n", - " \n", - " poetic\n", + " will\n", " \n", " \n", - " \n", - " .\n", + " attempt\n", " \n", " \n", - " \n", - " highly\n", + " a\n", " \n", " \n", - " \n", - " recommend\n", + " midair\n", " \n", " \n", - " \n", - " .\n", + " retrieval\n", " \n", " \n", "
\n", - " \n", " [CLS]\n", " \n", " \n", - " \n", - " as\n", + " california\n", " \n", " \n", - " \n", - " recent\n", + " group\n", " \n", " \n", - " \n", - " events\n", + " sues\n", " \n", " \n", - " \n", - " illustrate\n", + " albertson\n", " \n", " \n", - " \n", - " ,\n", + " '\n", " \n", " \n", - " \n", - " trust\n", + " s\n", " \n", " \n", - " \n", - " takes\n", + " over\n", " \n", " \n", - " \n", - " years\n", + " privacy\n", " \n", " \n", - " \n", - " to\n", + " concerns\n", " \n", " \n", - " \n", - " gain\n", + " a\n", " \n", " \n", - " \n", - " but\n", + " california\n", " \n", " \n", - " \n", - " can\n", + " -\n", " \n", " \n", - " \n", - " be\n", + " based\n", " \n", " \n", - " \n", - " lost\n", + " privacy\n", " \n", " \n", - " \n", - " in\n", + " advocacy\n", " \n", " \n", - " \n", - " an\n", + " group\n", " \n", " \n", - " \n", - " instant\n", + " is\n", " \n", " \n", - " \n", - " .\n", + " suing\n", " \n", " \n", - " \n", - " [SEP]\n", + " supermarket\n", " \n", " \n", - " \n", - " trust\n", + " giant\n", " \n", " \n", - " \n", - " ,\n", + " albertson\n", " \n", " \n", - " \n", - " once\n", + " '\n", " \n", " \n", - " \n", - " built\n", + " s\n", " \n", " \n", - " \n", - " ,\n", + " over\n", " \n", " \n", - " \n", - " is\n", + " alleged\n", " \n", " \n", - " \n", - " hard\n", + " privacy\n", " \n", " \n", - " \n", - " to\n", + " violations\n", " \n", " \n", - " \n", - " lose\n", + " involving\n", " \n", " \n", - " \n", + " its\n", + " \n", + " \n", + " \n", + " pharmacy\n", + " \n", + " \n", + " \n", + " customers\n", + " \n", + " \n", + " \n", " .\n", " \n", " \n", - " \n", " [SEP]\n", " \n", @@ -991,166 +1091,141 @@ "name": "stdout", "output_type": "stream", "text": [ - "Model: howey/electra-base-mnli | Pred: contradiction | True: contradiction\n" + "Model: textattack/albert-base-v2-ag-news | Pred: Business | True: Sci/Tech\n" ] }, { "data": { "text/html": [ "
\n", - " \n", " [CLS]\n", " \n", " \n", - " \n", - " as\n", - " \n", - " \n", - " \n", - " recent\n", - " \n", - " \n", - " \n", - " events\n", - " \n", - " \n", - " \n", - " illustrate\n", - " \n", - " \n", - " \n", - " ,\n", - " \n", - " \n", - " \n", - " trust\n", + " california\n", " \n", " \n", - " \n", - " takes\n", + " group\n", " \n", " \n", - " \n", - " years\n", + " sues\n", " \n", " \n", - " \n", - " to\n", + " albertson's\n", " \n", " \n", - " \n", - " gain\n", + " over\n", " \n", " \n", - " \n", - " but\n", + " privacy\n", " \n", " \n", - " \n", - " can\n", + " concerns\n", " \n", " \n", - " \n", - " be\n", + " a\n", " \n", " \n", " \n", - " lost\n", + " california-based\n", " \n", " \n", - " \n", - " in\n", + " privacy\n", " \n", " \n", - " \n", - " an\n", + " advocacy\n", " \n", " \n", - " \n", - " instant\n", + " group\n", " \n", " \n", - " \n", - " .\n", + " is\n", " \n", " \n", - " \n", - " [SEP]\n", + " suing\n", " \n", " \n", - " \n", - " trust\n", + " supermarket\n", " \n", " \n", - " \n", - " ,\n", + " giant\n", " \n", " \n", - " \n", - " once\n", + " albertson's\n", " \n", " \n", - " \n", - " built\n", + " over\n", " \n", " \n", - " \n", - " ,\n", + " alleged\n", " \n", " \n", - " \n", - " is\n", + " privacy\n", " \n", " \n", - " \n", - " hard\n", + " violations\n", " \n", " \n", - " \n", - " to\n", + " involving\n", " \n", " \n", - " \n", - " lose\n", + " its\n", " \n", " \n", - " \n", - " .\n", + " pharmacy\n", " \n", " \n", - " \n", - " [SEP]\n", + " customers.[SEP]\n", " \n", "
" ], @@ -1189,7 +1264,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -1250,14 +1325,14 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-occ/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)\n" + "Reusing dataset thermostat (C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\multi_nli-bert-occ\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b)\n" ] }, { @@ -1265,10 +1340,28 @@ "output_type": "stream", "text": [ "Loading Thermostat configuration: multi_nli-bert-occ\n", + "Dataset path is D:\\Working Student\\repo\\thermostat\\src\\thermostat\\dataset.py\n", + "Additional parameters for loading: {}\n", "Loading Thermostat configuration: multi_nli-bert-lig\n", - "Downloading and preparing dataset thermostat/multi_nli-bert-lig (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f...\n" + "Dataset path is D:\\Working Student\\repo\\thermostat\\src\\thermostat\\dataset.py\n", + "Additional parameters for loading: {}\n", + "Downloading and preparing dataset thermostat/multi_nli-bert-lig to C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\multi_nli-bert-lig\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "032671160c2b41f99c2683f3579196df", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/58.5M [00:00=0.3", "datasets>=1.5", - "jsonnet", + jsonnet, "numpy>=1.20", "overrides", "pandas", @@ -20,7 +28,7 @@ setup( name="thermostat-datasets", - version="1.0.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="1.0.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="Collection of NLP model explanations and accompanying analysis tools", long_description="Thermostat is a large collection of NLP model explanations and accompanying analysis tools. " "Combines explainability methods from the captum library with Hugging Face's datasets and " diff --git a/src/thermostat/data/dataset_utils.py b/src/thermostat/data/dataset_utils.py index 20c6244..223b5e0 100644 --- a/src/thermostat/data/dataset_utils.py +++ b/src/thermostat/data/dataset_utils.py @@ -1,5 +1,8 @@ import numpy as np import os +# new change +from sys import platform +# new change from datasets import Dataset, load_dataset from itertools import groupby from overrides import overrides @@ -345,8 +348,15 @@ def load(config_str: str = None, **kwargs) -> Thermopack: print(f'Loading Thermostat configuration: {config_str}') if ld_kwargs: print(f'Additional parameters for loading: {ld_kwargs}') - dataset_script_path = os.path.dirname(os.path.realpath(__file__)).replace('/thermostat/data', + # new change + if platform == "win32": + dataset_script_path = os.path.dirname(os.path.realpath(__file__)).replace('\\thermostat\\data', + '\\thermostat\\dataset.py') + else: + dataset_script_path = os.path.dirname(os.path.realpath(__file__)).replace('/thermostat/data', '/thermostat/dataset.py') + # new change + data = load_dataset(path=dataset_script_path, name=config_str, split="test", **ld_kwargs) diff --git a/src/thermostat/data/thermostat_configs.py b/src/thermostat/data/thermostat_configs.py index e8c4456..b7b5690 100644 --- a/src/thermostat/data/thermostat_configs.py +++ b/src/thermostat/data/thermostat_configs.py @@ -177,6 +177,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/8SLyHdDgRk2pXSL/download", **_AGNEWS_ALBERT_KWARGS, ), + # new change + ThermostatConfig( + name="ag_news-albert-lime-100", + description="AG News dataset, ALBERT model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/W3GT4ZDT2BzR5mj/download", + **_AGNEWS_ALBERT_KWARGS, + ), + # new change ThermostatConfig( name="ag_news-albert-occlusion", description="AG News dataset, ALBERT model, Occlusion explanations", @@ -212,6 +221,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/rW8MJyAjBGQxsK9/download", **_AGNEWS_BERT_KWARGS, ), + # new change + ThermostatConfig( + name="ag_news-bert-lime-100", + description="AG News dataset, BERT model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/FkSdXZPpN78HSHR/download", + **_AGNEWS_BERT_KWARGS, + ), + # new change ThermostatConfig( name="ag_news-bert-occlusion", description="AG News dataset, BERT model, Occlusion explanations", @@ -247,6 +265,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/qRgBtwfjaXceJoL/download", **_AGNEWS_ROBERTA_KWARGS, ), + # new change + ThermostatConfig( + name="ag_news-roberta-lime-100", + description="AG News dataset, RoBERTa model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/kFyjX2LqBdcW9bp/download", + **_AGNEWS_ROBERTA_KWARGS, + ), + # new change ThermostatConfig( name="ag_news-roberta-occlusion", description="AG News dataset, RoBERTa model, Occlusion explanations", @@ -282,6 +309,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/Tgktb4fq4EdXJNx/download", **_IMDB_ALBERT_KWARGS, ), + # new change + ThermostatConfig( + name="imdb-albert-lime-100", + description="IMDb dataset, ALBERT model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/FzErcT9TcFcG2Pr/download", + **_IMDB_ALBERT_KWARGS, + ), + # new change ThermostatConfig( name="imdb-albert-occ", description="IMDb dataset, ALBERT model, Occlusion explanations", @@ -317,6 +353,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/ZQEdEmFtKeGkYWp/download", **_IMDB_BERT_KWARGS, ), + # new change + ThermostatConfig( + name="imdb-bert-lime-100", + description="IMDb dataset, BERT model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/Qx7z8SFcMTB5bFa/download", + **_IMDB_BERT_KWARGS, + ), + # new change ThermostatConfig( name="imdb-bert-occ", description="IMDb dataset, BERT model, Occlusion explanations", @@ -352,6 +397,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/7p2576kFqiQLL9x/download", **_IMDB_ELECTRA_KWARGS, ), + # new change + ThermostatConfig( + name="imdb-electra-lime-100", + description="IMDb dataset, ELECTRA model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/LBqzn6JiQNzwMAC/download", + **_IMDB_ELECTRA_KWARGS, + ), + # new change ThermostatConfig( name="imdb-electra-occ", description="IMDb dataset, ELECTRA model, Occlusion explanations", @@ -387,6 +441,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/rpsMTw3S6JkQgcF/download", **_IMDB_ROBERTA_KWARGS, ), + # new change + ThermostatConfig( + name="imdb-roberta-lime-100", + description="IMDb dataset, RoBERTa model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/YZsAoJmR4EcwnG2/download", + **_IMDB_ROBERTA_KWARGS, + ), + # new change ThermostatConfig( name="imdb-roberta-occ", description="IMDb dataset, RoBERTa model, Occlusion explanations", @@ -422,6 +485,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/YCDW67f49wj5NXg/download", **_IMDB_XLNET_KWARGS, ), + # new change + ThermostatConfig( + name="imdb-xlnet-lime-100", + description="IMDb dataset, XLNet model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/T2KsA8ragxPz6eL/download", + **_IMDB_XLNET_KWARGS, + ), + # new change ThermostatConfig( name="imdb-xlnet-occ", description="IMDb dataset, XLNet model, Occlusion explanations", @@ -457,6 +529,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/e6JRy9fidSAC5zK/download", **_MNLI_ALBERT_KWARGS, ), + # new change + ThermostatConfig( + name="multi_nli-albert-lime-100", + description="MultiNLI dataset, ALBERT model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/WB2N3nFkHTGkXY8/download", + **_MNLI_ALBERT_KWARGS, + ), + # new change ThermostatConfig( name="multi_nli-albert-occ", description="MultiNLI dataset, ALBERT model, Occlusion explanations", @@ -492,6 +573,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/ptspBexoHaXtqXD/download", **_MNLI_BERT_KWARGS, ), + # new change + ThermostatConfig( + name="multi_nli-bert-lime-100", + description="MultiNLI dataset, BERT model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/LjFccwQ2mCAnsmH/download", + **_MNLI_BERT_KWARGS, + ), + # new change ThermostatConfig( name="multi_nli-bert-occ", description="MultiNLI dataset, BERT model, Occlusion explanations", @@ -527,6 +617,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/WzBwpwC9FoQZCwB/download", **_MNLI_ELECTRA_KWARGS, ), + # new change + ThermostatConfig( + name="multi_nli-electra-lime-100", + description="MultiNLI dataset, ELECTRA model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/TX6jWs9wBdsJA9w/download", + **_MNLI_ELECTRA_KWARGS, + ), + # new change ThermostatConfig( name="multi_nli-electra-occ", description="MultiNLI dataset, ELECTRA model, Occlusion explanations", @@ -562,6 +661,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/dY4z4ptcMtiYzZs/download", **_MNLI_ROBERTA_KWARGS, ), + # new change + ThermostatConfig( + name="multi_nli-roberta-lime-100", + description="MultiNLI dataset, RoBERTa model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/KTQWmCDX2EjHtQE/download", + **_MNLI_ROBERTA_KWARGS, + ), + # new change ThermostatConfig( name="multi_nli-roberta-occ", description="MultiNLI dataset, RoBERTa model, Occlusion explanations", @@ -597,6 +705,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/B7tfLSRKBGYxJ3s/download", **_MNLI_XLNET_KWARGS, ), + # new change + ThermostatConfig( + name="multi_nli-xlnet-lime-100", + description="MultiNLI dataset, XLNet model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/SesZACA2AwyefFp/download", + **_MNLI_XLNET_KWARGS, + ), + # new change ThermostatConfig( name="multi_nli-xlnet-occ", description="MultiNLI dataset, XLNet model, Occlusion explanations", @@ -632,6 +749,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/sijLW3ceigxDsKY/download", **_XNLI_ALBERT_KWARGS, ), + # new change + ThermostatConfig( + name="xnli-albert-lime-100", + description="XNLI dataset, ALBERT model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/oQW5cRc6GbqHtB6/download", + **_XNLI_ALBERT_KWARGS, + ), + # new change ThermostatConfig( name="xnli-albert-occ", description="XNLI dataset, ALBERT model, Occlusion explanations", @@ -667,6 +793,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/KfjqkRTd7FSWSkx/download", **_XNLI_BERT_KWARGS, ), + # new change + ThermostatConfig( + name="xnli-bert-lime-100", + description="XNLI dataset, BERT model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/FXHt989a2En8aZZ/download", + **_XNLI_BERT_KWARGS, + ), + # new change ThermostatConfig( name="xnli-bert-occ", description="XNLI dataset, BERT model, Occlusion explanations", @@ -702,6 +837,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/XnkiHXgxNsptxTJ/download", **_XNLI_ELECTRA_KWARGS, ), + # new change + ThermostatConfig( + name="xnli-electra-lime-100", + description="XNLI dataset, ELECTRA model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/7zNtxCHxEZk2tzC/download", + **_XNLI_ELECTRA_KWARGS, + ), + # new change ThermostatConfig( name="xnli-electra-occ", description="XNLI dataset, ELECTRA model, Occlusion explanations", @@ -737,6 +881,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/pZKo7m4g9WJXfoe/download", **_XNLI_ROBERTA_KWARGS, ), + # new change + ThermostatConfig( + name="xnli-roberta-lime-100", + description="XNLI dataset, RoBERTa model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/CHSR7Arw8M56bxN/download", + **_XNLI_ROBERTA_KWARGS, + ), + # new change ThermostatConfig( name="xnli-roberta-occ", description="XNLI dataset, RoBERTa model, Occlusion explanations", @@ -772,6 +925,15 @@ def __init__( data_url="https://cloud.dfki.de/owncloud/index.php/s/6s4DFPNYpzi8722/download", **_XNLI_XLNET_KWARGS, ), + # new change + ThermostatConfig( + name="xnli-xlnet-lime-100", + description="XNLI dataset, XLNet model, LIME explanations, 100 samples", + explainer="LimeBase", + data_url="https://cloud.dfki.de/owncloud/index.php/s/ZzN9PSkiRrJNza2/download", + **_XNLI_XLNET_KWARGS, + ), + # new change ThermostatConfig( name="xnli-xlnet-occ", description="XNLI dataset, XLNet model, Occlusion explanations", diff --git a/src/thermostat/explain.py b/src/thermostat/explain.py index 2e21044..f297537 100644 --- a/src/thermostat/explain.py +++ b/src/thermostat/explain.py @@ -129,8 +129,11 @@ def from_config(cls, config): res.name_model = config['model']['name'] if config['model']['path_model']: # can be empty when loading a HF model! res.path_model = config['model']['path_model'] - - if not config['model']['class']: + + # new change + # if not config['model']['class']: + if 'class' not in config['model'].keys(): + # new change res.num_labels = len(config['dataset']['label_names']) # TODO: Assert that num_labels in dataset corresponds to classification head in model res.model = AutoModelForSequenceClassification.from_pretrained(res.name_model, num_labels=res.num_labels)