diff --git a/doc/bibliography.bib b/doc/bibliography.bib index 77c61e8901..000642f8cf 100644 --- a/doc/bibliography.bib +++ b/doc/bibliography.bib @@ -432,6 +432,16 @@ @Article{durlofsky87a doi = {10.1017/S002211208700171X}, } +@TechReport{elijosius24a, + title = {Zero Shot Molecular Generation via Similarity Kernels}, + author = {Elijo{\v s}ius, Rokas and Zills, Fabian and Batatia, Ilyes and Norwood, Sam Walton and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Holm, Christian and Cs{\'a}nyi, G{\'a}bor}, + year = {2024}, + type = {Preprint}, + number = {arXiv:2402.08708}, + doi = {10.48550/arXiv.2402.08708}, + institution = {arXiv}, +} + @Article{ermak78a, title={{B}rownian dynamics with hydrodynamic interactions}, author={Ermak, Donald L. and McCammon, J. A.}, diff --git a/doc/sphinx/conf.py.in b/doc/sphinx/conf.py.in index 7137175212..8c1311ff8a 100644 --- a/doc/sphinx/conf.py.in +++ b/doc/sphinx/conf.py.in @@ -258,7 +258,7 @@ napoleon_use_param = False # Suppress warnings for features not compiled in # https://stackoverflow.com/questions/12206334/sphinx-autosummary-toctree-contains-reference-to-nonexisting-document-warnings -autodoc_mock_imports = ['featuredefs', 'matplotlib', 'OpenGL', 'vtk'] +autodoc_mock_imports = ['featuredefs', 'matplotlib', 'OpenGL', 'vtk', 'zndraw', 'znsocket', 'znjson'] # Add custom stylesheets def setup(app): diff --git a/doc/sphinx/installation.rst b/doc/sphinx/installation.rst index 6b86106b8b..198e90d2f8 100644 --- a/doc/sphinx/installation.rst +++ b/doc/sphinx/installation.rst @@ -128,6 +128,12 @@ Optionally the ccmake utility can be installed for easier configuration: sudo apt install cmake-curses-gui +To install the ZnDraw visualizer: + +.. code-block:: bash + + python3 -m pip install --user -c requirements.txt 'zndraw==0.4.5' + .. _Nvidia GPU acceleration: Nvidia GPU acceleration diff --git a/doc/sphinx/visualization.rst b/doc/sphinx/visualization.rst index 8f994cf56f..b421f49aca 100644 --- a/doc/sphinx/visualization.rst +++ b/doc/sphinx/visualization.rst @@ -12,13 +12,18 @@ to user input. It requires the Python module *PyOpenGL*. It is not meant to produce high quality renderings, but rather to debug the simulation setup and equilibration process. +.. _OpenGL visualizer: + +OpenGL visualizer +----------------- + .. _General usage: General usage -------------- +~~~~~~~~~~~~~ The recommended usage is to instantiate the visualizer and pass it the -:class:`espressomd.System() ` object. Then write +:class:`~espressomd.system.System` object. Then write your integration loop in a separate function, which is started in a non-blocking thread. Whenever needed, call ``update()`` to synchronize the renderer with your system. Finally start the blocking visualization @@ -50,7 +55,7 @@ window with ``start()``. See the following minimal code example:: .. _Setting up the visualizer: Setting up the visualizer -------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~ :class:`espressomd.visualization.openGLLive()` @@ -81,7 +86,6 @@ live plotting (see sample script :file:`/samples/visualization_ljliquid.py`). default package manager of your operating system. On Ubuntu the required package is called ``libgle3-dev``, on Fedora ``libgle`` -- just to name two examples. - .. _Running the visualizer: Running the visualizer @@ -292,6 +296,79 @@ With the keyword ``drag_enabled`` set to ``True``, the mouse can be used to exert a force on particles in drag direction (scaled by ``drag_force`` and the distance of particle and mouse cursor). +.. _ZnDraw: + +ZnDraw visualizer +----------------- + +|es| supports the ZnDraw visualizer :cite:`elijosius24a` in Jupyter Notebooks. +With ZnDraw [1]_, you can visualize your simulation live in a notebook or +web browser. The visualizer is based on ``THREE.js``. + +.. _ZnDraw General usage: + +General usage +~~~~~~~~~~~~~ + +The recommended usage is to instantiate the visualizer :class:`espressomd.zn.Visualizer` and pass it the :class:`~espressomd.system.System` object. +With the initialization you can also assign all particle types a color and radii through a type mapping. There are standard +colors like ``red``, ``black`` etc., but one can also use hex colors like ``#ff0000``. The radii can be set to a float value. +Then write your integration loop in a separate function, and call the update function of the visualizer to capture +the current state of the system and visualize it. Note that the visualizer needs to be started by pressing space. + +Example code:: + + import espressomd + import espressomd.zn + + system = espressomd.System(box_l=[10, 10, 10]) + system.cell_system.skin = 0.4 + system.time_step = 0.001 + + system.part.add(pos=[1, 1, 1], v=[1, 0, 0]) + system.part.add(pos=[9, 9, 9], v=[0, 1, 0]) + + vis = espressomd.zn.Visualizer(system, colors={0: "red"}, radii={0: 0.5}) + + for i in range(1000): + system.integrator.run(25) + vis.update() + +The visualizer supports further features like bonds, constraints, folding and lattice-Boltzmann solvers. The particle coordinates +can be folded by initalizing the visualizer with the keyword ``folded=True``. The display of bonds can be enabled by setting +``bonds=True``. + +Constraints can be drawn using the :meth:`~espressomd.zn.Visualizer.draw_constraints` method. +The method takes a list of all ESPResSo shapes that should be drawn as an argument. + +Furthermore the visualizer supports the visualization of the lattice-Boltzmann solver. The lattice-Boltzmann solver can be visualized +by setting the keyword ``vector_field`` to a lattice-Boltzmann solver :class:`~espressomd.zn.LBField` object, which has to be created +before initializing the visualizer and takes in several parameters like the node spacing, node offset and scale. One can also apply a +color map to the vector field by setting the keyword ``arrow_config`` to a dictionary containing the arrow settings. + +The arrow config contains a ``colormap`` using a list of 2 HSL-color values from which vector colors are interpolated using their length +as a criterium. The ``normalize`` boolean which normalizes the color to the largest vector. The ``colorrange`` list which is only used when +``normalize`` is false and describes the range to what the colorrange is applied to. ``scale_vector_thickness`` is a boolean and changes +the thickness scaling of the vectors and ``opacity`` is a float value that sets the opacity of the vectors. + +An example code snippet containing the :class:`~espressomd.zn.LBField` object:: + + import espressomd.zn + + color = {0: "#00f0f0"} + radii = {0: 0.5} + arrows_config = {'colormap': [[-0.5, 0.9, 0.5], [-0.0, 0.9, 0.5]], + 'normalize': True, + 'colorrange': [0, 1], + 'scale_vector_thickness': True, + 'opacity': 1.0} + + lbfield = espressomd.zn.LBField(system, step_x=2, step_y=2, step_z=5, scale=1) + vis = espressomd.zn.Visualizer(system, colors=color, radii=radii, folded=True, + vector_field=lbfield) + + vis.draw_constraints([wall1, wall2]) + .. _Visualization example scripts: Visualization example scripts @@ -299,3 +376,8 @@ Visualization example scripts Various :ref:`Sample Scripts` can be found in :file:`/samples/visualization_*.py` or in the :ref:`Tutorials` "Visualization" and "Charged Systems". + +____ + +.. [1] + https://github.com/zincware/ZnDraw diff --git a/maintainer/CI/doc_warnings.sh b/maintainer/CI/doc_warnings.sh index d9f4eac093..f1765090d3 100755 --- a/maintainer/CI/doc_warnings.sh +++ b/maintainer/CI/doc_warnings.sh @@ -51,7 +51,7 @@ if [ "${?}" = "0" ]; then # skip if broken link refers to a standard Python type or to a # class/function from an imported module other than espressomd is_standard_type_or_module="false" - grep -Pq '^([a-zA-Z0-9_]+Error|[a-zA-Z0-9_]*Exception|(?!espressomd\.)[a-zA-Z0-9_]+\.[a-zA-Z0-9_\.]+)$' <<< "${reference}" + grep -Pq '^([a-zA-Z0-9_]+Error|[a-zA-Z0-9_]*Exception|ConverterBase|(?!espressomd\.)[a-zA-Z0-9_]+\.[a-zA-Z0-9_\.]+)$' <<< "${reference}" [ "${?}" = "0" ] && is_standard_type_or_module="true" # private objects are not documented and cannot be linked is_private="false" diff --git a/src/python/espressomd/zn.py b/src/python/espressomd/zn.py new file mode 100644 index 0000000000..7cb887cd5c --- /dev/null +++ b/src/python/espressomd/zn.py @@ -0,0 +1,624 @@ +# +# Copyright (C) 2024 The ESPResSo project +# +# This file is part of ESPResSo. +# +# ESPResSo is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# ESPResSo is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +import subprocess +import numpy as np +import zndraw.zndraw +import zndraw.utils +import zndraw.draw +import znsocket +import znjson +import espressomd +import secrets +import time +import urllib.parse +import typing as t + + +# Standard colors +color_dict = {"black": "#303030", + "red": "#e6194B", + "green": "#3cb44b", + "yellow": "#ffe119", + "blue": "#4363d8", + "orange": "#f58231", + "purple": "#911eb4", + "cyan": "#42d4f4", + "magenta": "#f032e6", + "lime": "#bfef45", + "brown": "#9A6324", + "grey": "#a9a9a9", + "white": "#f0f0f0"} + + +class EspressoConverter(znjson.ConverterBase): + """ + Converter for ESPResSo systems to ASEDict + """ + level = 100 + representation = "ase.Atoms" + instance = espressomd.system.System + + def encode(self, system) -> zndraw.utils.ASEDict: + self.system = system + self.particles = self.system.part.all() + self.num_particles = len(self.particles) + self.params = system.visualizer_params + + self.numbers = self.num_particles * [1] + + if self.params["folded"] is True: + self.positions = self.particles.pos_folded + else: + self.positions = self.particles.pos + + if self.params["colors"] is None: + self.colors = self.get_default_colors() + else: + self.colors = self.set_colors(self.params["colors"]) + + if self.params["radii"] is None: + self.radii = self.get_default_radii() + else: + self.radii = self.set_radii(self.params["radii"]) + + if self.params["bonds"] is True: + bonds = self.get_bonds() + else: + bonds = [] + + arrays = { + "colors": self.colors, + "radii": self.radii, + } + cell = [[system.box_l[0], 0, 0], + [0, system.box_l[1], 0], + [0, 0, system.box_l[2]]] + pbc = system.periodicity + calc = None + info = {} + + if self.params["vector_field"] is not None: + vectors = self.params["vector_field"]() + else: + vectors = [] + + return zndraw.utils.ASEDict( + numbers=self.numbers, + positions=self.positions.tolist(), + connectivity=bonds, + arrays=arrays, + info=info, + calc=calc, + pbc=pbc.tolist(), + cell=cell, + vectors=vectors, + ) + + def decode(self, value): + value = None + return value + + def get_default_colors(self): + return [color_dict["white"]] * self.num_particles + + def get_default_radii(self): + return [0.5] * self.num_particles + + def set_colors(self, colors): + color_list = list() + for p in self.particles: + color = colors[p.type] + # if color starts with #, assume it is a hex color + if color.startswith("#"): + color_list.append(color) + else: + if color not in color_dict: + raise ValueError( + f"Color {color} not found in color dictionary") + color_list.append(color_dict[color]) + return color_list + + def set_radii(self, radii): + radius_list = list() + for p in self.particles: + radius_list.append(radii[p.type]) + return radius_list + + def get_bonds(self): + bonds = [] + for p in self.particles: + if not p.bonds: + continue + for bond in p.bonds: + if len(bond) == 4: + bonds.append([p.id, bond[1], 1]) + bonds.append([p.id, bond[2], 1]) + bonds.append([bond[2], bond[3], 1]) + else: + for bond_partner in bond[1:]: + bonds.append([p.id, bond_partner, 1]) + + self.process_bonds(bonds) + + return bonds + + def process_bonds(self, bonds): + half_box_l = 0.5 * self.system.box_l + num_part = len(self.positions) + bonds_to_remove = [] + bonds_to_add = [] + + for b in bonds: + try: + if self.params["folded"] is True: + x_a = self.system.part.by_id(b[0]).pos_folded + x_b = self.system.part.by_id(b[1]).pos_folded + else: + x_a = self.system.part.by_id(b[0]).pos + x_b = self.system.part.by_id(b[1]).pos + except Exception: + bonds_to_remove.append(b) + continue + + dx = x_b - x_a + + if np.all(np.abs(dx) < half_box_l): + continue + + if self.params["folded"] is False: + bonds_to_remove.append(b) + continue + + d = self.cut_bond(x_a, dx) + if d is np.inf: + bonds_to_remove.append(b) + continue + + s_a = x_a + 0.8 * dx + s_b = x_b - 0.8 * dx + + bonds_to_remove.append(b) + + self.add_ghost_particle(pos=s_a, color=self.colors[b[0]]) + bonds_to_add.append([b[0], num_part, 1]) + + self.add_ghost_particle(pos=s_b, color=self.colors[b[1]]) + bonds_to_add.append([b[1], num_part + 1, 1]) + num_part += 2 + + for b in bonds_to_remove: + bonds.remove(b) + + bonds.extend(bonds_to_add) + + def cut_bond(self, x_a, dx): + if np.dot(dx, dx) < 1e-9: + return np.inf + shift = np.rint(dx / self.system.box_l) + dx -= shift * self.system.box_l + best_d = np.inf + for i in range(3): + if dx[i] == 0: + continue + elif dx[i] > 0: + p0_i = self.system.box_l[i] + else: + p0_i = 0 + + d = (p0_i - x_a[i]) / dx[i] + if d < best_d: + best_d = d + return best_d + + def add_ghost_particle(self, pos, color): + self.positions = np.vstack([self.positions, pos]) + self.radii.append(1e-6 * min(self.radii)) + self.colors.append(color) + self.numbers.append(2) + + +znjson.config.register(EspressoConverter) + + +class LBField: + """ + Convert the ESPResSo lattice-Boltzmann field to a vector field for visualization. Samples + the field at a given step size and offset over the lattice nodes. + + Parameters + ---------- + system : :class:`~espressomd.system.System` + ESPResSo system + step_x : :obj:`int`, optional + Step size in x direction, by default 1 + step_y : :obj:`int`, optional + Step size in y direction, by default 1 + step_z : :obj:`int`, optional + Step size in z direction, by default 1 + offset_x : :obj:`int`, optional + Offset in x direction, by default 0 + offset_y : :obj:`int`, optional + Offset in y direction, by default 0 + offset_z : :obj:`int`, optional + Offset in z direction, by default 0 + scale : :obj:`float`, optional + Scale the velocity vectors, by default 1.0 + arrow_config : :obj:`dict`, optional + Configuration for the arrows, by default None and then uses the default configuration: + + 'colormap': [[-0.5, 0.9, 0.5], [-0.0, 0.9, 0.5]] + HSL colormap for the arrows, where the first value is the minimum value and the second value is the maximum value. + 'normalize': True + Normalize the colormap to the maximum value each frame + 'colorrange': [0, 1] + Range of the colormap, only used if normalize is False + 'scale_vector_thickness': True + Scale the thickness of the arrows with the velocity + 'opacity': 1.0 + Opacity of the arrows + """ + + def __init__(self, system: espressomd.system.System, + step_x: int = 1, + step_y: int = 1, + step_z: int = 1, + offset_x: int = 0, + offset_y: int = 0, + offset_z: int = 0, + scale: float = 1.0, + arrow_config: dict = None): + self.step_x = step_x + self.step_y = step_y + self.step_z = step_z + self.offset_x = offset_x + self.offset_y = offset_y + self.offset_z = offset_z + self.scale = scale + self.box = system.box_l + + if system.lb is None: + raise ValueError("System does not have a lattice-Boltzmann solver") + self.lbf = system.lb + self.agrid = system.lb.agrid + self.arrow_config = arrow_config + + def _build_grid(self): + x = np.arange(self.offset_x + self.agrid / 2, + self.box[0], self.step_x * self.agrid) + y = np.arange(self.offset_y + self.agrid / 2, + self.box[1], self.step_y * self.agrid) + z = np.arange(self.offset_z + self.agrid / 2, + self.box[2], self.step_z * self.agrid) + + origins = np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3) + + return origins + + def _get_velocities(self): + velocities = self.lbf[:, :, :].velocity + velocities = velocities[::self.step_x, ::self.step_y, ::self.step_z] + velocities = self.scale * velocities + velocities = np.swapaxes(velocities, 0, 3) + velocities = np.swapaxes(velocities, 2, 3) + velocities = velocities.T.reshape(-1, 3) + + return velocities + + def __call__(self): + origins = self._build_grid() + velocities = self._get_velocities() + velocities = origins + velocities + vector_field = np.stack([origins, velocities], axis=1) + + return vector_field.tolist() + + +class VectorField: + """ + Give an array of origins and vectors to create a vector field for visualization. + The vectorfield is updated every time it is called. Both origins and vectors must have the same shape. + Both origins and vectors must be 3D numpy arrays in the shape of ``(n, m, 3)``, + where ``n`` is the number of frames the field has and ``m`` is the number of vectors. + The number of frames ``n`` must larger or equal to the number of times the update function will be called in the update loop. + + Parameters + ---------- + origins : (n, m, 3) array_like of :obj:`float` + Array of origins for the vectors + vectors : (n, m, 3) array_like of :obj:`float` + Array of vectors + scale : :obj:`float`, optional + Scale the vectors, by default 1 + arrow_config : :obj:`dict`, optional + Configuration for the arrows, by default ``None`` and then uses the default configuration: + + 'colormap': [[-0.5, 0.9, 0.5], [-0.0, 0.9, 0.5]] + HSL colormap for the arrows, where the first value is the minimum value and the second value is the maximum value. + 'normalize': True + Normalize the colormap to the maximum value each frame + 'colorrange': [0, 1] + Range of the colormap, only used if normalize is False + 'scale_vector_thickness': True + Scale the thickness of the arrows with the velocity + 'opacity': 1.0 + Opacity of the arrows + """ + + def __init__(self, origins: np.ndarray, vectors: np.ndarray, + scale: float = 1, arrow_config: dict = None): + self.origins = origins + self.vectors = vectors + self.scale = scale + self.frame_count = 0 + self.arrow_config = arrow_config + + def __call__(self): + if self.origins.shape != self.vectors.shape: + raise ValueError("Origins and vectors must have the same shape") + + origins = self.origins[self.frame_count] + vectors = origins + self.vectors[self.frame_count] + self.frame_count += 1 + origins = origins.reshape(-1, 3) + vectors = vectors.reshape(-1, 3) + vectors = self.scale * vectors + vector_field = np.stack([origins, vectors], axis=1) + + return vector_field.tolist() + + +class Visualizer(): + """ + Visualizer for ESPResSo simulations using ZnDraw. + + Main component of the visualizer is the ZnDraw server, which is started as a subprocess. + The ZnDraw client is used to communicate with the server and send the visualized data. + The visualized data is encoded using the :class:`EspressoConverter`, + which converts the ESPResSo system to an typed dict. + The visualizer uploads a new frame to the server every time the update method is called. + + Parameters + ---------- + system : :class:`~espressomd.system.System` + ESPResSo system to visualize + port : :obj:`int`, optional + Port for the ZnDraw server, by default 1234, if taken, the next available port is used + token : :obj:`str`, optional + Token for the ZnDraw server, by default a random token is generated + folded : :obj:`bool`, optional + Fold the positions of the particles into the simulation box, by default True + colors : :obj:`dict`, optional + Dictionary containing color type mapping for the particles, by default all particles are white + radii : :obj:`dict`, optional + Dictionary containing radii type mapping for the particles, by default all particles have a radius of 0.5 + bonds : :obj:`bool`, optional + Draw bonds between particles, by default ``False`` + jupyter : :obj:`bool`, optional + Show the visualizer in a Jupyter notebook, by default True + vector_field : :class:`~espressomd.zn.VectorField` or :class:`~espressomd.zn.LBField`, optional + Vector field to visualize, by default ``None`` + + """ + + def __init__(self, + system: espressomd.system.System = None, + port: int = 1234, + token: str = secrets.token_hex(4), + folded: bool = True, + colors: dict = None, + radii: dict = None, + bonds: bool = False, + jupyter: bool = True, + vector_field: t.Union[VectorField, LBField] = None, + ): + + self.system = system + self.params = { + "folded": folded, + "colors": colors, + "radii": radii, + "bonds": bonds, + "vector_field": vector_field, + } + + self.url = "http://127.0.0.1" + self.frame_count = 0 + self.port = port + self.token = token + + # A server is started in a subprocess, and we have to wait for it + print("Starting ZnDraw server, this may take a few seconds") + self._start_server() + time.sleep(10) + self._start_zndraw() + time.sleep(1) + + if vector_field is not None: + self.arrow_config = {'colormap': [[-0.5, 0.9, 0.5], [-0.0, 0.9, 0.5]], + 'normalize': True, + 'colorrange': [0, 1], + 'scale_vector_thickness': True, + 'opacity': 1.0} + + if vector_field.arrow_config is not None: + for key, value in vector_field.arrow_config.items(): + if key not in self.arrow_config: + raise ValueError(f"Invalid key {key} in arrow_config") + self.arrow_config[key] = value + + if self.params["bonds"] and not self.params["folded"]: + print( + "Warning: Unfolded positions may result in incorrect bond visualization") + + if jupyter: + self._show_jupyter() + else: + # Problems with server being terminated at the end of the script + raise NotImplementedError("Only Jupyter is supported for now") + # webbrowser.open_new_tab(self.address) + + def _start_server(self): + """ + Start the ZnDraw server through a subprocess + """ + self.socket_port = zndraw.utils.get_port(default=6374) + + self.server = subprocess.Popen(["zndraw", "--no-browser", f"--port={self.port}", f"--storage-port={self.socket_port}"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL + ) + + def _start_zndraw(self): + """ + Start the ZnDraw client and connect to the server + """ + config = zndraw.zndraw.TimeoutConfig( + connection=10, + modifier=0.25, + between_calls=0.1, + emit_retries=3, + call_retries=1, + connect_retries=3, + ) + while True: + try: + self.r = znsocket.Client( + address=f"{self.url}:{self.socket_port}") + break + except BaseException: + time.sleep(0.5) + + url = f"{self.url}:{self.port}" + self.zndraw = zndraw.zndraw.ZnDrawLocal( + r=self.r, url=url, token=self.token, timeout=config) + parsed_url = urllib.parse.urlparse( + f"{self.zndraw.url}/token/{self.zndraw.token}") + self.address = parsed_url._replace(scheme="http").geturl() + + def _show_jupyter(self): + """ + Show the visualizer in a Jupyter notebook + """ + from IPython.display import IFrame, display + print(f"Showing ZnDraw at {self.address}") + display(IFrame(src=self.address, width="100%", height="700px")) + + def update(self): + """ + Update the visualizer with the current state of the system + """ + self.system.visualizer_params = self.params + + data = znjson.dumps( + self.system, cls=znjson.ZnEncoder.from_converters( + [EspressoConverter]) + ) + + # Catch when the server is initializing an empty frame + if self.frame_count == 0 and len(self.zndraw) > 0: + self.zndraw.__setitem__(0, data) + else: + self.zndraw.append(data) + + if self.params["vector_field"] is not None and self.frame_count == 0: + self.zndraw.socket.sleep(1) + + for key, value in self.arrow_config.items(): + setattr(self.zndraw.config.arrows, key, value) + + self.frame_count += 1 + + def draw_constraints(self, shapes: list): + """ + Draw constraints on the visualizer + """ + if not isinstance(shapes, list): + raise ValueError("Constraints must be given in a list") + + objects = [] + + for shape in shapes: + + shape_type = shape.__class__.__name__ + + mat = zndraw.draw.Material(color="#b0b0b0", opacity=0.8) + + if shape_type == "Cylinder": + center = shape.center + axis = shape.axis + length = shape.length + radius = shape.radius + + rotation_angles = zndraw.utils.direction_to_euler( + axis, roll=np.pi / 2) + + objects.append(zndraw.draw.Cylinder(position=center, + rotation=rotation_angles, + radius_bottom=radius, + radius_top=radius, + height=length, + material=mat)) + + elif shape_type == "Wall": + dist = shape.dist + normal = np.array(shape.normal) + + rotation_angles = zndraw.utils.direction_to_euler(normal) + + position = (dist * normal).tolist() + # Not optimal, but ensures its always larger than the box. + wall_width = wall_height = 2 * max(self.system.box_l) + + objects.append(zndraw.draw.Plane( + position=position, + rotation=rotation_angles, + width=wall_width, + height=wall_height, + material=mat)) + + elif shape_type == "Sphere": + center = shape.center + radius = shape.radius + + objects.append( + zndraw.draw.Sphere(position=center, radius=radius, material=mat)) + + elif shape_type == "Rhomboid": + a = shape.a + b = shape.b + c = shape.c + corner = shape.corner + + objects.append( + zndraw.draw.Rhomboid(position=corner, vectorA=a, vectorB=b, vectorC=c, material=mat)) + + elif shape_type == "Ellipsoid": + center = shape.center + a = shape.a + b = shape.b + + objects.append(zndraw.draw.Ellipsoid(position=center, + a=a, b=b, c=b, material=mat)) + + else: + raise NotImplementedError( + f"Shape of type {shape_type} isn't available in ZnDraw") + + self.zndraw.geometries = objects