diff --git a/viz.py b/viz.py index a5a0d078..9ee9b8b6 100644 --- a/viz.py +++ b/viz.py @@ -5,25 +5,31 @@ import pytest from mesa.visualization import SolaraViz + def get_viz_files(directory): viz_files = [] for root, dirs, files in os.walk(directory): for file in files: - if file in ['app.py', 'viz.py']: + if file in ["app.py", "viz.py"]: module_name = os.path.relpath(os.path.join(root, file[:-3])).replace( os.sep, "." ) viz_files.append(module_name) return viz_files + @pytest.mark.parametrize("module_name", get_viz_files("examples")) def test_solara_viz(module_name): # Add the 'examples' directory to the Python path - examples_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'examples')) + examples_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "examples") + ) sys.path.insert(0, examples_dir) # Add the parent directory of the module to the Python path - module_parent_dir = os.path.abspath(os.path.join(examples_dir, os.path.dirname(module_name.replace('.', os.sep)))) + module_parent_dir = os.path.abspath( + os.path.join(examples_dir, os.path.dirname(module_name.replace(".", os.sep))) + ) if module_parent_dir not in sys.path: sys.path.insert(0, module_parent_dir) @@ -83,6 +89,7 @@ def test_solara_viz(module_name): if module_parent_dir in sys.path: sys.path.remove(module_parent_dir) + # Run the tests if __name__ == "__main__": pytest.main([__file__, "-v"])