Skip to content

Commit 6db9ebf

Browse files
author
Benjamin Moody
committed
SignalMixin.smooth_frames: use minimal data type for result.
Instead of always returning the result as an int64 or float64 array, select the output type based on the types of the input arrays. The output type should be the smallest type that has the correct "kind" and is able to represent all input values. For example, in digital mode, if the input includes some int8 arrays and some int16 arrays, the result should be an int16 array. In physical mode, if the inputs are all float32, then the result will be float32; otherwise the result will be float64. However, although the output type should generally match the input type, intermediate results may need to be stored as a different type. For example, if the input and output are both int16, and one or more signals have spf > 1 and use the entire 16-bit range, then the sum of N samples will overflow an int16. Previously, it was fine simply to store the intermediate results in the output array itself, because the output array was 64-bit, and no WFDB format has more than 32-bit precision, and spf is (in practice) limited to at most 2**31-1. For simplicity, continue using int64 or float64 as the intermediate type, regardless of the actual input types and spf. At the same time, we can also optimize the calculation slightly by reshaping the input array and using np.sum, avoiding another Python loop.
1 parent b09bab4 commit 6db9ebf

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

wfdb/io/_signal.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -830,32 +830,53 @@ def smooth_frames(self, sigtype='physical'):
830830
# Total samples per frame
831831
tspf = sum(spf)
832832

833+
# The output data type should be the smallest type that can
834+
# represent any input sample value. The intermediate data type
835+
# must be able to represent the sum of spf[ch] sample values.
836+
833837
if sigtype == 'physical':
834838
expanded_signal = self.e_p_signal
835-
output_dtype = np.dtype('float64')
839+
intermediate_dtype = np.dtype('float64')
840+
allowed_dtypes = [
841+
np.dtype('float32'),
842+
np.dtype('float64'),
843+
]
836844
elif sigtype == 'digital':
837845
expanded_signal = self.e_d_signal
838-
output_dtype = np.dtype('int64')
846+
intermediate_dtype = np.dtype('int64')
847+
allowed_dtypes = [
848+
np.dtype('int8'),
849+
np.dtype('int16'),
850+
np.dtype('int32'),
851+
np.dtype('int64'),
852+
]
839853
else:
840854
raise ValueError("sigtype must be 'physical' or 'digital'")
841855

842856
n_sig = len(expanded_signal)
843857
sig_len = int(len(expanded_signal[0])/spf[0])
858+
input_dtypes = set()
844859
for ch in range(n_sig):
845860
if len(expanded_signal[ch]) != sig_len * spf[ch]:
846861
raise ValueError("length mismatch: signal %d has %d samples,"
847862
" expected %dx%d"
848863
% (ch, len(expanded_signal),
849864
sig_len, spf[ch]))
865+
input_dtypes.add(expanded_signal[ch].dtype)
866+
867+
for output_dtype in allowed_dtypes:
868+
if all(dt <= output_dtype for dt in input_dtypes):
869+
break
870+
850871
signal = np.zeros((sig_len, n_sig), dtype=output_dtype)
851872

852873
for ch in range(n_sig):
853874
if spf[ch] == 1:
854875
signal[:, ch] = expanded_signal[ch]
855876
else:
856-
for frame in range(spf[ch]):
857-
signal[:, ch] += expanded_signal[ch][frame::spf[ch]]
858-
signal[:, ch] = signal[:, ch] / spf[ch]
877+
frames = expanded_signal[ch].reshape(-1, spf[ch])
878+
signal_sum = np.sum(frames, axis=1, dtype=intermediate_dtype)
879+
signal[:, ch] = signal_sum / spf[ch]
859880

860881
return signal
861882

0 commit comments

Comments
 (0)