Skip to content

Commit 396d894

Browse files
committed
update wfdb.io.wrsamp to allow writing a signal with unique samps_per_frame
1 parent 6a0de80 commit 396d894

File tree

1 file changed

+81
-39
lines changed

1 file changed

+81
-39
lines changed

wfdb/io/record.py

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import datetime
2-
import multiprocessing.dummy
32
import posixpath
43
import os
54
import re
@@ -2822,6 +2821,9 @@ def wrsamp(
28222821
sig_name,
28232822
p_signal=None,
28242823
d_signal=None,
2824+
e_p_signal=None,
2825+
e_d_signal=None,
2826+
samps_per_frame=None,
28252827
fmt=None,
28262828
adc_gain=None,
28272829
baseline=None,
@@ -2860,6 +2862,14 @@ def wrsamp(
28602862
file(s). The dtype must be an integer type. Either p_signal or d_signal
28612863
must be set, but not both. In addition, if d_signal is set, fmt, gain
28622864
and baseline must also all be set.
2865+
e_p_signal : ndarray, optional
2866+
The expanded physical conversion of the signal. Either a 2d numpy
2867+
array or a list of 1d numpy arrays.
2868+
e_d_signal : ndarray, optional
2869+
The expanded digital conversion of the signal. Either a 2d numpy
2870+
array or a list of 1d numpy arrays.
2871+
samps_per_frame : int or list of ints, optional
2872+
The total number of samples per frame.
28632873
fmt : list, optional
28642874
A list of strings giving the WFDB format of each file used to store each
28652875
channel. Accepted formats are: '80','212','16','24', and '32'. There are
@@ -2911,59 +2921,91 @@ def wrsamp(
29112921
if "." in record_name:
29122922
raise Exception("Record name must not contain '.'")
29132923
# Check input field combinations
2914-
if p_signal is not None and d_signal is not None:
2924+
signal_list = [p_signal, d_signal, e_p_signal, e_d_signal]
2925+
signals_set = sum(1 for var in signal_list if var is not None)
2926+
if signals_set != 1:
29152927
raise Exception(
2916-
"Must only give one of the inputs: p_signal or d_signal"
2928+
"Must provide one and only one input signal: p_signal, d_signal, e_p_signal, or e_d_signal"
29172929
)
2918-
if d_signal is not None:
2930+
if d_signal is not None or e_d_signal is not None:
29192931
if fmt is None or adc_gain is None or baseline is None:
29202932
raise Exception(
2921-
"When using d_signal, must also specify 'fmt', 'gain', and 'baseline' fields."
2933+
"When using d_signal or e_d_signal, must also specify 'fmt', 'gain', and 'baseline' fields"
29222934
)
2923-
# Depending on whether d_signal or p_signal was used, set other
2924-
# required features.
2935+
2936+
# If samps_per_frame is a list, check that it aligns as expected with the channels in the signal
2937+
if len(samps_per_frame) > 1:
2938+
# Get properties of the signal being passed
2939+
non_none_signal = next(signal for signal in signal_list if signal is not None)
2940+
if isinstance(non_none_signal, np.ndarray):
2941+
num_sig_channels = non_none_signal.shape[1]
2942+
channel_samples = [non_none_signal.shape[0]] * non_none_signal.shape[1]
2943+
elif isinstance(non_none_signal, list):
2944+
num_sig_channels = len(non_none_signal)
2945+
channel_samples = [len(channel) for channel in non_none_signal]
2946+
else:
2947+
raise TypeError("Unsupported signal format. Must be ndarray or list of lists.")
2948+
2949+
# Check that the number of channels matches the number of samps_per_frame entries
2950+
if num_sig_channels != len(samps_per_frame):
2951+
raise Exception(
2952+
"When passing samps_per_frame as a list, it must have the same number of entries as the signal has channels"
2953+
)
2954+
2955+
# Check that the number of frames is the same across all channels
2956+
frames = [a / b for a, b in zip(channel_samples, samps_per_frame)]
2957+
if len(set(frames)) > 1:
2958+
raise Exception(
2959+
"The number of samples in a channel divided by the corresponding samples_per_frame entry must be uniform"
2960+
)
2961+
2962+
# Create the Record object
2963+
record = Record(
2964+
record_name=record_name,
2965+
p_signal=p_signal,
2966+
d_signal=d_signal,
2967+
e_p_signal=e_p_signal,
2968+
e_d_signal=e_d_signal,
2969+
samps_per_frame=samps_per_frame,
2970+
fs=fs,
2971+
fmt=fmt,
2972+
units=units,
2973+
sig_name=sig_name,
2974+
adc_gain=adc_gain,
2975+
baseline=baseline,
2976+
comments=comments,
2977+
base_time=base_time,
2978+
base_date=base_date,
2979+
base_datetime=base_datetime,
2980+
)
2981+
2982+
# Depending on which signal was used, set other required fields.
29252983
if p_signal is not None:
2926-
# Create the Record object
2927-
record = Record(
2928-
record_name=record_name,
2929-
p_signal=p_signal,
2930-
fs=fs,
2931-
fmt=fmt,
2932-
units=units,
2933-
sig_name=sig_name,
2934-
adc_gain=adc_gain,
2935-
baseline=baseline,
2936-
comments=comments,
2937-
base_time=base_time,
2938-
base_date=base_date,
2939-
base_datetime=base_datetime,
2940-
)
29412984
# Compute optimal fields to store the digital signal, carry out adc,
29422985
# and set the fields.
29432986
record.set_d_features(do_adc=1)
2944-
else:
2945-
# Create the Record object
2946-
record = Record(
2947-
record_name=record_name,
2948-
d_signal=d_signal,
2949-
fs=fs,
2950-
fmt=fmt,
2951-
units=units,
2952-
sig_name=sig_name,
2953-
adc_gain=adc_gain,
2954-
baseline=baseline,
2955-
comments=comments,
2956-
base_time=base_time,
2957-
base_date=base_date,
2958-
base_datetime=base_datetime,
2959-
)
2987+
elif d_signal is not None:
29602988
# Use d_signal to set the fields directly
29612989
record.set_d_features()
2990+
elif e_p_signal is not None:
2991+
# Compute optimal fields to store the digital signal, carry out adc,
2992+
# and set the fields.
2993+
record.set_d_features(do_adc=1, expanded=True)
2994+
elif e_d_signal is not None:
2995+
# Use e_d_signal to set the fields directly
2996+
record.set_d_features(expanded=True)
29622997

29632998
# Set default values of any missing field dependencies
29642999
record.set_defaults()
3000+
3001+
# Determine whether the signal is expanded
3002+
if (e_d_signal or e_p_signal) is not None:
3003+
expanded = True
3004+
else:
3005+
expanded = False
3006+
29653007
# Write the record files - header and associated dat
2966-
record.wrsamp(write_dir=write_dir)
3008+
record.wrsamp(write_dir=write_dir, expanded=expanded)
29673009

29683010

29693011
def dl_database(

0 commit comments

Comments
 (0)