mxnet
io.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_IO_H_
25 #define MXNET_IO_H_
26 
27 #include <vector>
28 #include <string>
29 #include <utility>
30 #include <queue>
31 #include "dmlc/data.h"
32 #include "dmlc/registry.h"
33 #include "./base.h"
34 #include "./ndarray.h"
35 
36 namespace mxnet {
41 template <typename DType>
42 class IIterator : public dmlc::DataIter<DType> {
43  public:
48  virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
50  virtual void BeforeFirst(void) = 0;
52  virtual bool Next(void) = 0;
54  virtual const DType& Value(void) const = 0;
56  virtual ~IIterator(void) {}
58  std::vector<std::string> data_names;
60  inline void SetDataName(const std::string data_name) {
61  data_names.push_back(data_name);
62  }
67  virtual int64_t GetLenHint(void) const {
68  return -1;
69  }
70 }; // class IIterator
71 
73 struct DataInst {
75  unsigned index;
77  std::vector<TBlob> data;
79  std::string extra_data;
80 }; // struct DataInst
81 
85 struct DataBatch {
87  std::vector<NDArray> data;
89  std::vector<uint64_t> index;
91  std::string extra_data;
94 }; // struct DataBatch
95 
97 typedef std::function<IIterator<DataBatch>*()> DataIteratorFactory;
101 struct DataIteratorReg : public dmlc::FunctionRegEntryBase<DataIteratorReg, DataIteratorFactory> {};
102 //--------------------------------------------------------------
103 // The following part are API Registration of Iterators
104 //--------------------------------------------------------------
117 #define MXNET_REGISTER_IO_ITER(name) \
118  DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name)
119 
126 class Dataset {
127  public:
131  virtual uint64_t GetLen(void) const = 0;
137  virtual bool GetItem(uint64_t idx, std::vector<NDArray>* ret) = 0;
138  // virtual destructor
139  virtual ~Dataset(void) {}
140 }; // class Dataset
141 
143 typedef std::function<Dataset*(const std::vector<std::pair<std::string, std::string> >&)>
148 struct DatasetReg : public dmlc::FunctionRegEntryBase<DatasetReg, DatasetFactory> {};
149 //--------------------------------------------------------------
150 // The following part are API Registration of Datasets
151 //--------------------------------------------------------------
164 #define MXNET_REGISTER_IO_DATASET(name) \
165  DMLC_REGISTRY_REGISTER(::mxnet::DatasetReg, DatasetReg, name)
166 
168  public:
170  virtual ~BatchifyFunction(void) {}
172  virtual bool Batchify(const std::vector<std::vector<NDArray> >& inputs,
173  std::vector<NDArray>* outputs) = 0;
174 }; // class BatchifyFunction
175 
176 using BatchifyFunctionPtr = std::shared_ptr<BatchifyFunction>;
177 
179 typedef std::function<BatchifyFunction*(const std::vector<std::pair<std::string, std::string> >&)>
185  : public dmlc::FunctionRegEntryBase<BatchifyFunctionReg, BatchifyFunctionFactory> {};
186 //--------------------------------------------------------------
187 // The following part are API Registration of Batchify Function
188 //--------------------------------------------------------------
201 #define MXNET_REGISTER_IO_BATCHIFY_FUNCTION(name) \
202  DMLC_REGISTRY_REGISTER(::mxnet::BatchifyFunctionReg, BatchifyFunctionReg, name)
203 } // namespace mxnet
204 #endif // MXNET_IO_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::DataBatch::extra_data
std::string extra_data
extra data to be fed to the network
Definition: io.h:91
mxnet::DataInst::data
std::vector< TBlob > data
content of data
Definition: io.h:77
mxnet::IIterator::data_names
std::vector< std::string > data_names
store the name of each data, it could be used for making NDArrays
Definition: io.h:58
dmlc::FunctionRegEntryBase
Common base class for function registry.
Definition: registry.h:151
mxnet::BatchifyFunctionFactory
std::function< BatchifyFunction *(const std::vector< std::pair< std::string, std::string > > &)> BatchifyFunctionFactory
typedef the factory function of data sampler
Definition: io.h:180
mxnet::DatasetFactory
std::function< Dataset *(const std::vector< std::pair< std::string, std::string > > &)> DatasetFactory
typedef the factory function of dataset
Definition: io.h:144
mxnet::BatchifyFunction
Definition: io.h:167
mxnet::Dataset
A random accessable dataset which provides GetLen() and GetItem(). Unlike DataIter,...
Definition: io.h:126
mxnet::Dataset::~Dataset
virtual ~Dataset(void)
Definition: io.h:139
mxnet::DataIteratorFactory
std::function< IIterator< DataBatch > *()> DataIteratorFactory
typedef the factory function of data iterator
Definition: io.h:97
mxnet::DataInst
a single data instance
Definition: io.h:73
mxnet::Dataset::GetLen
virtual uint64_t GetLen(void) const =0
Get the size of the dataset.
mxnet::BatchifyFunctionReg
Registry entry for DataSampler factory functions.
Definition: io.h:184
mxnet::DatasetReg
Registry entry for Dataset factory functions.
Definition: io.h:148
mxnet::BatchifyFunction::~BatchifyFunction
virtual ~BatchifyFunction(void)
Destructor.
Definition: io.h:170
mxnet::DataBatch::data
std::vector< NDArray > data
content of dense data, if this DataBatch is dense
Definition: io.h:87
mxnet::DataInst::extra_data
std::string extra_data
extra data to be fed to the network
Definition: io.h:79
mxnet::IIterator::BeforeFirst
virtual void BeforeFirst(void)=0
reset the iterator
mxnet::DataIteratorReg
Registry entry for DataIterator factory functions.
Definition: io.h:101
mxnet::IIterator
iterator type
Definition: io.h:42
mxnet::DataBatch::index
std::vector< uint64_t > index
index of image data
Definition: io.h:89
dmlc::DataIter
data iterator interface this is not a C++ style iterator, but nice for data pulling:) This interface ...
Definition: data.h:56
mxnet::BatchifyFunction::Batchify
virtual bool Batchify(const std::vector< std::vector< NDArray > > &inputs, std::vector< NDArray > *outputs)=0
The batchify logic.
mxnet::IIterator::GetLenHint
virtual int64_t GetLenHint(void) const
request iterator length hint for current epoch. Note that the returned value can be < 0,...
Definition: io.h:67
mxnet::IIterator::~IIterator
virtual ~IIterator(void)
constructor
Definition: io.h:56
mxnet::DataInst::index
unsigned index
unique id for instance
Definition: io.h:75
mxnet::DataBatch::num_batch_padd
int num_batch_padd
num of example padded to batch
Definition: io.h:93
mxnet::IIterator::Value
virtual const DType & Value(void) const =0
get current data
data.h
defines common input data structure, and interface for handling the input data
mxnet::IIterator::Next
virtual bool Next(void)=0
move to next item
mxnet::DataBatch
DataBatch of NDArray, returned by Iterator.
Definition: io.h:85
mxnet::IIterator::SetDataName
void SetDataName(const std::string data_name)
set data name to each attribute of data
Definition: io.h:60
registry.h
Registry utility that helps to build registry singletons.
ndarray.h
NDArray interface that handles array arithematics.
mxnet::Dataset::GetItem
virtual bool GetItem(uint64_t idx, std::vector< NDArray > *ret)=0
Get the ndarray items given index in dataset.
mxnet::IIterator::Init
virtual void Init(const std::vector< std::pair< std::string, std::string > > &kwargs)=0
set the parameters and init iter
mxnet::BatchifyFunctionPtr
std::shared_ptr< BatchifyFunction > BatchifyFunctionPtr
Definition: io.h:176
base.h
configuration of MXNet as well as basic data structure.