Skip to content

Commit 626307a

Browse files
committed
Add test for csv_to_wfdb().
1 parent 17b9349 commit 626307a

File tree

1 file changed

+68
-3
lines changed

1 file changed

+68
-3
lines changed

tests/io/test_convert.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1+
import os
2+
import shutil
3+
import unittest
4+
15
import numpy as np
26

37
from wfdb.io.record import rdrecord
48
from wfdb.io.convert.edf import read_edf
9+
from wfdb.io.convert.csv import csv_to_wfdb
510

611

7-
class TestConvert:
12+
class TestEdfToWfdb:
13+
"""
14+
Tests for the io.convert.edf module.
15+
"""
816
def test_edf_uniform(self):
917
"""
1018
EDF format conversion to MIT for uniform sample rates.
11-
1219
"""
1320
# Uniform sample rates
1421
record_MIT = rdrecord("sample-data/n16").__dict__
@@ -60,7 +67,6 @@ def test_edf_uniform(self):
6067
def test_edf_non_uniform(self):
6168
"""
6269
EDF format conversion to MIT for non-uniform sample rates.
63-
6470
"""
6571
# Non-uniform sample rates
6672
record_MIT = rdrecord("sample-data/wave_4").__dict__
@@ -108,3 +114,62 @@ def test_edf_non_uniform(self):
108114

109115
target_results = len(fields) * [True]
110116
assert np.array_equal(test_results, target_results)
117+
118+
119+
class TestCsvToWfdb(unittest.TestCase):
120+
"""
121+
Tests for the io.convert.csv module.
122+
"""
123+
def setUp(self):
124+
"""
125+
Create a temporary directory containing data for testing.
126+
127+
Load 100.dat file for comparison to 100.csv file.
128+
"""
129+
self.test_dir = 'test_output'
130+
os.makedirs(self.test_dir, exist_ok=True)
131+
132+
self.record_100_csv = 'sample-data/100.csv'
133+
self.record_100_dat = rdrecord('sample-data/100', physical=True)
134+
135+
def tearDown(self):
136+
"""
137+
Remove the temporary directory after the test.
138+
"""
139+
if os.path.exists(self.test_dir):
140+
shutil.rmtree(self.test_dir)
141+
142+
def test_write_dir(self):
143+
"""
144+
Call the function with the write_dir argument.
145+
"""
146+
csv_to_wfdb(
147+
file_name=self.record_100_csv,
148+
fs=360,
149+
units='mV',
150+
write_dir=self.test_dir
151+
)
152+
153+
# Check if the output files are created in the specified directory
154+
base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0]
155+
expected_dat_file = os.path.join(self.test_dir, f'{base_name}.dat')
156+
expected_hea_file = os.path.join(self.test_dir, f'{base_name}.hea')
157+
158+
self.assertTrue(os.path.exists(expected_dat_file))
159+
self.assertTrue(os.path.exists(expected_hea_file))
160+
161+
# Check that newly written file matches the 100.dat file
162+
record_write = rdrecord(os.path.join(self.test_dir, base_name))
163+
164+
self.assertEqual(record_write.fs, 360)
165+
self.assertEqual(record_write.fs, self.record_100_dat.fs)
166+
self.assertEqual(record_write.units, ['mV', 'mV'])
167+
self.assertEqual(record_write.units, self.record_100_dat.units)
168+
self.assertEqual(record_write.sig_name, ['MLII', 'V5'])
169+
self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name)
170+
self.assertEqual(record_write.p_signal.size, 1300000)
171+
self.assertEqual(record_write.p_signal.size, self.record_100_dat.p_signal.size)
172+
173+
174+
if __name__ == '__main__':
175+
unittest.main()

0 commit comments

Comments
 (0)