7 #ifndef DMLC_INPUT_SPLIT_SHUFFLE_H_ 8 #define DMLC_INPUT_SPLIT_SHUFFLE_H_ 25 if (num_shuffle_parts_ > 1) {
26 std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
27 int idx = shuffle_indexes_[0] + part_index_ * num_shuffle_parts_;
28 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
31 source_->BeforeFirst();
35 source_->HintChunkSize(chunk_size);
38 return source_->GetTotalSize();
42 if (num_shuffle_parts_ > 1) {
43 if (!source_->NextRecord(out_rec)) {
44 if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
49 shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
50 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
56 return source_->NextRecord(out_rec);
61 if (num_shuffle_parts_ > 1) {
62 if (!source_->NextChunk(out_chunk)) {
63 if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) {
68 shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
69 source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_);
75 return source_->NextChunk(out_chunk);
80 CHECK(nsplit == num_parts_) <<
"num_parts is not consistent!";
81 int idx = shuffle_indexes_[0] + rank * num_shuffle_parts_;
82 source_->ResetPartition(idx, nsplit * num_shuffle_parts_);
104 unsigned num_shuffle_parts,
106 : part_index_(part_index),
107 num_parts_(num_parts),
108 num_shuffle_parts_(num_shuffle_parts),
109 cur_shuffle_idx_(0) {
110 for (
unsigned i = 0; i < num_shuffle_parts_; i++) {
111 shuffle_indexes_.push_back(i);
113 trnd_.seed(kRandMagic_ + part_index_ + num_parts_ + num_shuffle_parts_ +
115 std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_);
116 int idx = shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_;
142 unsigned num_shuffle_parts,
144 CHECK(num_shuffle_parts > 0) <<
"number of shuffle parts should be greater than zero!";
146 uri, part_index, num_parts, type, num_shuffle_parts, shuffle_seed);
151 static const int kRandMagic_ = 666;
155 std::unique_ptr<InputSplit> source_;
157 unsigned part_index_;
161 unsigned num_shuffle_parts_;
163 unsigned cur_shuffle_idx_;
165 std::vector<int> shuffle_indexes_;
168 #endif // DMLC_INPUT_SPLIT_SHUFFLE_H_
namespace for dmlc
Definition: array_view.h:12