mxnet
io.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_IO_H_
27 #define MXNET_CPP_IO_H_
28 
29 #include <map>
30 #include <string>
31 #include <vector>
32 #include <sstream>
33 #include "mxnet-cpp/base.h"
34 #include "mxnet-cpp/ndarray.h"
35 #include "dmlc/logging.h"
36 
37 namespace mxnet {
38 namespace cpp {
43 class DataBatch {
44  public:
47  int pad_num;
48  std::vector<int> index;
49 };
50 class DataIter {
51  public:
52  virtual void BeforeFirst(void) = 0;
53  virtual bool Next(void) = 0;
54  virtual NDArray GetData(void) = 0;
55  virtual NDArray GetLabel(void) = 0;
56  virtual int GetPadNum(void) = 0;
57  virtual std::vector<int> GetIndex(void) = 0;
58 
60  return DataBatch{GetData(), GetLabel(), GetPadNum(), GetIndex()};
61  }
62  void Reset() { BeforeFirst(); }
63 
64  virtual ~DataIter() = default;
65 };
66 
68  public:
69  inline MXDataIterMap() {
70  mx_uint num_data_iter_creators = 0;
71  DataIterCreator *data_iter_creators = nullptr;
72  int r = MXListDataIters(&num_data_iter_creators, &data_iter_creators);
73  CHECK_EQ(r, 0);
74  for (mx_uint i = 0; i < num_data_iter_creators; i++) {
75  const char *name;
76  const char *description;
77  mx_uint num_args;
78  const char **arg_names;
79  const char **arg_type_infos;
80  const char **arg_descriptions;
81  r = MXDataIterGetIterInfo(data_iter_creators[i], &name, &description,
82  &num_args, &arg_names, &arg_type_infos,
83  &arg_descriptions);
84  CHECK_EQ(r, 0);
85  mxdataiter_creators_[name] = data_iter_creators[i];
86  }
87  }
88  inline DataIterCreator GetMXDataIterCreator(const std::string &name) {
89  return mxdataiter_creators_[name];
90  }
91 
92  private:
93  std::map<std::string, DataIterCreator> mxdataiter_creators_;
94 };
95 
97  public:
98  MXDataIterBlob() : handle_(nullptr) {}
99  explicit MXDataIterBlob(DataIterHandle handle) : handle_(handle) {}
102 
103  private:
104  MXDataIterBlob &operator=(const MXDataIterBlob &);
105 };
106 
107 class MXDataIter : public DataIter {
108  public:
109  explicit MXDataIter(const std::string &mxdataiter_type);
110  MXDataIter(const MXDataIter &other) {
111  creator_ = other.creator_;
112  params_ = other.params_;
113  blob_ptr_ = other.blob_ptr_;
114  }
115  void BeforeFirst();
116  bool Next();
117  NDArray GetData();
118  NDArray GetLabel();
119  int GetPadNum();
120  std::vector<int> GetIndex();
121  MXDataIter CreateDataIter();
128  template <typename T>
129  MXDataIter &SetParam(const std::string &name, const T &value) {
130  std::string value_str;
131  std::stringstream ss;
132  ss << value;
133  ss >> value_str;
134 
135  params_[name] = value_str;
136  return *this;
137  }
138 
139  private:
140  DataIterCreator creator_;
141  std::map<std::string, std::string> params_;
142  std::shared_ptr<MXDataIterBlob> blob_ptr_;
143  static MXDataIterMap*& mxdataiter_map();
144 };
145 } // namespace cpp
146 } // namespace mxnet
147 
148 #endif // MXNET_CPP_IO_H_
149 
MXDataIterBlob(DataIterHandle handle)
Definition: io.h:99
void * DataIterCreator
handle a dataiter creator
Definition: c_api.h:81
MXDataIter & SetParam(const std::string &name, const T &value)
set config parameters
Definition: io.h:129
void Reset()
Definition: io.h:62
Definition: io.h:67
namespace of mxnet
Definition: api_registry.h:33
DataIterCreator GetMXDataIterCreator(const std::string &name)
Definition: io.h:88
~MXDataIterBlob()
Definition: io.h:100
NDArray interface.
Definition: ndarray.h:121
Definition: io.h:107
MXDataIter(const MXDataIter &other)
Definition: io.h:110
MXDataIterMap()
Definition: io.h:69
std::vector< int > index
Definition: io.h:48
NDArray label
Definition: io.h:46
DataBatch GetDataBatch()
Definition: io.h:59
MXNET_DLL int MXDataIterGetIterInfo(DataIterCreator creator, const char **name, const char **description, uint32_t *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions)
Get the detailed information about data iterator.
Definition: io.h:50
NDArray data
Definition: io.h:45
MXDataIterBlob()
Definition: io.h:98
MXNET_DLL int MXListDataIters(uint32_t *out_size, DataIterCreator **out_array)
List all the available iterator entries.
int pad_num
Definition: io.h:47
void * DataIterHandle
handle to a DataIterator
Definition: c_api.h:83
MXNET_DLL int MXDataIterFree(DataIterHandle handle)
Free the handle to the IO module.
DataIterHandle handle_
Definition: io.h:101
Default object for holding a mini-batch of data and related information.
Definition: io.h:43
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:58
Definition: io.h:96