Skip to content

Commit

Permalink
Merge pull request #100 from chorus-ai/wfdb-open-close
Browse files Browse the repository at this point in the history
Open/read/close API for WFDB format
  • Loading branch information
briangow authored Oct 18, 2024
2 parents 8fdce66 + 4ff6dec commit 260733d
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions waveform_benchmark/formats/wfdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import numpy
import soundfile
import wfdb

from waveform_benchmark.formats.base import BaseFormat
Expand Down Expand Up @@ -95,16 +96,151 @@ def read_waveforms(self, path, start_time, end_time, signal_names):
results[signal_name] = samples
return results

def open_waveforms(self, path, signal_names, **kwargs):
header = wfdb.rdheader(path)
dir_name = os.path.dirname(path)
readers = {}
sig_readers = {}

# Open a reader for each signal file of interest. For each
# individual signal, determine the corresponding reader object
# and channel index within that signal file.
for name in signal_names:
i = header.sig_name.index(name)
file_name = header.file_name[i]
try:
reader = readers[file_name]
except KeyError:
reader = self._open_reader(header, dir_name, file_name)
readers[file_name] = reader
channel = reader.channels.index(name)
sig_readers[name] = (reader, channel)

return {
# Frame frequency
'fs': header.fs,

# List of all readers
'readers': list(readers.values()),

# Dictionary mapping signal names to (reader, channel)
'sig_readers': sig_readers,
}

def _open_reader(self, header, dir_name, file_name):
path = os.path.join(dir_name, file_name)

rec_channels = [
i for i, name in enumerate(header.file_name) if name == file_name
]

for i in rec_channels:
assert header.fmt[i] == self.fmt, "incorrect format"

channels = [header.sig_name[i] for i in rec_channels]
spf = [header.samps_per_frame[i] for i in rec_channels]
inv_gain = numpy.array(
[1 / header.adc_gain[i] for i in rec_channels],
dtype=numpy.float32,
)
baseline = numpy.array(
[header.baseline[i] for i in rec_channels],
dtype=numpy.float32,
)

return self.Reader(path, channels, spf, inv_gain, baseline)

def close_waveforms(self, opened_files):
# Close all of the readers we opened above
for reader in opened_files['readers']:
reader.close()

def read_opened_waveforms(self, opened_files, start_time, end_time,
signal_names):
# Determine start/end frame number
fs = opened_files['fs']
start_frame = round(start_time * fs)
end_frame = round(end_time * fs)

# Read all samples (from selected signal files) for that range
# of frame numbers
for reader in opened_files['readers']:
reader.load_frames(start_frame, end_frame)

# Extract the desired signals and return a dictionary of
# arrays
results = {}
for name in signal_names:
reader, channel = opened_files['sig_readers'][name]
results[name] = reader.get_channel(channel)
return results


class WFDBFormat16(BaseWFDBFormat):
"""
WFDB with 16-bit binary storage.
"""
fmt = '16'

class Reader:
def __init__(self, path, channels, spf, inv_gain, baseline):
self.fp = open(path, 'rb')
self.inv_gain = inv_gain
self.baseline = baseline
self.channels = channels

channel_start = numpy.cumsum([0] + spf)
self.channel_slice = [
slice(x, y) for x, y in zip(channel_start, channel_start[1:])
]
self.total_spf = channel_start[-1]
self.frame_dtype = numpy.dtype('<i2') * self.total_spf
self.bytes_per_frame = self.frame_dtype.itemsize

def close(self):
self.fp.close()

def load_frames(self, start_frame, end_frame):
self.fp.seek(start_frame * self.bytes_per_frame)
self.data = numpy.fromfile(self.fp, self.frame_dtype,
end_frame - start_frame)

def get_channel(self, channel):
data = self.data[:, self.channel_slice[channel]]
result = data - self.baseline[channel]
result *= self.inv_gain[channel]
result[data == -32768] = numpy.nan
return result.reshape(-1)


class WFDBFormat516(BaseWFDBFormat):
"""
WFDB with FLAC compression.
"""
fmt = '516'

class Reader:
def __init__(self, path, channels, spf, inv_gain, baseline):
self.fp = soundfile.SoundFile(path)
self.spf = spf[0]
self.inv_gain = inv_gain
self.baseline = baseline
self.channels = channels

def close(self):
self.fp.close()

def load_frames(self, start_frame, end_frame):
self.fp.seek(start_frame * self.spf)
# Note that the following call may fail, for very large
# numbers of samples, if you are using an outdated version
# of libsndfile (prior to 1.2.0.)
self.data = self.fp.read((end_frame - start_frame) * self.spf,
dtype='int16', always_2d=True)

def get_channel(self, channel):
data = self.data[:, channel]
result = data - self.baseline[channel]
result *= self.inv_gain[channel]
result[data == -32768] = numpy.nan
return result

0 comments on commit 260733d

Please sign in to comment.