Skip to content

PyTorch utilities sampler #424

Open
Open
@glemaitre

Description

@glemaitre

We could add utilities for PyTorch.
Basically it should be inheriting from torch.utils.data.Sampler.

The implementation could look like something:

class BalancedSampler(Sampler):

    def __init__(self, X, y, sampler=None, random_state=None):
        self.X = X
        self.y = y
        self.sampler = sampler
        self.random_state = random_state
        self._sample()

    def _sample(self):
        random_state = check_random_state(self.random_state)
        if self.sampler is None:
            self.sampler_ = RandomUnderSampler(return_indices=True,
                                               random_state=random_state)
        else:
            if not hasattr(self.sampler, 'return_indices'):
                raise ValueError("'sampler' needs to return the indices of "
                                 "the samples selected. Provide a sampler "
                                 "which has an attribute 'return_indices'.")
            self.sampler_ = clone(self.sampler)
            self.sampler_.set_params(return_indices=True)
            set_random_state(self.sampler_, random_state)

        _, _, self.indices_ = self.sampler_.fit_sample(self.X, self.y)
        # shuffle the indices since the sampler are packing them by class
        random_state.shuffle(self.indices_)

    def __iter__(self):
        return iter(self.indices_.tolist())

    def __len__(self):
        return len(self.X.shape[0])

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions