Open
Description
We could add utilities for PyTorch.
Basically it should be inheriting from torch.utils.data.Sampler
.
The implementation could look like something:
class BalancedSampler(Sampler):
def __init__(self, X, y, sampler=None, random_state=None):
self.X = X
self.y = y
self.sampler = sampler
self.random_state = random_state
self._sample()
def _sample(self):
random_state = check_random_state(self.random_state)
if self.sampler is None:
self.sampler_ = RandomUnderSampler(return_indices=True,
random_state=random_state)
else:
if not hasattr(self.sampler, 'return_indices'):
raise ValueError("'sampler' needs to return the indices of "
"the samples selected. Provide a sampler "
"which has an attribute 'return_indices'.")
self.sampler_ = clone(self.sampler)
self.sampler_.set_params(return_indices=True)
set_random_state(self.sampler_, random_state)
_, _, self.indices_ = self.sampler_.fit_sample(self.X, self.y)
# shuffle the indices since the sampler are packing them by class
random_state.shuffle(self.indices_)
def __iter__(self):
return iter(self.indices_.tolist())
def __len__(self):
return len(self.X.shape[0])