Skip to content

Commit 95131de

Browse files
ehnryxfmassa
authored andcommitted
expose audio_channels as a parameter to kinetics dataset (#1559)
1 parent be6f398 commit 95131de

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

torchvision/datasets/kinetics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Kinetics400(VisionDataset):
3939
def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
4040
extensions=('avi',), transform=None, _precomputed_metadata=None,
4141
num_workers=1, _video_width=0, _video_height=0,
42-
_video_min_dimension=0, _audio_samples=0):
42+
_video_min_dimension=0, _audio_samples=0, _audio_channels=0):
4343
super(Kinetics400, self).__init__(root)
4444

4545
classes = list(sorted(list_dir(root)))
@@ -58,6 +58,7 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
5858
_video_height=_video_height,
5959
_video_min_dimension=_video_min_dimension,
6060
_audio_samples=_audio_samples,
61+
_audio_channels=_audio_channels,
6162
)
6263
self.transform = transform
6364

torchvision/datasets/video_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class VideoClips(object):
7171
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
7272
frame_rate=None, _precomputed_metadata=None, num_workers=0,
7373
_video_width=0, _video_height=0, _video_min_dimension=0,
74-
_audio_samples=0):
74+
_audio_samples=0, _audio_channels=0):
7575

7676
self.video_paths = video_paths
7777
self.num_workers = num_workers
@@ -81,6 +81,7 @@ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1
8181
self._video_height = _video_height
8282
self._video_min_dimension = _video_min_dimension
8383
self._audio_samples = _audio_samples
84+
self._audio_channels = _audio_channels
8485

8586
if _precomputed_metadata is None:
8687
self._compute_frame_pts()
@@ -149,7 +150,8 @@ def subset(self, indices):
149150
_video_width=self._video_width,
150151
_video_height=self._video_height,
151152
_video_min_dimension=self._video_min_dimension,
152-
_audio_samples=self._audio_samples)
153+
_audio_samples=self._audio_samples,
154+
_audio_channels=self._audio_channels)
153155

154156
@staticmethod
155157
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
@@ -298,6 +300,7 @@ def get_clip(self, idx):
298300
video_pts_range=(video_start_pts, video_end_pts),
299301
video_timebase=info["video_timebase"],
300302
audio_samples=self._audio_samples,
303+
audio_channels=self._audio_channels,
301304
audio_pts_range=(audio_start_pts, audio_end_pts),
302305
audio_timebase=audio_timebase,
303306
)

0 commit comments

Comments
 (0)