Source code for mxnet.gluon.data.sampler

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=
"""Dataset sampler."""
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'FilterSampler', 'BatchSampler',
           'IntervalSampler']

import numpy as np

[docs]class Sampler(object): """Base class for samplers. All samplers should subclass `Sampler` and define `__iter__` and `__len__` methods. """ def __iter__(self): raise NotImplementedError def __len__(self): raise NotImplementedError
[docs]class SequentialSampler(Sampler): """Samples elements from [start, start+length) sequentially. Parameters ---------- length : int Length of the sequence. start : int, default is 0 The start of the sequence index. """ def __init__(self, length, start=0): self._length = length self._start = start def __iter__(self): return iter(range(self._start, self._start + self._length)) def __len__(self): return self._length
[docs]class RandomSampler(Sampler): """Samples elements from [0, length) randomly without replacement. Parameters ---------- length : int Length of the sequence. """ def __init__(self, length): self._length = length def __iter__(self): indices = np.arange(self._length) np.random.shuffle(indices) return iter(indices) def __len__(self): return self._length
[docs]class FilterSampler(Sampler): """Samples elements from a Dataset for which `fn` returns True. Parameters ---------- fn : callable A callable function that takes a sample and returns a boolean dataset : Dataset The dataset to filter. """ def __init__(self, fn, dataset): self._fn = fn self._dataset = dataset self._indices = [i for i, sample in enumerate(dataset) if fn(sample)] def __iter__(self): return iter(self._indices) def __len__(self): return len(self._indices)
[docs]class BatchSampler(Sampler): """Wraps over another `Sampler` and return mini-batches of samples. Parameters ---------- sampler : Sampler The source Sampler. batch_size : int Size of mini-batch. last_batch : {'keep', 'discard', 'rollover'} Specifies how the last batch is handled if batch_size does not evenly divide sequence length. If 'keep', the last batch will be returned directly, but will contain less element than `batch_size` requires. If 'discard', the last batch will be discarded. If 'rollover', the remaining elements will be rolled over to the next iteration. Examples -------- >>> sampler = gluon.data.SequentialSampler(10) >>> batch_sampler = gluon.data.BatchSampler(sampler, 3, 'keep') >>> list(batch_sampler) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] """ def __init__(self, sampler, batch_size, last_batch='keep'): self._sampler = sampler self._batch_size = batch_size self._last_batch = last_batch self._prev = [] def __iter__(self): batch, self._prev = self._prev, [] for i in self._sampler: batch.append(i) if len(batch) == self._batch_size: yield batch batch = [] if batch: if self._last_batch == 'keep': yield batch elif self._last_batch == 'discard': return elif self._last_batch == 'rollover': self._prev = batch else: raise ValueError( "last_batch must be one of 'keep', 'discard', or 'rollover', " \ f"but got {self._last_batch}") def __len__(self): if self._last_batch == 'keep': return (len(self._sampler) + self._batch_size - 1) // self._batch_size if self._last_batch == 'discard': return len(self._sampler) // self._batch_size if self._last_batch == 'rollover': return (len(self._prev) + len(self._sampler)) // self._batch_size raise ValueError( "last_batch must be one of 'keep', 'discard', or 'rollover', " \ f"but got {self._last_batch}")
[docs]class IntervalSampler(Sampler): """Samples elements from [0, length) at fixed intervals. Parameters ---------- length : int Length of the sequence. interval : int The number of items to skip between two samples. rollover : bool, default True Whether to start again from the first skipped item after reaching the end. If true, this sampler would start again from the first skipped item until all items are visited. Otherwise, iteration stops when end is reached and skipped items are ignored. Examples -------- >>> sampler = contrib.data.IntervalSampler(13, interval=3) >>> list(sampler) [0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11] >>> sampler = contrib.data.IntervalSampler(13, interval=3, rollover=False) >>> list(sampler) [0, 3, 6, 9, 12] """ def __init__(self, length, interval, rollover=True): assert interval <= length, \ "Interval {} must be smaller than or equal to length {}".format(interval, length) self._length = length self._interval = interval self._rollover = rollover def __iter__(self): for i in range(self._interval if self._rollover else 1): for j in range(i, self._length, self._interval): yield j def __len__(self): return self._length