Skip to content

Commit

Permalink
Merge pull request #867 from dianna-ai/eu-law-dashboard
Browse files Browse the repository at this point in the history
Add EU law example to dashboard
  • Loading branch information
cwmeijer authored Oct 23, 2024
2 parents 116ffeb + d2edc6e commit 4500fac
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 18 deletions.
77 changes: 77 additions & 0 deletions dianna/dashboard/_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from pathlib import Path
from typing import Iterable
import numpy as np
import onnx
import pandas as pd
import torch
import xgboost
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel
from transformers import AutoTokenizer
from dianna.utils.tokenizers import SpacyTokenizer


def load_data(file):
Expand Down Expand Up @@ -81,3 +89,72 @@ def load_penguins(penguins):
X_test.insert(0, 'Index', X_test.index)

return X_train.to_numpy(dtype=np.float32), X_test


def features_eulaw(texts: list[str], model_tag="law-ai/InLegalBERT"):
"""Create features for a list of texts."""
max_length = 512
tokenizer = AutoTokenizer.from_pretrained(model_tag)
model = AutoModel.from_pretrained(model_tag)

def process_batch(batch: Iterable[str]):
cropped_texts = [text[:max_length] for text in batch]
encoded_inputs = tokenizer(cropped_texts, padding='longest', truncation=True, max_length=max_length,
return_tensors="pt")
with torch.no_grad():
outputs = model(**encoded_inputs)
last_hidden_states = outputs.last_hidden_state
sentence_features = last_hidden_states.mean(dim=1)
return sentence_features

dataloader = DataLoader(texts, batch_size=1) # batch size of 1 was quickest for my development
features = [process_batch(batch) for batch in tqdm(dataloader, desc='Creating features')]
return np.array(torch.cat(features, dim=0))


def classify_texts_eulaw(texts: list[str], model_path, return_proba: bool = False):
"""Classifies every text in a list of texts using the xgboost model stored in model_path.
The xgboost model will be loaded and used to classify the texts. The texts however will first be processed by a
large language model which will do the feature extraction for every text. The classifications of the
xgboost model will be returned.
For training the xgboost model, see train_legalbert_xgboost.py.
Parameters
----------
texts
A list of strings of which each needs to be classified.
model_path
The path to a stored xgboost model
return_proba
return the probabilities of the model
Returns
-------
List of classifications, one for every text in the list
"""
features = features_eulaw(texts)
model = xgboost.XGBClassifier()
model.load_model(model_path)

if return_proba:
return model.predict_proba(features)
return model.predict(features)


class StatementClassifierEUlaw():
def __init__(self, model_path):
self.tokenizer = SpacyTokenizer(name='en_core_web_sm')
self.model_path = model_path

def __call__(self, sentences):
# ensure the input has a batch axis
if isinstance(sentences, str):
sentences = [sentences]

probs = classify_texts_eulaw(sentences, self.model_path, return_proba=True)

model_runner = np.transpose([(probs[:, 0]), (1 - probs[:, 0])])

return model_runner
3 changes: 2 additions & 1 deletion dianna/dashboard/_models_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def _run_rise_text(_model, text, **kwargs):

@st.cache_data
def _run_lime_text(_model, text, **kwargs):
relevances = explain_text(_model, text, tokenizer, method='LIME', **kwargs)
relevances = explain_text(_model, text, tokenizer, method='LIME',
**kwargs)
return relevances


Expand Down
1 change: 0 additions & 1 deletion dianna/dashboard/_models_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def run_model(ts_data):
method='LIME',
num_features=len(ts_data[0]),
num_slices=len(ts_data[0]),
num_samples=100,
distance_method='dtw',
**kwargs,
)
Expand Down
7 changes: 5 additions & 2 deletions dianna/dashboard/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,14 @@ def _get_params(method: str, key):
elif method == 'LIME':
if 'Tabular' in key:
return {
'random_state': st.number_input('Random state', value=0, key=f'{key}_{method}_rs'),
}
'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
'num_samples': st.number_input('Number of samples', value=2000, key=f'{key}_{method}_ns')
}
else:
return {
'random_state': st.number_input('Random state', value=2, key=f'{key}_{method}_rs'),
'num_features': st.number_input('Number of features', 999, key=f'{key}_{method}_rf'),
'num_samples': st.number_input('Number of samples', value=2000, key=f'{key}_{method}_ns')
}

