Skip to content

Commit

Permalink
Add teardown for util.classification plotting tests
Browse files Browse the repository at this point in the history
  • Loading branch information
morganjwilliams committed May 17, 2024
1 parent a77be50 commit 9662162
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions test/util/util_classification.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import unittest

import matplotlib.pyplot as plt
import pandas as pd

from pyrolite.util.classification import (TAS,
USDASoilTexture,
QAP,
FeldsparTernary,
JensenPlot,
PeralkalinityClassifier)
from pyrolite.util.classification import (
QAP,
TAS,
FeldsparTernary,
JensenPlot,
PeralkalinityClassifier,
USDASoilTexture,
)
from pyrolite.util.synthetic import normal_frame, random_cov_matrix


Expand Down Expand Up @@ -50,6 +53,9 @@ def test_classifer_predict(self):
_ = classes.apply(lambda x: cm.fields.get(x, {"name": None})["name"])
self.assertFalse(pd.isnull(classes).all())

def tearDown(self):
plt.close("all")


class TestUSDASoilTexture(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -77,6 +83,9 @@ def test_classifer_predict(self):
classes = cm.predict(df, data_scale=1.0)
self.assertFalse(pd.isnull(classes).all())

def tearDown(self):
plt.close("all")


class TestQAP(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -104,6 +113,9 @@ def test_classifer_predict(self):
classes = cm.predict(df)
self.assertFalse(pd.isnull(classes).all())

def tearDown(self):
plt.close("all")


class TestFeldsparTernary(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -131,6 +143,9 @@ def test_classifer_predict(self):
classes = cm.predict(df)
self.assertFalse(pd.isnull(classes).all())

def tearDown(self):
plt.close("all")


class TestJensenPlot(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -159,6 +174,9 @@ def test_classifer_predict(self):
classes = cm.predict(df)
self.assertFalse(pd.isnull(classes).all())

def tearDown(self):
plt.close("all")


class TestPeralkalinity(unittest.TestCase):
"""Test the peralkalinity classifier."""
Expand All @@ -175,6 +193,9 @@ def test_classifer_predict(self):
cm = PeralkalinityClassifier()
df.loc[:, "Peralk"] = cm.predict(df)

def tearDown(self):
plt.close("all")


if __name__ == "__main__":
unittest.main()

0 comments on commit 9662162

Please sign in to comment.