Skip to content

Commit

Permalink
Compressed version handled with gain
Browse files Browse the repository at this point in the history
  • Loading branch information
del42 committed Oct 18, 2024
1 parent b7ab475 commit e52b458
Showing 1 changed file with 15 additions and 64 deletions.
79 changes: 15 additions & 64 deletions waveform_benchmark/formats/zarr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import numpy as np
import zarr

from waveform_benchmark.formats.base import BaseFormat


class Zarr(BaseFormat):
"""
Example format using Zarr with 16-bit integer waveforms.
Expand All @@ -27,8 +25,17 @@ def write_waveforms(self, path, waveforms):
cursamples = np.where(np.isnan(chunk['samples']), nanval, np.round(chunk['samples'] * chunk['gain']).astype(np.int16))
samples[start:end] = cursamples

# Create a dataset for each waveform within the root group.
ds = root_group.create_dataset(name, data=samples, chunks=True, dtype=np.int16)
if self.fmt == 'Compressed':
ds = root_group.create_dataset(
name,
data=samples,
chunks=True,
dtype=np.int16,
compressor=zarr.Blosc(cname='zstd', clevel=3, shuffle=zarr.Blosc.BITSHUFFLE)
)
else:
ds = root_group.create_dataset(name, data=samples, chunks=True, dtype=np.int16)

ds.attrs['units'] = waveform['units']
ds.attrs['samples_per_second'] = waveform['samples_per_second']
ds.attrs['nanvalue'] = nanval # Store the sentinel value for NaN
Expand Down Expand Up @@ -59,64 +66,8 @@ def read_waveforms(self, path, start_time, end_time, signal_names):

return results

def open_waveforms(self, path: str, signal_names: list, **kwargs):
"""
Open Zarr waveforms.
"""
output = {}
root_group = zarr.open_group(path, mode='r')
for signal_name in signal_names:
output[signal_name] = root_group[signal_name]
return output


class Zarr_compressed(BaseFormat):
"""
Example format using Zarr with compression and 16-bit integer waveforms.
"""

def write_waveforms(self, path, waveforms):
# Initialize Zarr group
root_group = zarr.open_group(path, mode='w')
nanval = -32768 # Sentinel value for NaN

for name, waveform in waveforms.items():
length = waveform['chunks'][-1]['end_sample']
samples = np.empty(length, dtype=np.int16)
samples[:] = nanval

for chunk in waveform['chunks']:
start = chunk['start_sample']
end = chunk['end_sample']
cursamples = np.where(np.isnan(chunk['samples']), nanval, chunk['samples'])
samples[start:end] = cursamples

# each wavefrom is stored as a dataset within the root group
ds = root_group.create_dataset(name, data=samples, chunks=True, dtype=np.int16,
compressor=zarr.Blosc(cname='zstd', clevel=1, shuffle=zarr.Blosc.BITSHUFFLE))
ds.attrs['units'] = waveform['units']
ds.attrs['samples_per_second'] = waveform['samples_per_second']
ds.attrs['nanvalue'] = nanval

def read_waveforms(self, path, start_time, end_time, signal_names):
# Open the Zarr group
root_group = zarr.open_group(path, mode='r')

results = {}
for signal_name in signal_names:
ds = root_group[signal_name]
samples_per_second = ds.attrs['samples_per_second']
nanval = ds.attrs['nanvalue'] # Retrieve the sentinel value

start_sample = round(start_time * samples_per_second)
end_sample = round(end_time * samples_per_second)

# Random access the Zarr array
sig_data = ds[start_sample:end_sample]
naninds = (sig_data == nanval)
sig_data = sig_data.astype(float)
sig_data[naninds] = np.nan

results[signal_name] = sig_data
class Zarr_Compressed(Zarr):
fmt = 'Compressed'

return results
class Zarr_Uncompressed(Zarr):
fmt = 'Uncompressed'

0 comments on commit e52b458

Please sign in to comment.