mxnet
data.h
Go to the documentation of this file.
1 
7 #ifndef DMLC_DATA_H_
8 #define DMLC_DATA_H_
9 
10 #include <string>
11 #include <vector>
12 #include <map>
13 #include "./base.h"
14 #include "./io.h"
15 #include "./logging.h"
16 #include "./registry.h"
17 
18 // To help C Preprocessor with processing c++ templated types
19 #define __DMLC_COMMA ,
20 
21 namespace dmlc {
26 typedef float real_t;
27 
32 typedef unsigned index_t;
33 
34 // This file describes common data structure that can be used
35 // for large-scale machine learning, this may not be a complete list
36 // But we will keep the most common and useful ones, and keep adding new ones
55 template<typename DType>
56 class DataIter {
57  public:
61  virtual void BeforeFirst(void) = 0;
63  virtual bool Next(void) = 0;
65  virtual const DType &Value(void) const = 0;
66 };
67 
73 template<typename IndexType, typename DType = real_t>
74 class Row {
75  public:
77  const DType *label;
79  const real_t *weight;
81  const uint64_t *qid;
83  size_t length;
87  const IndexType *field;
91  const IndexType *index;
96  const DType *value;
101  inline IndexType get_field(size_t i) const {
102  return field[i];
103  }
108  inline IndexType get_index(size_t i) const {
109  return index[i];
110  }
116  inline DType get_value(size_t i) const {
117  return value == NULL ? DType(1.0f) : value[i];
118  }
122  inline DType get_label() const {
123  return *label;
124  }
129  inline real_t get_weight() const {
130  return weight == NULL ? 1.0f : *weight;
131  }
136  inline uint64_t get_qid() const {
137  return qid == NULL ? 0 : *qid;
138  }
146  template<typename V>
147  inline V SDot(const V *weight, size_t size) const {
148  V sum = static_cast<V>(0);
149  if (value == NULL) {
150  for (size_t i = 0; i < length; ++i) {
151  CHECK(index[i] < size) << "feature index exceed bound";
152  sum += weight[index[i]];
153  }
154  } else {
155  for (size_t i = 0; i < length; ++i) {
156  CHECK(index[i] < size) << "feature index exceed bound";
157  sum += weight[index[i]] * value[i];
158  }
159  }
160  return sum;
161  }
162 };
163 
174 template<typename IndexType, typename DType = real_t>
175 struct RowBlock {
177  size_t size;
179  const size_t *offset;
181  const DType *label;
183  const real_t *weight;
185  const uint64_t *qid;
187  const IndexType *field;
189  const IndexType *index;
191  const DType *value;
197  inline Row<IndexType, DType> operator[](size_t rowid) const;
199  inline size_t MemCostBytes(void) const {
200  size_t cost = size * (sizeof(size_t) + sizeof(DType));
201  if (weight != NULL) cost += size * sizeof(real_t);
202  if (qid != NULL) cost += size * sizeof(size_t);
203  size_t ndata = offset[size] - offset[0];
204  if (field != NULL) cost += ndata * sizeof(IndexType);
205  if (index != NULL) cost += ndata * sizeof(IndexType);
206  if (value != NULL) cost += ndata * sizeof(DType);
207  return cost;
208  }
215  inline RowBlock Slice(size_t begin, size_t end) const {
216  CHECK(begin <= end && end <= size);
217  RowBlock ret;
218  ret.size = end - begin;
219  ret.label = label + begin;
220  if (weight != NULL) {
221  ret.weight = weight + begin;
222  } else {
223  ret.weight = NULL;
224  }
225  if (qid != NULL) {
226  ret.qid = qid + begin;
227  } else {
228  ret.qid = NULL;
229  }
230  ret.offset = offset + begin;
231  ret.field = field;
232  ret.index = index;
233  ret.value = value;
234  return ret;
235  }
236 };
237 
253 template<typename IndexType, typename DType = real_t>
254 class RowBlockIter : public DataIter<RowBlock<IndexType, DType> > {
255  public:
268  Create(const char *uri,
269  unsigned part_index,
270  unsigned num_parts,
271  const char *type);
273  virtual size_t NumCol() const = 0;
274 };
275 
292 template <typename IndexType, typename DType = real_t>
293 class Parser : public DataIter<RowBlock<IndexType, DType> > {
294  public:
307  static Parser<IndexType, DType> *
308  Create(const char *uri_,
309  unsigned part_index,
310  unsigned num_parts,
311  const char *type);
313  virtual size_t BytesRead(void) const = 0;
315  typedef Parser<IndexType, DType>* (*Factory)
316  (const std::string& path,
317  const std::map<std::string, std::string>& args,
318  unsigned part_index,
319  unsigned num_parts);
320 };
321 
327 template<typename IndexType, typename DType = real_t>
329  : public FunctionRegEntryBase<ParserFactoryReg<IndexType, DType>,
330  typename Parser<IndexType, DType>::Factory> {};
331 
358 #define DMLC_REGISTER_DATA_PARSER(IndexType, DataType, TypeName, FactoryFunction) \
359  DMLC_REGISTRY_REGISTER(ParserFactoryReg<IndexType __DMLC_COMMA DataType>, \
360  ParserFactoryReg ## _ ## IndexType ## _ ## DataType, TypeName) \
361  .set_body(FactoryFunction)
362 
363 
364 // implementation of operator[]
365 template<typename IndexType, typename DType>
366 inline Row<IndexType, DType>
368  CHECK(rowid < size);
370  inst.label = label + rowid;
371  if (weight != NULL) {
372  inst.weight = weight + rowid;
373  } else {
374  inst.weight = NULL;
375  }
376  if (qid != NULL) {
377  inst.qid = qid + rowid;
378  } else {
379  inst.qid = NULL;
380  }
381  inst.length = offset[rowid + 1] - offset[rowid];
382  if (field != NULL) {
383  inst.field = field + offset[rowid];
384  } else {
385  inst.field = NULL;
386  }
387  inst.index = index + offset[rowid];
388  if (value == NULL) {
389  inst.value = NULL;
390  } else {
391  inst.value = value + offset[rowid];
392  }
393  return inst;
394 }
395 
396 } // namespace dmlc
397 #endif // DMLC_DATA_H_
virtual bool Next(void)=0
move to next item
IndexType get_field(size_t i) const
Definition: data.h:101
const real_t * weight
With weight: array[size] label of each instance, otherwise nullptr.
Definition: data.h:183
float real_t
this defines the float point that will be used to store feature values
Definition: data.h:26
Common base class for function registry.
Definition: registry.h:151
IndexType get_index(size_t i) const
Definition: data.h:108
DType get_value(size_t i) const
Definition: data.h:116
uint64_t get_qid() const
Definition: data.h:136
const DType * value
array value of each instance, this can be NULL indicating every value is set to be 1 ...
Definition: data.h:96
const DType * value
feature value, can be NULL, indicating all values are 1
Definition: data.h:191
#define DMLC_THROW_EXCEPTION
Definition: base.h:224
data iterator interface this is not a C++ style iterator, but nice for data pulling:) This interface ...
Definition: data.h:56
size_t MemCostBytes(void) const
Definition: data.h:199
const uint64_t * qid
session-id of the instance
Definition: data.h:81
Data structure that holds the data Row block iterator interface that gets RowBlocks Difference betwee...
Definition: data.h:254
const IndexType * field
field id
Definition: data.h:187
const DType * label
array[size] label of each instance
Definition: data.h:181
unsigned index_t
this defines the unsigned integer type that can normally be used to store feature index ...
Definition: data.h:32
parser interface that parses input data used to load dmlc data format into your own data format Diffe...
Definition: data.h:293
one row of training instance
Definition: data.h:74
namespace for dmlc
Definition: array_view.h:12
const real_t * weight
weight of the instance
Definition: data.h:79
size_t length
length of the sparse vector
Definition: data.h:83
virtual const DType & Value(void) const =0
get current data
const uint64_t * qid
With qid: array[size] session id of each instance, otherwise nullptr.
Definition: data.h:185
RowBlock Slice(size_t begin, size_t end) const
slice a RowBlock to get rows in [begin, end)
Definition: data.h:215
const IndexType * index
index of each instance
Definition: data.h:91
a block of data, containing several rows in sparse matrix This is useful for (streaming-sxtyle) algor...
Definition: data.h:175
V SDot(const V *weight, size_t size) const
helper function to compute dot product of current
Definition: data.h:147
const DType * label
label of the instance
Definition: data.h:77
const IndexType * field
field of each instance
Definition: data.h:87
const size_t * offset
array[size+1], row pointer to beginning of each rows
Definition: data.h:179
real_t get_weight() const
Definition: data.h:129
Row< IndexType, DType > operator[](size_t rowid) const
get specific rows in the batch
Definition: data.h:367
virtual void BeforeFirst(void)=0
set before first of the item
const IndexType * index
feature index
Definition: data.h:189
DType get_label() const
Definition: data.h:122
registry entry of parser factory
Definition: data.h:328
virtual ~DataIter(void) DMLC_THROW_EXCEPTION
destructor
Definition: data.h:59
size_t size
batch size
Definition: data.h:177