mxnet
input_split_shuffle.h
Go to the documentation of this file.
1 
7 #ifndef DMLC_INPUT_SPLIT_SHUFFLE_H_
8 #define DMLC_INPUT_SPLIT_SHUFFLE_H_
9 
10 #include <cstdio>
11 #include <cstring>
12 #include <vector>
13 #include <string>
14 #include <algorithm>
15 #include <memory>
16 
17 namespace dmlc {
19 class InputSplitShuffle : public InputSplit {
20  public:
21  // destructor
22  virtual ~InputSplitShuffle(void) { source_.reset(); }
23  // implement BeforeFirst
24  virtual void BeforeFirst(void) {
25  if (num_shuffle_parts_ > 1) {
26  std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
27  int idx = shuffle_indexes_[0] + part_index_ * num_shuffle_parts_;
28  source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
29  cur_shuffle_idx_ = 0;
30  } else {
31  source_->BeforeFirst();
32  }
33  }
34  virtual void HintChunkSize(size_t chunk_size) {
35  source_->HintChunkSize(chunk_size);
36  }
37  virtual size_t GetTotalSize(void) {
38  return source_->GetTotalSize();
39  }
40  // implement next record
41  virtual bool NextRecord(Blob *out_rec) {
42  if (num_shuffle_parts_ > 1) {
43  if (!source_->NextRecord(out_rec)) {
44  if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
45  return false;
46  }
47  ++cur_shuffle_idx_;
48  int idx =
49  shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
50  source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
51  return NextRecord(out_rec);
52  } else {
53  return true;
54  }
55  } else {
56  return source_->NextRecord(out_rec);
57  }
58  }
59  // implement next chunk
60  virtual bool NextChunk(Blob* out_chunk) {
61  if (num_shuffle_parts_ > 1) {
62  if (!source_->NextChunk(out_chunk)) {
63  if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
64  return false;
65  }
66  ++cur_shuffle_idx_;
67  int idx =
68  shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
69  source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
70  return NextChunk(out_chunk);
71  } else {
72  return true;
73  }
74  } else {
75  return source_->NextChunk(out_chunk);
76  }
77  }
78  // implement ResetPartition.
79  virtual void ResetPartition(unsigned rank, unsigned nsplit) {
80  CHECK(nsplit == num_parts_) << "num_parts is not consistent!";
81  int idx = shuffle_indexes_[0] + rank * num_shuffle_parts_;
82  source_->ResetPartition(idx, nsplit * num_shuffle_parts_);
83  cur_shuffle_idx_ = 0;
84  }
100  InputSplitShuffle(const char* uri,
101  unsigned part_index,
102  unsigned num_parts,
103  const char* type,
104  unsigned num_shuffle_parts,
105  int shuffle_seed)
106  : part_index_(part_index),
107  num_parts_(num_parts),
108  num_shuffle_parts_(num_shuffle_parts),
109  cur_shuffle_idx_(0) {
110  for (unsigned i = 0; i < num_shuffle_parts_; i++) {
111  shuffle_indexes_.push_back(i);
112  }
113  trnd_.seed(kRandMagic_ + part_index_ + num_parts_ + num_shuffle_parts_ +
114  shuffle_seed);
115  std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
116  int idx = shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
117  source_.reset(
118  InputSplit::Create(uri, idx , num_parts_ * num_shuffle_parts_, type));
119  }
138  static InputSplit* Create(const char* uri,
139  unsigned part_index,
140  unsigned num_parts,
141  const char* type,
142  unsigned num_shuffle_parts,
143  int shuffle_seed) {
144  CHECK(num_shuffle_parts > 0) << "number of shuffle parts should be greater than zero!";
145  return new InputSplitShuffle(
146  uri, part_index, num_parts, type, num_shuffle_parts, shuffle_seed);
147  }
148 
149  private:
150  // magic nyumber for seed
151  static const int kRandMagic_ = 666;
153  std::mt19937 trnd_;
155  std::unique_ptr<InputSplit> source_;
157  unsigned part_index_;
159  unsigned num_parts_;
161  unsigned num_shuffle_parts_;
163  unsigned cur_shuffle_idx_;
165  std::vector<int> shuffle_indexes_;
166 };
167 } // namespace dmlc
168 #endif // DMLC_INPUT_SPLIT_SHUFFLE_H_
static InputSplit * Create(const char *uri, unsigned part_index, unsigned num_parts, const char *type, unsigned num_shuffle_parts, int shuffle_seed)
factory function: create input split with chunk shuffling given a uri
Definition: input_split_shuffle.h:138
virtual void HintChunkSize(size_t chunk_size)
hint the inputsplit how large the chunk size it should return when implementing NextChunk this is a h...
Definition: input_split_shuffle.h:34
a blob of memory region
Definition: io.h:158
virtual void BeforeFirst(void)
reset the position of InputSplit to beginning
Definition: input_split_shuffle.h:24
class to construct input split with global shuffling
Definition: input_split_shuffle.h:19
InputSplitShuffle(const char *uri, unsigned part_index, unsigned num_parts, const char *type, unsigned num_shuffle_parts, int shuffle_seed)
constructor
Definition: input_split_shuffle.h:100
virtual bool NextChunk(Blob *out_chunk)
get a chunk of memory that can contain multiple records, the caller needs to parse the content of the...
Definition: input_split_shuffle.h:60
virtual ~InputSplitShuffle(void)
Definition: input_split_shuffle.h:22
namespace for dmlc
Definition: array_view.h:12
virtual size_t GetTotalSize(void)
get the total size of the InputSplit
Definition: input_split_shuffle.h:37
static InputSplit * Create(const char *uri, unsigned part_index, unsigned num_parts, const char *type)
factory function: create input split given a uri
input split creates that allows reading of records from split of data, independent part that covers a...
Definition: io.h:155
virtual bool NextRecord(Blob *out_rec)
get the next record, the returning value is valid until next call to NextRecord, NextChunk or NextBat...
Definition: input_split_shuffle.h:41
virtual void ResetPartition(unsigned rank, unsigned nsplit)
reset the Input split to a certain part id, The InputSplit will be pointed to the head of the new spe...
Definition: input_split_shuffle.h:79