|
| 1 | +import os |
| 2 | +import shutil |
| 3 | +import unittest |
| 4 | + |
1 | 5 | import numpy as np
|
2 | 6 |
|
3 | 7 | from wfdb.io.record import rdrecord
|
4 | 8 | from wfdb.io.convert.edf import read_edf
|
| 9 | +from wfdb.io.convert.csv import csv_to_wfdb |
5 | 10 |
|
6 | 11 |
|
7 |
| -class TestConvert: |
| 12 | +class TestEdfToWfdb: |
| 13 | + """ |
| 14 | + Tests for the io.convert.edf module. |
| 15 | + """ |
8 | 16 | def test_edf_uniform(self):
|
9 | 17 | """
|
10 | 18 | EDF format conversion to MIT for uniform sample rates.
|
11 |
| -
|
12 | 19 | """
|
13 | 20 | # Uniform sample rates
|
14 | 21 | record_MIT = rdrecord("sample-data/n16").__dict__
|
@@ -60,7 +67,6 @@ def test_edf_uniform(self):
|
60 | 67 | def test_edf_non_uniform(self):
|
61 | 68 | """
|
62 | 69 | EDF format conversion to MIT for non-uniform sample rates.
|
63 |
| -
|
64 | 70 | """
|
65 | 71 | # Non-uniform sample rates
|
66 | 72 | record_MIT = rdrecord("sample-data/wave_4").__dict__
|
@@ -108,3 +114,62 @@ def test_edf_non_uniform(self):
|
108 | 114 |
|
109 | 115 | target_results = len(fields) * [True]
|
110 | 116 | assert np.array_equal(test_results, target_results)
|
| 117 | + |
| 118 | + |
| 119 | +class TestCsvToWfdb(unittest.TestCase): |
| 120 | + """ |
| 121 | + Tests for the io.convert.csv module. |
| 122 | + """ |
| 123 | + def setUp(self): |
| 124 | + """ |
| 125 | + Create a temporary directory containing data for testing. |
| 126 | +
|
| 127 | + Load 100.dat file for comparison to 100.csv file. |
| 128 | + """ |
| 129 | + self.test_dir = 'test_output' |
| 130 | + os.makedirs(self.test_dir, exist_ok=True) |
| 131 | + |
| 132 | + self.record_100_csv = 'sample-data/100.csv' |
| 133 | + self.record_100_dat = rdrecord('sample-data/100', physical=True) |
| 134 | + |
| 135 | + def tearDown(self): |
| 136 | + """ |
| 137 | + Remove the temporary directory after the test. |
| 138 | + """ |
| 139 | + if os.path.exists(self.test_dir): |
| 140 | + shutil.rmtree(self.test_dir) |
| 141 | + |
| 142 | + def test_write_dir(self): |
| 143 | + """ |
| 144 | + Call the function with the write_dir argument. |
| 145 | + """ |
| 146 | + csv_to_wfdb( |
| 147 | + file_name=self.record_100_csv, |
| 148 | + fs=360, |
| 149 | + units='mV', |
| 150 | + write_dir=self.test_dir |
| 151 | + ) |
| 152 | + |
| 153 | + # Check if the output files are created in the specified directory |
| 154 | + base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0] |
| 155 | + expected_dat_file = os.path.join(self.test_dir, f'{base_name}.dat') |
| 156 | + expected_hea_file = os.path.join(self.test_dir, f'{base_name}.hea') |
| 157 | + |
| 158 | + self.assertTrue(os.path.exists(expected_dat_file)) |
| 159 | + self.assertTrue(os.path.exists(expected_hea_file)) |
| 160 | + |
| 161 | + # Check that newly written file matches the 100.dat file |
| 162 | + record_write = rdrecord(os.path.join(self.test_dir, base_name)) |
| 163 | + |
| 164 | + self.assertEqual(record_write.fs, 360) |
| 165 | + self.assertEqual(record_write.fs, self.record_100_dat.fs) |
| 166 | + self.assertEqual(record_write.units, ['mV', 'mV']) |
| 167 | + self.assertEqual(record_write.units, self.record_100_dat.units) |
| 168 | + self.assertEqual(record_write.sig_name, ['MLII', 'V5']) |
| 169 | + self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name) |
| 170 | + self.assertEqual(record_write.p_signal.size, 1300000) |
| 171 | + self.assertEqual(record_write.p_signal.size, self.record_100_dat.p_signal.size) |
| 172 | + |
| 173 | + |
| 174 | +if __name__ == '__main__': |
| 175 | + unittest.main() |
0 commit comments