Skip to content

Commit

Permalink
Freezing observed when changing the parameter sw in the module `pic…
Browse files Browse the repository at this point in the history
…k_cest` (#237)

Fixes #236
  • Loading branch information
gbouvignies authored May 29, 2024
1 parent f6e2b0c commit 0cef75c
Showing 1 changed file with 78 additions and 38 deletions.
116 changes: 78 additions & 38 deletions chemex/tools/pick_cest/buttons.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,54 +26,87 @@


class Buttons:
"""A class to manage the interaction with Matplotlib plots for chemical shift analysis.
Attributes:
fig (Figure): The Matplotlib figure object.
axis (Axes): The Matplotlib axes object.
data (dict): Data structure to hold curves for each spin system.
spin_systems (list): Sorted list of spin systems.
out (Path): Path to save output files.
spin_system (SpinSystem): Current spin system being analyzed.
curves (list): List of curves for the current spin system.
cs_a (dict): Dictionary to store chemical shift 'a' for each spin system.
cs_b (dict): Dictionary to store chemical shift 'b' for each spin system.
artists (list): List of Matplotlib artist objects.
sw (Optional[float]): Sweep width for the curves.
cursor (Cursor): Matplotlib cursor object for interaction.
index (int): Index to track the current spin system.
"""

def __init__(
self,
figure: Figure,
axis: Axes,
experiments: Experiments,
path: Path,
sw: float | None,
sw: float | None = None,
) -> None:
self.fig = figure
self.axis = axis
self.data: dict[SpinSystem, list[Curve]] = {}
spin_systems: set[SpinSystem] = set()

for experiment in experiments:
for profile in experiment:
spin_system = profile.spin_system
spin_systems.add(spin_system)
self.data.setdefault(spin_system, []).append(Curve(profile, sw))

self.spin_systems = sorted(spin_systems)
self.data: dict[SpinSystem, list[Curve]] = self._init_data(experiments, sw)
self.spin_systems = sorted(self.data.keys())
self.out = path

self.spin_system = SpinSystem(name="")
self.curves: list[Curve] = []
self.cs_a: dict[SpinSystem, float | None] = {}
self.cs_b: dict[SpinSystem, float | None] = {}
self.artists: list[Artist] = []
self.sw = sw

self.cursor = Cursor(self.axis, horizOn=False, useblit=True)

self.index = -1
self.next()

@staticmethod
def _init_data(
experiments: Experiments, sw: float | None
) -> dict[SpinSystem, list[Curve]]:
"""Initialize data structure from experiments.
Args:
experiments (Experiments): Container of experiments.
sw (Optional[float]): Sweep width for the curves.
Returns:
dict: Data structure with spin systems and their corresponding curves.
"""
data = {}
for experiment in experiments:
for profile in experiment:
spin_system = profile.spin_system
if spin_system not in data:
data[spin_system] = []
data[spin_system].append(Curve(profile, sw))
return data

def _clear_artists(self) -> None:
"""Remove all artist objects from the plot."""
while self.artists:
self.artists.pop().remove()

def _clear_axis(self) -> None:
"""Clear the current axis and remove all artists."""
self._clear_artists()
self.axis.clear()

def _show_labels(self) -> None:
"""Set the title and axis labels for the plot."""
self.axis.set_title(str(self.spin_system))
self.axis.set_xlabel(XLABELS[self.spin_system.nuclei["i"]])
self.axis.set_ylabel("$I/I_0$")

def _plot_profiles(self) -> None:
"""Plot the experimental profiles and their splines."""
if not self.curves:
return
xranges = np.concatenate([curve.get_xrange(self.sw) for curve in self.curves])
Expand All @@ -89,13 +122,20 @@ def _plot_profiles(self) -> None:
self.fig.canvas.draw_idle()

def _get_click_position(self, event: Event) -> float | None:
if not isinstance(event, LocationEvent):
return None
if event.inaxes != self.axis:
return None
return event.xdata
"""Get the x-coordinate of a click event.
Args:
event (Event): Matplotlib event object.
Returns:
Optional[float]: The x-coordinate of the click, or None if invalid.
"""
if isinstance(event, LocationEvent) and event.inaxes == self.axis:
return event.xdata
return None

def _add_line(self, position: float, state: Literal["a", "b"]) -> None:
"""Add a vertical line and a text label to the plot."""
text_ = rf"$\varpi_{state}$ = {position:.3f} ppm"
text = self.fig.text(0.82, TEXT_Y[state], text_)
line = self.axis.axvline(
Expand All @@ -107,13 +147,14 @@ def _add_line(self, position: float, state: Literal["a", "b"]) -> None:
self.artists.extend([line, text])

def _add_text_dw(self, dw_ab: float) -> None:
"""Add a text label for the chemical shift difference."""
text_ = rf"$\Delta\varpi_{{ab}}$ = {dw_ab:.3f} ppm"
text = self.fig.text(0.82, 0.7, text_)
self.artists.append(text)

def _save(self) -> None:
"""Save the chemical shift data to TOML files."""
self.out.mkdir(parents=True, exist_ok=True)

fname1 = self.out / "cs_a.toml"
fname2 = self.out / "dw_ab.toml"

Expand All @@ -137,7 +178,7 @@ def _save(self) -> None:
file2.write(f"{str(name).upper():10s} = {dw_ab:8.3f}\n")

def set_cs(self, event: Event) -> None:
"""Set the chemical shift."""
"""Set the chemical shift based on a click event."""
xdata = self._get_click_position(event)
if xdata is None:
return
Expand All @@ -153,6 +194,7 @@ def set_cs(self, event: Event) -> None:
self._plot_lines()

def _plot_lines(self) -> None:
"""Plot the vertical lines for chemical shifts."""
key = self.spin_system
cs_a = self.cs_a.get(key)
cs_b = self.cs_b.get(key)
Expand All @@ -169,47 +211,45 @@ def _plot_lines(self) -> None:
self._add_text_dw(cs_b - cs_a)

self._save()

self.fig.canvas.draw_idle()
self.fig.canvas.flush_events()

def _plot(self, event: Event | None = None) -> None:
"""Main plotting function to clear axis and plot profiles and lines."""
self._clear_axis()
self._plot_lines()
self._plot_profiles()
self.fig.canvas.draw_idle()

def _shift(self, step: int) -> None:
self.index += step
self.index %= len(self.spin_systems)
"""Shift the current spin system index by a given step."""
self.index = (self.index + step) % len(self.spin_systems)
self.spin_system = self.spin_systems[self.index]
self.curves = self.data[self.spin_system]
self._clear_axis()
self._plot()

def next(self, _event: Event | None = None) -> None:
"""Go to next residue."""
self._shift(+1)
"""Go to the next residue."""
self._shift(1)

def previous(self, _event: Event) -> None:
"""Go to previous residue."""
def previous(self, _event: Event | None = None) -> None:
"""Go to the previous residue."""
self._shift(-1)

def swap(self, event: Event) -> None:
"""Swap peak peak positions for major/minor states."""
name = self.spin_system
if self.cs_b[name] is not None:
self.cs_a[name], self.cs_b[name] = self.cs_b[name], self.cs_a[name]
"""Swap peak positions for major/minor states."""
key = self.spin_system
if self.cs_b[key] is not None:
self.cs_a[key], self.cs_b[key] = self.cs_b[key], self.cs_a[key]
self._plot_lines()

def clear(self, event: Event) -> None:
name = self.spin_system
self.cs_a[name], self.cs_b[name] = None, None
"""Clear the chemical shifts for the current spin system."""
key = self.spin_system
self.cs_a[key], self.cs_b[key] = None, None
self._plot_lines()

def set_sw(self, sw: float) -> None:
"""Set the sweep width and update the plot."""
with contextlib.suppress(ValueError):
self.sw = sw

self._clear_axis()
self._plot()

0 comments on commit 0cef75c

Please sign in to comment.