else:
Expand Down
56 changes: 47 additions & 9 deletions dianna/dashboard/pages/Text.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import sys
import streamlit as st
from _model_utils import StatementClassifierEUlaw
from _model_utils import load_labels
from _model_utils import load_model
from _models_text import explain_text_dispatcher
Expand Down Expand Up @@ -66,7 +67,7 @@ def description_explainer(open='open'):
if input_type == 'Use an example':
load_example = st.sidebar.radio(
label='Use example',
options=('Movie sentiment classification',),
options=('Movie sentiment classification', 'Nature of EU laws'),
index = None,
on_change = reset_method,
key='Text_load_example')
Expand All @@ -93,6 +94,37 @@ def description_explainer(open='open'):
""",
unsafe_allow_html=True
)
elif load_example == 'Nature of EU laws':
text_input = st.sidebar.selectbox(
'Select EU law statement',
(
"The relevant Member State shall inform the other Member States of any authorisation granted under "
"this Article.",
"The purchase, import or transport from Syria of crude oil and petroleum products shall be prohibited.",
"This Decision shall enter into force on the twentieth day following that of its publication in the "
"Official Journal of the European Union.",
"Where observations are submitted, or where substantial new evidence is presented, the Council shall "
"review its decision and inform the person or entity concerned accordingly.",
"Member States shall cooperate, in accordance with their national legislation, with inspections and "
"disposals undertaken pursuant to paragraphs 1 and 2.")
)
text_model_file = download('inlegal_bert_xgboost_classifier.json', 'model')

description_explainer("")
st.markdown(
"""
**********************************************************************
This notebook demonstrates how to use the LIME explainable-AI method in [DIANNA](https://github.com/dianna-ai/dianna)
to explain a text classification model created as part of the [Nature of EU Rules project
](https://research-software-directory.org/projects/the-nature-of-eu-rules-strict-and-detailed-or-lacking-bite).
The model is used to perform binary classification of individual sentences from EU legislation to determine
whether they specify a regulation or not (i.e., whether they specify a legal obligation or prohibition that some
legal entity should comply with).
[Here's an example](https://eur-lex.europa.eu/legal-content/EN/TXT/HTML/?uri=CELEX:32012R1215&qid=1724343987254)
of what an EU legislative document looks like.
""",
unsafe_allow_html=True
)
else:
description_explainer()
st.info('Select an example in the left panel to coninue')
Expand Down Expand Up @@ -122,23 +154,29 @@ def description_explainer(open='open'):
st.info('Select which input type to use in the left panel to continue')
st.stop()

model = load_model(text_model_file)
serialized_model = model.SerializeToString()
if load_example == 'Nature of EU laws':
labels = ['constitutive', 'regulatory']
choices = ('LIME',)

labels = load_labels(text_label_file)

choices = ('RISE', 'LIME')
else:
labels = load_labels(text_label_file)
choices = ('RISE', 'LIME')
model = load_model(text_model_file)
serialized_model = model.SerializeToString()

st.text("")

with st.container(border=True):
prediction_placeholder = st.empty()
methods, method_params = _methods_checkboxes(choices=choices, key='Text_cb')

model_runner = MovieReviewsModelRunner(serialized_model)

with st.spinner('Predicting class'):
predictions = predict(model=serialized_model, text_input=text_input)
if load_example == 'Nature of EU laws':
model_runner = StatementClassifierEUlaw(text_model_file)
predictions = model_runner([text_input])
else:
model_runner = MovieReviewsModelRunner(serialized_model)
predictions = predict(model=serialized_model, text_input=text_input)

with prediction_placeholder:
top_indices, top_labels = _get_top_indices_and_labels(
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ dashboard =
streamlit-aggrid
streamlit
streamlit_option_menu
transformers
xgboost
notebooks =
keras
nbmake
Expand Down
6 changes: 1 addition & 5 deletions tests/test_dashboard_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,8 @@ def test_timeseries_page(page: Page):
page.get_by_role('heading', name='winter').get_by_text('winter'),
page.get_by_role('img', name='0').first,
page.get_by_role('img', name='0').nth(1),
# Second image
page.get_by_role('heading', name='summer').get_by_text('summer'),
page.get_by_role('img', name='0').nth(2),
page.get_by_role('img', name='0').nth(3),
):
expect(selector).to_be_visible(timeout=100_000)
expect(selector).to_be_visible(timeout=200_000)

# Test FRB example
page.locator("label").filter(has_text="Use an example").locator("div").nth(1).click()
Expand Down

0 comments on commit 4500fac

Please sign in to comment.