Description
Original discussion thread: https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252
Previously closed issue: #78
Related PR Merged: #96
Again posting a new issue because the previous issue has been closed and pr merged without providing a complete and thorough tutorial as was felt required in the initial discussion.
tldr; how to properly implement
torch.utils.data.Sampler
Specifically for my current use-case, I have a deep metric loss model that implements an online hard mining strategy (probability of the selection of some samples per epoch is higher than rest based on certain metrics ).
It didn't feel correct putting the logic in the transforms, and I currently do the mining in the "run" function:
- Pull the current minibatch1 from the dataloader
- Apply hard mining logic to find samples to train on from current batch :
- dry forward run without back-prop
- get all misclassified samples as 'hard samples' for current batch
- calculate probability ranking of this subset based on certain heuristics ( Wrongly classified sample of higher similarity will have higher probability)
- based on sample rankings again create a dataset on the fly for these samples, wherein
__getitem__
: chooses a minibatch2 as subset of these hard samples (might have repeated samples which have a higher probability ranking) - run forward and backward pass for samples in minibatch2
For reference size of minibatch1 ~ 10X minibatch2
The strategy works pretty well in training; though one can imagine the code sanity and running time 😞
I understand, if the dataloader class was not intended for online sampling which requires a forward pass;
but can we atleast have the complete tutorial on the data.sampler et al methods showing different offline sampling techniques - choosing samples from the current batch based on some set heuristics.
Or did I completely misunderstand the use of the Samplers ??