Skip to content

Commit

Permalink
Merge pull request #99 from chorus-ai/format_zarr_bit16int
Browse files Browse the repository at this point in the history
zarr format updated to bit16int
  • Loading branch information
briangow authored Oct 18, 2024
2 parents 260733d + e52b458 commit b8334cb
Showing 1 changed file with 36 additions and 50 deletions.
86 changes: 36 additions & 50 deletions waveform_benchmark/formats/zarr.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
import numpy as np
import zarr

from waveform_benchmark.formats.base import BaseFormat


class Zarr(BaseFormat):
"""
Example format using Zarr
Example format using Zarr with 16-bit integer waveforms.
"""
def write_waveforms(self, path, waveforms):
# Initialize Zarr group
root_group = zarr.open_group(path, mode='w')
nanval = -32768 # Sentinel value for NaN

for name, waveform in waveforms.items():
length = waveform['chunks'][-1]['end_sample']
samples = np.empty(length, dtype=np.float32)
samples[:] = np.nan
samples = np.empty(length, dtype=np.int16)
samples[:] = nanval

max_gain = max(chunk['gain'] for chunk in waveform['chunks']) # Get max gain from the chunks

for chunk in waveform['chunks']:
start = chunk['start_sample']
end = chunk['end_sample']
samples[start:end] = chunk['samples']
# Replace NaN values in the chunk with sentinel value
cursamples = np.where(np.isnan(chunk['samples']), nanval, np.round(chunk['samples'] * chunk['gain']).astype(np.int16))
samples[start:end] = cursamples

if self.fmt == 'Compressed':
ds = root_group.create_dataset(
name,
data=samples,
chunks=True,
dtype=np.int16,
compressor=zarr.Blosc(cname='zstd', clevel=3, shuffle=zarr.Blosc.BITSHUFFLE)
)
else:
ds = root_group.create_dataset(name, data=samples, chunks=True, dtype=np.int16)

# Create a dataset for each waveform within the root group.
ds = root_group.create_dataset(name, data=samples, chunks=True, dtype=np.float32)
ds.attrs['units'] = waveform['units']
ds.attrs['samples_per_second'] = waveform['samples_per_second']
ds.attrs['nanvalue'] = nanval # Store the sentinel value for NaN
ds.attrs['gain'] = max_gain # Store the gain

def read_waveforms(self, path, start_time, end_time, signal_names):
# Open the Zarr group
Expand All @@ -35,53 +49,25 @@ def read_waveforms(self, path, start_time, end_time, signal_names):
for signal_name in signal_names:
ds = root_group[signal_name]
samples_per_second = ds.attrs['samples_per_second']
nanval = ds.attrs['nanvalue'] # Retrieve the sentinel value for NaN
gain = ds.attrs['gain'] # Retrieve the gain

start_sample = round(start_time * samples_per_second)
end_sample = round(end_time * samples_per_second)

# Random access the Zarr array
results[signal_name] = ds[start_sample:end_sample]

return results

class Zarr_compressed(BaseFormat):
"""
Example format using Zarr with compression.
"""

def write_waveforms(self, path, waveforms):
# Initialize Zarr group
root_group = zarr.open_group(path, mode='w')

for name, waveform in waveforms.items():
length = waveform['chunks'][-1]['end_sample']
samples = np.empty(length, dtype=np.float32)
samples[:] = np.nan

for chunk in waveform['chunks']:
start = chunk['start_sample']
end = chunk['end_sample']
samples[start:end] = chunk['samples']

# each waveform within the root group with compression.
ds = root_group.create_dataset(name, data=samples, chunks=True, dtype=np.float32, compressor=zarr.Blosc(cname='zstd', clevel=9, shuffle=zarr.Blosc.BITSHUFFLE))
ds.attrs['units'] = waveform['units']
ds.attrs['samples_per_second'] = waveform['samples_per_second']

sig_data = ds[start_sample:end_sample]
naninds = (sig_data == nanval)
sig_data = sig_data.astype(np.float32)
sig_data = sig_data / gain
sig_data[naninds] = np.nan

def read_waveforms(self, path, start_time, end_time, signal_names):
# Open the Zarr group
root_group = zarr.open_group(path, mode='r')
results[signal_name] = sig_data

results = {}
for signal_name in signal_names:
ds = root_group[signal_name]
samples_per_second = ds.attrs['samples_per_second']

start_sample = round(start_time * samples_per_second)
end_sample = round(end_time * samples_per_second)

# Random access the Zarr array
results[signal_name] = ds[start_sample:end_sample]
return results

class Zarr_Compressed(Zarr):
fmt = 'Compressed'

return results
class Zarr_Uncompressed(Zarr):
fmt = 'Uncompressed'

0 comments on commit b8334cb

Please sign in to comment.