Skip to content

Commit caff031

Browse files
committed
Add output_dir argument to csv_to_wfdb. Fixes #67.
1 parent 34b989e commit caff031

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

wfdb/io/convert/csv.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def csv_to_wfdb(
3333
header=True,
3434
delimiter=",",
3535
verbose=False,
36+
output_dir=None,
3637
):
3738
"""
3839
Read a WFDB header file and return either a `Record` object with the
@@ -235,6 +236,9 @@ def csv_to_wfdb(
235236
verbose : bool, optional
236237
Whether to print all the information read about the file (True) or
237238
not (False).
239+
output_dir : str, optional
240+
The directory where the output files will be saved. If not provided,
241+
the output files will be saved in the same directory as the input file.
238242
239243
Returns
240244
-------
@@ -291,6 +295,7 @@ def csv_to_wfdb(
291295
df_CSV = pd.read_csv(file_name, delimiter=delimiter, header=None)
292296
if verbose:
293297
print("Successfully read CSV")
298+
294299
# Extract the entire signal from the dataframe
295300
p_signal = df_CSV.values
296301
# The dataframe should be in (`sig_len`, `n_sig`) dimensions
@@ -300,6 +305,7 @@ def csv_to_wfdb(
300305
n_sig = p_signal.shape[1]
301306
if verbose:
302307
print("Number of signals: {}".format(n_sig))
308+
303309
# Check if signal names are valid and set defaults
304310
if not sig_name:
305311
if header:
@@ -318,15 +324,23 @@ def csv_to_wfdb(
318324
if verbose:
319325
print("Signal names: {}".format(sig_name))
320326

321-
# Set the output header file name to be the same, remove path
322-
if os.sep in file_name:
323-
file_name = file_name.split(os.sep)[-1]
324-
record_name = file_name.replace(".csv", "")
327+
# Determine the output directory
328+
if output_dir:
329+
if not os.path.exists(output_dir):
330+
os.makedirs(output_dir)
331+
output_base = os.path.join(
332+
output_dir, os.path.basename(file_name).replace(".csv", "")
333+
)
334+
else:
335+
if os.sep in file_name:
336+
file_name = file_name.split(os.sep)[-1]
337+
output_base = file_name.replace(".csv", "")
338+
325339
if verbose:
326-
print("Output header: {}.hea".format(record_name))
340+
print("Output base: {}".format(output_base))
327341

328342
# Replace the CSV file tag with DAT
329-
dat_file_name = file_name.replace(".csv", ".dat")
343+
dat_file_name = output_base + ".dat"
330344
dat_file_name = [dat_file_name] * n_sig
331345
if verbose:
332346
print("Output record: {}".format(dat_file_name[0]))
@@ -419,7 +433,7 @@ def csv_to_wfdb(
419433
if record_only:
420434
# Create the record from the input and generated values
421435
record = Record(
422-
record_name=record_name,
436+
record_name=output_base,
423437
n_sig=n_sig,
424438
fs=fs,
425439
samps_per_frame=samps_per_frame,
@@ -454,7 +468,7 @@ def csv_to_wfdb(
454468
else:
455469
# Write the information to a record and header file
456470
wrsamp(
457-
record_name=record_name,
471+
record_name=output_base,
458472
fs=fs,
459473
units=units,
460474
sig_name=sig_name,

0 commit comments

Comments
 (0)