mxnet
io.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_CPP_IO_H_
26 #define MXNET_CPP_IO_H_
27 
28 #include <map>
29 #include <string>
30 #include <vector>
31 #include <sstream>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/ndarray.h"
34 #include "dmlc/logging.h"
35 
36 namespace mxnet {
37 namespace cpp {
42 class DataBatch {
43  public:
46  int pad_num;
47  std::vector<int> index;
48 };
49 class DataIter {
50  public:
51  virtual void BeforeFirst(void) = 0;
52  virtual bool Next(void) = 0;
53  virtual NDArray GetData(void) = 0;
54  virtual NDArray GetLabel(void) = 0;
55  virtual int GetPadNum(void) = 0;
56  virtual std::vector<int> GetIndex(void) = 0;
57 
59  return DataBatch{GetData(), GetLabel(), GetPadNum(), GetIndex()};
60  }
61  void Reset() {
62  BeforeFirst();
63  }
64 
65  virtual ~DataIter() = default;
66 };
67 
69  public:
70  inline MXDataIterMap() {
71  mx_uint num_data_iter_creators = 0;
72  DataIterCreator* data_iter_creators = nullptr;
73  int r = MXListDataIters(&num_data_iter_creators, &data_iter_creators);
74  CHECK_EQ(r, 0);
75  for (mx_uint i = 0; i < num_data_iter_creators; i++) {
76  const char* name;
77  const char* description;
78  mx_uint num_args;
79  const char** arg_names;
80  const char** arg_type_infos;
81  const char** arg_descriptions;
82  r = MXDataIterGetIterInfo(data_iter_creators[i],
83  &name,
84  &description,
85  &num_args,
86  &arg_names,
87  &arg_type_infos,
88  &arg_descriptions);
89  CHECK_EQ(r, 0);
90  mxdataiter_creators_[name] = data_iter_creators[i];
91  }
92  }
93  inline DataIterCreator GetMXDataIterCreator(const std::string& name) {
94  return mxdataiter_creators_[name];
95  }
96 
97  private:
98  std::map<std::string, DataIterCreator> mxdataiter_creators_;
99 };
100 
102  public:
103  MXDataIterBlob() : handle_(nullptr) {}
104  explicit MXDataIterBlob(DataIterHandle handle) : handle_(handle) {}
107  }
109 
110  private:
111  MXDataIterBlob& operator=(const MXDataIterBlob&);
112 };
113 
114 class MXDataIter : public DataIter {
115  public:
116  explicit MXDataIter(const std::string& mxdataiter_type);
117  MXDataIter(const MXDataIter& other) {
118  creator_ = other.creator_;
119  params_ = other.params_;
120  blob_ptr_ = other.blob_ptr_;
121  }
122  void BeforeFirst();
123  bool Next();
124  NDArray GetData();
125  NDArray GetLabel();
126  int GetPadNum();
127  std::vector<int> GetIndex();
135  template <typename T>
136  MXDataIter& SetParam(const std::string& name, const T& value) {
137  std::string value_str;
138  std::stringstream ss;
139  ss << value;
140  ss >> value_str;
141 
142  params_[name] = value_str;
143  return *this;
144  }
145 
146  private:
147  DataIterCreator creator_;
148  std::map<std::string, std::string> params_;
149  std::shared_ptr<MXDataIterBlob> blob_ptr_;
150  static MXDataIterMap*& mxdataiter_map();
151 };
152 } // namespace cpp
153 } // namespace mxnet
154 
155 #endif // MXNET_CPP_IO_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::cpp::MXDataIter::MXDataIter
MXDataIter(const std::string &mxdataiter_type)
mxnet::cpp::MXDataIter::MXDataIter
MXDataIter(const MXDataIter &other)
Definition: io.h:117
mxnet::cpp::MXDataIter
Definition: io.h:114
mxnet::cpp::DataBatch
Default object for holding a mini-batch of data and related information.
Definition: io.h:42
mxnet::cpp::DataIter::GetDataBatch
DataBatch GetDataBatch()
Definition: io.h:58
mxnet::cpp::DataIter::GetIndex
virtual std::vector< int > GetIndex(void)=0
mxnet::cpp::MXDataIter::GetData
NDArray GetData()
mxnet::cpp::MXDataIterBlob::~MXDataIterBlob
~MXDataIterBlob()
Definition: io.h:105
MXDataIterGetIterInfo
MXNET_DLL int MXDataIterGetIterInfo(DataIterCreator creator, const char **name, const char **description, uint32_t *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions)
Get the detailed information about data iterator.
mxnet::cpp::MXDataIterMap::MXDataIterMap
MXDataIterMap()
Definition: io.h:70
mxnet::cpp::MXDataIter::GetPadNum
int GetPadNum()
mxnet::cpp::MXDataIter::GetLabel
NDArray GetLabel()
mxnet::cpp::MXDataIter::SetParam
MXDataIter & SetParam(const std::string &name, const T &value)
set config parameters
Definition: io.h:136
mxnet::cpp::MXDataIterBlob::MXDataIterBlob
MXDataIterBlob()
Definition: io.h:103
mxnet::cpp::NDArray
NDArray interface.
Definition: ndarray.h:122
ndarray.h
definition of ndarray
mxnet::cpp::DataIter::GetData
virtual NDArray GetData(void)=0
mxnet::cpp::MXDataIterBlob::handle_
DataIterHandle handle_
Definition: io.h:108
mxnet::cpp::DataIter::Next
virtual bool Next(void)=0
DataIterHandle
void * DataIterHandle
handle to a DataIterator
Definition: c_api.h:90
mxnet::cpp::DataIter::Reset
void Reset()
Definition: io.h:61
mxnet::cpp::MXDataIter::CreateDataIter
MXDataIter CreateDataIter()
DataIterCreator
void * DataIterCreator
handle a dataiter creator
Definition: c_api.h:88
mxnet::cpp::DataIter
Definition: io.h:49
MXListDataIters
MXNET_DLL int MXListDataIters(uint32_t *out_size, DataIterCreator **out_array)
List all the available iterator entries.
mxnet::cpp::DataBatch::data
NDArray data
Definition: io.h:44
mxnet::cpp::MXDataIterBlob
Definition: io.h:101
mxnet::cpp::DataBatch::pad_num
int pad_num
Definition: io.h:46
mxnet::cpp::MXDataIterBlob::MXDataIterBlob
MXDataIterBlob(DataIterHandle handle)
Definition: io.h:104
mxnet::cpp::DataBatch::index
std::vector< int > index
Definition: io.h:47
mxnet::cpp::DataIter::GetLabel
virtual NDArray GetLabel(void)=0
mxnet::cpp::MXDataIter::Next
bool Next()
mxnet::cpp::MXDataIterMap
Definition: io.h:68
mxnet::cpp::MXDataIter::BeforeFirst
void BeforeFirst()
mxnet::cpp::DataIter::BeforeFirst
virtual void BeforeFirst(void)=0
base.h
base definitions for mxnetcpp
MXDataIterFree
MXNET_DLL int MXDataIterFree(DataIterHandle handle)
Free the handle to the IO module.
mxnet::cpp::MXDataIterMap::GetMXDataIterCreator
DataIterCreator GetMXDataIterCreator(const std::string &name)
Definition: io.h:93
mx_uint
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:65
mxnet::cpp::DataBatch::label
NDArray label
Definition: io.h:45
mxnet::cpp::MXDataIter::GetIndex
std::vector< int > GetIndex()
mxnet::cpp::DataIter::GetPadNum
virtual int GetPadNum(void)=0
mxnet::cpp::DataIter::~DataIter
virtual ~DataIter()=default