Skip to content

Commit 5539c02

Browse files
committed
Fix file close issue.
1 parent 30b66fa commit 5539c02

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

tests/test_archive.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ def temp_record():
5454
dat_path = os.path.join(tmpdir, record_basename + ".dat")
5555
archive_path = os.path.join(tmpdir, record_basename + ".wfdb")
5656

57-
WFDBArchive.create_archive(
58-
None,
59-
file_list=[hea_path, dat_path],
60-
output_path=archive_path,
61-
)
57+
with WFDBArchive(record_name=record_basename, mode="w") as archive:
58+
archive.create_archive(
59+
file_list=[hea_path, dat_path],
60+
output_path=archive_path,
61+
)
6262

6363
yield {
6464
"record_name": os.path.join(tmpdir, record_basename),

wfdb/io/archive.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import os
23
import shutil
34
import zipfile
@@ -44,10 +45,9 @@ def __init__(self, record_name, mode="r"):
4445
self.zipfile = zipfile.ZipFile(self.archive_path, mode="r")
4546

4647
elif mode == "w":
47-
# Initialize an empty archive on disk
48+
# Create archive file if needed
4849
if not os.path.exists(self.archive_path):
49-
with zipfile.ZipFile(self.archive_path, mode="w"):
50-
pass # Just create the file
50+
WFDBArchive.make_archive_file([], self.archive_path)
5151
self.zipfile = zipfile.ZipFile(self.archive_path, mode="a")
5252

5353
def __enter__(self):
@@ -62,6 +62,15 @@ def exists(self, filename):
6262
"""
6363
return self.zipfile and filename in self.zipfile.namelist()
6464

65+
@staticmethod
66+
def make_archive_file(file_list, output_path):
67+
with zipfile.ZipFile(output_path, mode="w") as zf:
68+
for file in file_list:
69+
compress = zipfile.ZIP_DEFLATED
70+
zf.write(
71+
file, arcname=os.path.basename(file), compress_type=compress
72+
)
73+
6574
@contextmanager
6675
def open(self, filename, mode="r"):
6776
"""
@@ -73,8 +82,6 @@ def open(self, filename, mode="r"):
7382
if "b" in mode:
7483
yield f
7584
else:
76-
import io
77-
7885
yield io.TextIOWrapper(f)
7986
else:
8087
raise FileNotFoundError(
@@ -97,7 +104,7 @@ def write(self, filename, data):
97104
self.zipfile.writestr(filename, data)
98105
return
99106

100-
# If already opened in read or append mode, use the replace-then-move trick
107+
# If already opened in read or append mode, use replace-then-move
101108
tmp_path = self.archive_path + ".tmp"
102109
with zipfile.ZipFile(self.archive_path, mode="r") as zin:
103110
with zipfile.ZipFile(tmp_path, mode="w") as zout:
@@ -114,16 +121,13 @@ def create_archive(self, file_list, output_path=None):
114121
If output_path is not specified, uses self.archive_path.
115122
"""
116123
output_path = output_path or self.archive_path
117-
with zipfile.ZipFile(output_path, mode="w") as zf:
118-
for file in file_list:
119-
compress = (
120-
zipfile.ZIP_STORED
121-
if file.endswith((".hea", ".hea.json", ".hea.yml"))
122-
else zipfile.ZIP_DEFLATED
123-
)
124-
zf.write(
125-
file, arcname=os.path.basename(file), compress_type=compress
126-
)
124+
WFDBArchive.make_archive_file(file_list, output_path)
125+
126+
# If this archive object points to the archive, reload the zipfile in append mode
127+
if output_path == self.archive_path:
128+
if self.zipfile:
129+
self.zipfile.close()
130+
self.zipfile = zipfile.ZipFile(self.archive_path, mode="a")
127131

128132

129133
def get_archive(record_base_name, mode="r"):

0 commit comments

Comments
 (0)