\n",
- "
\n",
" [CLS]\n",
" \n",
" \n",
- "
\n",
- " as\n",
+ " california\n",
" \n",
" \n",
- "
\n",
- " recent\n",
+ " group\n",
" \n",
" \n",
- "
\n",
- " events\n",
+ " sues\n",
" \n",
" \n",
- "
\n",
- " illustrate\n",
+ " albertson\n",
" \n",
" \n",
- "
\n",
- " ,\n",
+ " '\n",
" \n",
" \n",
- "
\n",
- " trust\n",
+ " s\n",
" \n",
" \n",
- "
\n",
- " takes\n",
+ " over\n",
" \n",
" \n",
- "
\n",
- " years\n",
+ " privacy\n",
" \n",
" \n",
- "
\n",
- " to\n",
+ " concerns\n",
" \n",
" \n",
- "
\n",
- " gain\n",
+ " a\n",
" \n",
" \n",
- "
\n",
- " but\n",
+ " california\n",
" \n",
" \n",
- "
\n",
- " can\n",
+ " -\n",
" \n",
" \n",
- "
\n",
- " be\n",
+ " based\n",
" \n",
" \n",
- "
\n",
- " lost\n",
+ " privacy\n",
" \n",
" \n",
- "
\n",
- " in\n",
+ " advocacy\n",
" \n",
" \n",
- "
\n",
- " an\n",
+ " group\n",
" \n",
" \n",
- "
\n",
- " instant\n",
+ " is\n",
" \n",
" \n",
- "
\n",
- " .\n",
+ " suing\n",
" \n",
" \n",
- "
\n",
- " [SEP]\n",
+ " supermarket\n",
" \n",
" \n",
- "
\n",
- " trust\n",
+ " giant\n",
" \n",
" \n",
- "
\n",
- " ,\n",
+ " albertson\n",
" \n",
" \n",
- "
\n",
- " once\n",
+ " '\n",
" \n",
" \n",
- "
\n",
- " built\n",
+ " s\n",
" \n",
" \n",
- "
\n",
- " ,\n",
+ " over\n",
" \n",
" \n",
- "
\n",
- " is\n",
+ " alleged\n",
" \n",
" \n",
- "
\n",
- " hard\n",
+ " privacy\n",
" \n",
" \n",
- "
\n",
- " to\n",
+ " violations\n",
" \n",
" \n",
- "
\n",
- " lose\n",
+ " involving\n",
" \n",
" \n",
- "
\n",
+ " its\n",
+ " \n",
+ " \n",
+ "
\n",
+ " pharmacy\n",
+ " \n",
+ " \n",
+ "
\n",
+ " customers\n",
+ " \n",
+ " \n",
+ "
\n",
" .\n",
" \n",
" \n",
- "
\n",
" [SEP]\n",
" \n",
@@ -991,166 +1091,141 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Model: howey/electra-base-mnli | Pred: contradiction | True: contradiction\n"
+ "Model: textattack/albert-base-v2-ag-news | Pred: Business | True: Sci/Tech\n"
]
},
{
"data": {
"text/html": [
"
\n",
- " \n",
" [CLS]\n",
" \n",
" \n",
- " \n",
- " as\n",
- " \n",
- " \n",
- " \n",
- " recent\n",
- " \n",
- " \n",
- " \n",
- " events\n",
- " \n",
- " \n",
- " \n",
- " illustrate\n",
- " \n",
- " \n",
- " \n",
- " ,\n",
- " \n",
- " \n",
- " \n",
- " trust\n",
+ " california\n",
" \n",
" \n",
- " \n",
- " takes\n",
+ " group\n",
" \n",
" \n",
- " \n",
- " years\n",
+ " sues\n",
" \n",
" \n",
- " \n",
- " to\n",
+ " albertson's\n",
" \n",
" \n",
- " \n",
- " gain\n",
+ " over\n",
" \n",
" \n",
- " \n",
- " but\n",
+ " privacy\n",
" \n",
" \n",
- " \n",
- " can\n",
+ " concerns\n",
" \n",
" \n",
- " \n",
- " be\n",
+ " a\n",
" \n",
" \n",
" \n",
- " lost\n",
+ " california-based\n",
" \n",
" \n",
- " \n",
- " in\n",
+ " privacy\n",
" \n",
" \n",
- " \n",
- " an\n",
+ " advocacy\n",
" \n",
" \n",
- " \n",
- " instant\n",
+ " group\n",
" \n",
" \n",
- " \n",
- " .\n",
+ " is\n",
" \n",
" \n",
- " \n",
- " [SEP]\n",
+ " suing\n",
" \n",
" \n",
- " \n",
- " trust\n",
+ " supermarket\n",
" \n",
" \n",
- " \n",
- " ,\n",
+ " giant\n",
" \n",
" \n",
- " \n",
- " once\n",
+ " albertson's\n",
" \n",
" \n",
- " \n",
- " built\n",
+ " over\n",
" \n",
" \n",
- " \n",
- " ,\n",
+ " alleged\n",
" \n",
" \n",
- " \n",
- " is\n",
+ " privacy\n",
" \n",
" \n",
- " \n",
- " hard\n",
+ " violations\n",
" \n",
" \n",
- " \n",
- " to\n",
+ " involving\n",
" \n",
" \n",
- " \n",
- " lose\n",
+ " its\n",
" \n",
" \n",
- " \n",
- " .\n",
+ " pharmacy\n",
" \n",
" \n",
- " \n",
- " [SEP]\n",
+ " customers.[SEP]\n",
" \n",
"
"
],
@@ -1189,7 +1264,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
@@ -1250,14 +1325,14 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-occ/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)\n"
+ "Reusing dataset thermostat (C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\multi_nli-bert-occ\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b)\n"
]
},
{
@@ -1265,10 +1340,28 @@
"output_type": "stream",
"text": [
"Loading Thermostat configuration: multi_nli-bert-occ\n",
+ "Dataset path is D:\\Working Student\\repo\\thermostat\\src\\thermostat\\dataset.py\n",
+ "Additional parameters for loading: {}\n",
"Loading Thermostat configuration: multi_nli-bert-lig\n",
- "Downloading and preparing dataset thermostat/multi_nli-bert-lig (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f...\n"
+ "Dataset path is D:\\Working Student\\repo\\thermostat\\src\\thermostat\\dataset.py\n",
+ "Additional parameters for loading: {}\n",
+ "Downloading and preparing dataset thermostat/multi_nli-bert-lig to C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\multi_nli-bert-lig\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...\n"
]
},
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "032671160c2b41f99c2683f3579196df",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading: 0%| | 0.00/58.5M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
@@ -1277,7 +1370,7 @@
"version_minor": 0
},
"text/plain": [
- "HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…"
+ "0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
@@ -1287,11 +1380,27 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Dataset thermostat downloaded and prepared to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f. Subsequent calls will reuse this data.\n",
+ "Dataset thermostat downloaded and prepared to C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\multi_nli-bert-lig\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b. Subsequent calls will reuse this data.\n",
"Loading Thermostat configuration: multi_nli-bert-lime\n",
- "Downloading and preparing dataset thermostat/multi_nli-bert-lime (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lime/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f...\n"
+ "Dataset path is D:\\Working Student\\repo\\thermostat\\src\\thermostat\\dataset.py\n",
+ "Additional parameters for loading: {}\n",
+ "Downloading and preparing dataset thermostat/multi_nli-bert-lime to C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\multi_nli-bert-lime\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b...\n"
]
},
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "419d61c9f31343c3ac6b55bf386170d9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading: 0%| | 0.00/59.4M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
@@ -1300,7 +1409,7 @@
"version_minor": 0
},
"text/plain": [
- "HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…"
+ "0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
@@ -1310,7 +1419,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Dataset thermostat downloaded and prepared to /home/nfel/.cache/huggingface/datasets/thermostat/multi_nli-bert-lime/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f. Subsequent calls will reuse this data.\n"
+ "Dataset thermostat downloaded and prepared to C:\\Users\\49176\\.cache\\huggingface\\datasets\\thermostat\\multi_nli-bert-lime\\1.0.1\\0cbe93e1fbe5b8ed0217559442d8b49a80fd4c2787185f2d7940817c67d8707b. Subsequent calls will reuse this data.\n"
]
}
],
@@ -1323,7 +1432,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 15,
"metadata": {
"scrolled": true
},
@@ -1870,38 +1979,9 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/imdb-bert-lime/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Loading Thermostat configuration: imdb-bert-lime\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Reusing dataset thermostat (/home/nfel/.cache/huggingface/datasets/thermostat/imdb-bert-lig/1.0.0/d4c1fec14831f7d2677ccb8fba33151c9fea8119c4921e647af71cb81299899f)\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Loading Thermostat configuration: imdb-bert-lig\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"imdb_lime = thermostat.load(\"imdb-bert-lime\")\n",
"imdb_intg = thermostat.load(\"imdb-bert-lig\")"
@@ -2015,7 +2095,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -2029,7 +2109,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.9"
+ "version": "3.9.7"
}
},
"nbformat": 4,
diff --git a/setup.py b/setup.py
index 6a3513e..d3bb012 100644
--- a/setup.py
+++ b/setup.py
@@ -1,9 +1,17 @@
from setuptools import find_packages, setup
+# new change
+from sys import platform
+
+if platform == "win32":
+ jsonnet = "jsonnet-binary"
+else:
+ jsonnet = "jsonnet"
+# new change
REQUIRED_PKGS = [
"captum>=0.3",
"datasets>=1.5",
- "jsonnet",
+ jsonnet,
"numpy>=1.20",
"overrides",
"pandas",
@@ -20,7 +28,7 @@
setup(
name="thermostat-datasets",
- version="1.0.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="1.0.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Collection of NLP model explanations and accompanying analysis tools",
long_description="Thermostat is a large collection of NLP model explanations and accompanying analysis tools. "
"Combines explainability methods from the captum library with Hugging Face's datasets and "
diff --git a/src/thermostat/data/dataset_utils.py b/src/thermostat/data/dataset_utils.py
index 20c6244..223b5e0 100644
--- a/src/thermostat/data/dataset_utils.py
+++ b/src/thermostat/data/dataset_utils.py
@@ -1,5 +1,8 @@
import numpy as np
import os
+# new change
+from sys import platform
+# new change
from datasets import Dataset, load_dataset
from itertools import groupby
from overrides import overrides
@@ -345,8 +348,15 @@ def load(config_str: str = None, **kwargs) -> Thermopack:
print(f'Loading Thermostat configuration: {config_str}')
if ld_kwargs:
print(f'Additional parameters for loading: {ld_kwargs}')
- dataset_script_path = os.path.dirname(os.path.realpath(__file__)).replace('/thermostat/data',
+ # new change
+ if platform == "win32":
+ dataset_script_path = os.path.dirname(os.path.realpath(__file__)).replace('\\thermostat\\data',
+ '\\thermostat\\dataset.py')
+ else:
+ dataset_script_path = os.path.dirname(os.path.realpath(__file__)).replace('/thermostat/data',
'/thermostat/dataset.py')
+ # new change
+
data = load_dataset(path=dataset_script_path,
name=config_str, split="test", **ld_kwargs)
diff --git a/src/thermostat/data/thermostat_configs.py b/src/thermostat/data/thermostat_configs.py
index e8c4456..b7b5690 100644
--- a/src/thermostat/data/thermostat_configs.py
+++ b/src/thermostat/data/thermostat_configs.py
@@ -177,6 +177,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/8SLyHdDgRk2pXSL/download",
**_AGNEWS_ALBERT_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="ag_news-albert-lime-100",
+ description="AG News dataset, ALBERT model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/W3GT4ZDT2BzR5mj/download",
+ **_AGNEWS_ALBERT_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="ag_news-albert-occlusion",
description="AG News dataset, ALBERT model, Occlusion explanations",
@@ -212,6 +221,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/rW8MJyAjBGQxsK9/download",
**_AGNEWS_BERT_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="ag_news-bert-lime-100",
+ description="AG News dataset, BERT model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/FkSdXZPpN78HSHR/download",
+ **_AGNEWS_BERT_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="ag_news-bert-occlusion",
description="AG News dataset, BERT model, Occlusion explanations",
@@ -247,6 +265,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/qRgBtwfjaXceJoL/download",
**_AGNEWS_ROBERTA_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="ag_news-roberta-lime-100",
+ description="AG News dataset, RoBERTa model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/kFyjX2LqBdcW9bp/download",
+ **_AGNEWS_ROBERTA_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="ag_news-roberta-occlusion",
description="AG News dataset, RoBERTa model, Occlusion explanations",
@@ -282,6 +309,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/Tgktb4fq4EdXJNx/download",
**_IMDB_ALBERT_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="imdb-albert-lime-100",
+ description="IMDb dataset, ALBERT model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/FzErcT9TcFcG2Pr/download",
+ **_IMDB_ALBERT_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="imdb-albert-occ",
description="IMDb dataset, ALBERT model, Occlusion explanations",
@@ -317,6 +353,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/ZQEdEmFtKeGkYWp/download",
**_IMDB_BERT_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="imdb-bert-lime-100",
+ description="IMDb dataset, BERT model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/Qx7z8SFcMTB5bFa/download",
+ **_IMDB_BERT_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="imdb-bert-occ",
description="IMDb dataset, BERT model, Occlusion explanations",
@@ -352,6 +397,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/7p2576kFqiQLL9x/download",
**_IMDB_ELECTRA_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="imdb-electra-lime-100",
+ description="IMDb dataset, ELECTRA model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/LBqzn6JiQNzwMAC/download",
+ **_IMDB_ELECTRA_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="imdb-electra-occ",
description="IMDb dataset, ELECTRA model, Occlusion explanations",
@@ -387,6 +441,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/rpsMTw3S6JkQgcF/download",
**_IMDB_ROBERTA_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="imdb-roberta-lime-100",
+ description="IMDb dataset, RoBERTa model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/YZsAoJmR4EcwnG2/download",
+ **_IMDB_ROBERTA_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="imdb-roberta-occ",
description="IMDb dataset, RoBERTa model, Occlusion explanations",
@@ -422,6 +485,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/YCDW67f49wj5NXg/download",
**_IMDB_XLNET_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="imdb-xlnet-lime-100",
+ description="IMDb dataset, XLNet model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/T2KsA8ragxPz6eL/download",
+ **_IMDB_XLNET_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="imdb-xlnet-occ",
description="IMDb dataset, XLNet model, Occlusion explanations",
@@ -457,6 +529,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/e6JRy9fidSAC5zK/download",
**_MNLI_ALBERT_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="multi_nli-albert-lime-100",
+ description="MultiNLI dataset, ALBERT model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/WB2N3nFkHTGkXY8/download",
+ **_MNLI_ALBERT_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="multi_nli-albert-occ",
description="MultiNLI dataset, ALBERT model, Occlusion explanations",
@@ -492,6 +573,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/ptspBexoHaXtqXD/download",
**_MNLI_BERT_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="multi_nli-bert-lime-100",
+ description="MultiNLI dataset, BERT model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/LjFccwQ2mCAnsmH/download",
+ **_MNLI_BERT_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="multi_nli-bert-occ",
description="MultiNLI dataset, BERT model, Occlusion explanations",
@@ -527,6 +617,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/WzBwpwC9FoQZCwB/download",
**_MNLI_ELECTRA_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="multi_nli-electra-lime-100",
+ description="MultiNLI dataset, ELECTRA model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/TX6jWs9wBdsJA9w/download",
+ **_MNLI_ELECTRA_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="multi_nli-electra-occ",
description="MultiNLI dataset, ELECTRA model, Occlusion explanations",
@@ -562,6 +661,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/dY4z4ptcMtiYzZs/download",
**_MNLI_ROBERTA_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="multi_nli-roberta-lime-100",
+ description="MultiNLI dataset, RoBERTa model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/KTQWmCDX2EjHtQE/download",
+ **_MNLI_ROBERTA_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="multi_nli-roberta-occ",
description="MultiNLI dataset, RoBERTa model, Occlusion explanations",
@@ -597,6 +705,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/B7tfLSRKBGYxJ3s/download",
**_MNLI_XLNET_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="multi_nli-xlnet-lime-100",
+ description="MultiNLI dataset, XLNet model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/SesZACA2AwyefFp/download",
+ **_MNLI_XLNET_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="multi_nli-xlnet-occ",
description="MultiNLI dataset, XLNet model, Occlusion explanations",
@@ -632,6 +749,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/sijLW3ceigxDsKY/download",
**_XNLI_ALBERT_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="xnli-albert-lime-100",
+ description="XNLI dataset, ALBERT model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/oQW5cRc6GbqHtB6/download",
+ **_XNLI_ALBERT_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="xnli-albert-occ",
description="XNLI dataset, ALBERT model, Occlusion explanations",
@@ -667,6 +793,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/KfjqkRTd7FSWSkx/download",
**_XNLI_BERT_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="xnli-bert-lime-100",
+ description="XNLI dataset, BERT model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/FXHt989a2En8aZZ/download",
+ **_XNLI_BERT_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="xnli-bert-occ",
description="XNLI dataset, BERT model, Occlusion explanations",
@@ -702,6 +837,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/XnkiHXgxNsptxTJ/download",
**_XNLI_ELECTRA_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="xnli-electra-lime-100",
+ description="XNLI dataset, ELECTRA model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/7zNtxCHxEZk2tzC/download",
+ **_XNLI_ELECTRA_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="xnli-electra-occ",
description="XNLI dataset, ELECTRA model, Occlusion explanations",
@@ -737,6 +881,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/pZKo7m4g9WJXfoe/download",
**_XNLI_ROBERTA_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="xnli-roberta-lime-100",
+ description="XNLI dataset, RoBERTa model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/CHSR7Arw8M56bxN/download",
+ **_XNLI_ROBERTA_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="xnli-roberta-occ",
description="XNLI dataset, RoBERTa model, Occlusion explanations",
@@ -772,6 +925,15 @@ def __init__(
data_url="https://cloud.dfki.de/owncloud/index.php/s/6s4DFPNYpzi8722/download",
**_XNLI_XLNET_KWARGS,
),
+ # new change
+ ThermostatConfig(
+ name="xnli-xlnet-lime-100",
+ description="XNLI dataset, XLNet model, LIME explanations, 100 samples",
+ explainer="LimeBase",
+ data_url="https://cloud.dfki.de/owncloud/index.php/s/ZzN9PSkiRrJNza2/download",
+ **_XNLI_XLNET_KWARGS,
+ ),
+ # new change
ThermostatConfig(
name="xnli-xlnet-occ",
description="XNLI dataset, XLNet model, Occlusion explanations",
diff --git a/src/thermostat/explain.py b/src/thermostat/explain.py
index 2e21044..f297537 100644
--- a/src/thermostat/explain.py
+++ b/src/thermostat/explain.py
@@ -129,8 +129,11 @@ def from_config(cls, config):
res.name_model = config['model']['name']
if config['model']['path_model']: # can be empty when loading a HF model!
res.path_model = config['model']['path_model']
-
- if not config['model']['class']:
+
+ # new change
+ # if not config['model']['class']:
+ if 'class' not in config['model'].keys():
+ # new change
res.num_labels = len(config['dataset']['label_names'])
# TODO: Assert that num_labels in dataset corresponds to classification head in model
res.model = AutoModelForSequenceClassification.from_pretrained(res.name_model, num_labels=res.num_labels)