mxnet
serializer.h
Go to the documentation of this file.
1 
7 #ifndef DMLC_SERIALIZER_H_
8 #define DMLC_SERIALIZER_H_
9 
10 #include <vector>
11 #include <string>
12 #include <map>
13 #include <set>
14 #include <list>
15 #include <deque>
16 #include <utility>
17 
18 #include "./base.h"
19 #include "./io.h"
20 #include "./logging.h"
21 #include "./type_traits.h"
22 #include "./endian.h"
23 
24 #if DMLC_USE_CXX11
25 #include <unordered_map>
26 #include <unordered_set>
27 #endif
28 
29 namespace dmlc {
31 namespace serializer {
37 template<typename T>
38 struct Handler;
39 
41 
48 template<bool cond, typename Then, typename Else, typename Return>
49 struct IfThenElse;
50 
51 template<typename Then, typename Else, typename T>
52 struct IfThenElse<true, Then, Else, T> {
53  inline static void Write(Stream *strm, const T &data) {
54  Then::Write(strm, data);
55  }
56  inline static bool Read(Stream *strm, T *data) {
57  return Then::Read(strm, data);
58  }
59 };
60 template<typename Then, typename Else, typename T>
61 struct IfThenElse<false, Then, Else, T> {
62  inline static void Write(Stream *strm, const T &data) {
63  Else::Write(strm, data);
64  }
65  inline static bool Read(Stream *strm, T *data) {
66  return Else::Read(strm, data);
67  }
68 };
69 
71 template<typename T>
72 struct NativePODHandler {
73  inline static void Write(Stream *strm, const T &data) {
74  strm->Write(&data, sizeof(T));
75  }
76  inline static bool Read(Stream *strm, T *dptr) {
77  return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*)
78  }
79 };
80 
82 template<typename T>
83 struct ArithmeticHandler {
84  inline static void Write(Stream *strm, const T &data) {
86  strm->Write(&data, sizeof(T));
87  } else {
88  T copy = data;
89  ByteSwap(&copy, sizeof(T), 1);
90  strm->Write(&copy, sizeof(T));
91  }
92  }
93  inline static bool Read(Stream *strm, T *dptr) {
94  bool ret = strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*)
96  ByteSwap(dptr, sizeof(T), 1);
97  }
98  return ret;
99  }
100 };
101 
102 // serializer for class that have save/load function
103 template<typename T>
104 struct SaveLoadClassHandler {
105  inline static void Write(Stream *strm, const T &data) {
106  data.Save(strm);
107  }
108  inline static bool Read(Stream *strm, T *data) {
109  return data->Load(strm);
110  }
111 };
112 
119 template<typename T>
120 struct UndefinedSerializerFor {
121 };
122 
127 template<typename T>
128 struct NativePODVectorHandler {
129  inline static void Write(Stream *strm, const std::vector<T> &vec) {
130  uint64_t sz = static_cast<uint64_t>(vec.size());
131  strm->Write<uint64_t>(sz);
132  if (sz != 0) {
133  strm->Write(&vec[0], sizeof(T) * vec.size());
134  }
135  }
136  inline static bool Read(Stream *strm, std::vector<T> *out_vec) {
137  uint64_t sz;
138  if (!strm->Read<uint64_t>(&sz)) return false;
139  size_t size = static_cast<size_t>(sz);
140  out_vec->resize(size);
141  if (sz != 0) {
142  size_t nbytes = sizeof(T) * size;
143  return strm->Read(&(*out_vec)[0], nbytes) == nbytes;
144  }
145  return true;
146  }
147 };
148 
153 template<typename T>
154 struct ComposeVectorHandler {
155  inline static void Write(Stream *strm, const std::vector<T> &vec) {
156  uint64_t sz = static_cast<uint64_t>(vec.size());
157  strm->Write<uint64_t>(sz);
158  strm->WriteArray(dmlc::BeginPtr(vec), vec.size());
159  }
160  inline static bool Read(Stream *strm, std::vector<T> *out_vec) {
161  uint64_t sz;
162  if (!strm->Read<uint64_t>(&sz)) return false;
163  size_t size = static_cast<size_t>(sz);
164  out_vec->resize(size);
165  return strm->ReadArray(dmlc::BeginPtr(*out_vec), size);
166  }
167 };
168 
173 template<typename T>
174 struct NativePODStringHandler {
175  inline static void Write(Stream *strm, const std::basic_string<T> &vec) {
176  uint64_t sz = static_cast<uint64_t>(vec.length());
177  strm->Write<uint64_t>(sz);
178  if (sz != 0) {
179  strm->Write(&vec[0], sizeof(T) * vec.length());
180  }
181  }
182  inline static bool Read(Stream *strm, std::basic_string<T> *out_vec) {
183  uint64_t sz;
184  if (!strm->Read<uint64_t>(&sz)) return false;
185  size_t size = static_cast<size_t>(sz);
186  out_vec->resize(size);
187  if (sz != 0) {
188  size_t nbytes = sizeof(T) * size;
189  return strm->Read(&(*out_vec)[0], nbytes) == nbytes;
190  }
191  return true;
192  }
193 };
194 
196 template<typename TA, typename TB>
197 struct PairHandler {
198  inline static void Write(Stream *strm, const std::pair<TA, TB> &data) {
199  Handler<TA>::Write(strm, data.first);
200  Handler<TB>::Write(strm, data.second);
201  }
202  inline static bool Read(Stream *strm, std::pair<TA, TB> *data) {
203  return Handler<TA>::Read(strm, &(data->first)) &&
204  Handler<TB>::Read(strm, &(data->second));
205  }
206 };
207 
208 // set type handler that can handle most collection type case
209 template<typename ContainerType, typename ElemType>
210 struct CollectionHandler {
211  inline static void Write(Stream *strm, const ContainerType &data) {
212  // dump data to vector
213  std::vector<ElemType> vdata(data.begin(), data.end());
214  // serialize the vector
215  Handler<std::vector<ElemType> >::Write(strm, vdata);
216  }
217  inline static bool Read(Stream *strm, ContainerType *data) {
218  std::vector<ElemType> vdata;
219  if (!Handler<std::vector<ElemType> >::Read(strm, &vdata)) return false;
220  data->clear();
221  data->insert(vdata.begin(), vdata.end());
222  return true;
223  }
224 };
225 
226 
227 // handler that can handle most list type case
228 // this type insert function takes additional iterator
229 template<typename ListType>
230 struct ListHandler {
231  inline static void Write(Stream *strm, const ListType &data) {
232  typedef typename ListType::value_type ElemType;
233  // dump data to vector
234  std::vector<ElemType> vdata(data.begin(), data.end());
235  // serialize the vector
236  Handler<std::vector<ElemType> >::Write(strm, vdata);
237  }
238  inline static bool Read(Stream *strm, ListType *data) {
239  typedef typename ListType::value_type ElemType;
240  std::vector<ElemType> vdata;
241  if (!Handler<std::vector<ElemType> >::Read(strm, &vdata)) return false;
242  data->clear();
243  data->insert(data->begin(), vdata.begin(), vdata.end());
244  return true;
245  }
246 };
247 
249 
258 template<typename T>
259 struct Handler {
265  inline static void Write(Stream *strm, const T &data) {
266  IfThenElse<dmlc::is_arithmetic<T>::value,
267  ArithmeticHandler<T>,
268  IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
269  NativePODHandler<T>,
270  IfThenElse<dmlc::has_saveload<T>::value,
271  SaveLoadClassHandler<T>,
272  UndefinedSerializerFor<T>, T>,
273  T>,
274  T>
275  ::Write(strm, data);
276  }
283  inline static bool Read(Stream *strm, T *data) {
284  return
285  IfThenElse<dmlc::is_arithmetic<T>::value,
286  ArithmeticHandler<T>,
287  IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
288  NativePODHandler<T>,
289  IfThenElse<dmlc::has_saveload<T>::value,
290  SaveLoadClassHandler<T>,
291  UndefinedSerializerFor<T>, T>,
292  T>,
293  T>
294  ::Read(strm, data);
295  }
296 };
297 
299 template<typename T>
300 struct Handler<std::vector<T> > {
301  inline static void Write(Stream *strm, const std::vector<T> &data) {
302  IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
303  NativePODVectorHandler<T>,
304  ComposeVectorHandler<T>, std::vector<T> >
305  ::Write(strm, data);
306  }
307  inline static bool Read(Stream *strm, std::vector<T> *data) {
308  return IfThenElse<dmlc::is_pod<T>::value && DMLC_IO_NO_ENDIAN_SWAP,
309  NativePODVectorHandler<T>,
310  ComposeVectorHandler<T>,
311  std::vector<T> >
312  ::Read(strm, data);
313  }
314 };
315 
316 template<typename T>
317 struct Handler<std::basic_string<T> > {
318  inline static void Write(Stream *strm, const std::basic_string<T> &data) {
319  IfThenElse<dmlc::is_pod<T>::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1),
320  NativePODStringHandler<T>,
321  UndefinedSerializerFor<T>,
322  std::basic_string<T> >
323  ::Write(strm, data);
324  }
325  inline static bool Read(Stream *strm, std::basic_string<T> *data) {
326  return IfThenElse<dmlc::is_pod<T>::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1),
327  NativePODStringHandler<T>,
328  UndefinedSerializerFor<T>,
329  std::basic_string<T> >
330  ::Read(strm, data);
331  }
332 };
333 
334 template<typename TA, typename TB>
335 struct Handler<std::pair<TA, TB> > {
336  inline static void Write(Stream *strm, const std::pair<TA, TB> &data) {
337  IfThenElse<dmlc::is_pod<TA>::value &&
340  NativePODHandler<std::pair<TA, TB> >,
341  PairHandler<TA, TB>,
342  std::pair<TA, TB> >
343  ::Write(strm, data);
344  }
345  inline static bool Read(Stream *strm, std::pair<TA, TB> *data) {
346  return IfThenElse<dmlc::is_pod<TA>::value &&
349  NativePODHandler<std::pair<TA, TB> >,
350  PairHandler<TA, TB>,
351  std::pair<TA, TB> >
352  ::Read(strm, data);
353  }
354 };
355 
356 template<typename K, typename V>
357 struct Handler<std::map<K, V> >
358  : public CollectionHandler<std::map<K, V>, std::pair<K, V> > {
359 };
360 
361 template<typename K, typename V>
362 struct Handler<std::multimap<K, V> >
363  : public CollectionHandler<std::multimap<K, V>, std::pair<K, V> > {
364 };
365 
366 template<typename T>
367 struct Handler<std::set<T> >
368  : public CollectionHandler<std::set<T>, T> {
369 };
370 
371 template<typename T>
372 struct Handler<std::multiset<T> >
373  : public CollectionHandler<std::multiset<T>, T> {
374 };
375 
376 template<typename T>
377 struct Handler<std::list<T> >
378  : public ListHandler<std::list<T> > {
379 };
380 
381 template<typename T>
382 struct Handler<std::deque<T> >
383  : public ListHandler<std::deque<T> > {
384 };
385 
386 #if DMLC_USE_CXX11
387 template<typename K, typename V>
388 struct Handler<std::unordered_map<K, V> >
389  : public CollectionHandler<std::unordered_map<K, V>, std::pair<K, V> > {
390 };
391 
392 template<typename K, typename V>
393 struct Handler<std::unordered_multimap<K, V> >
394  : public CollectionHandler<std::unordered_multimap<K, V>, std::pair<K, V> > {
395 };
396 
397 template<typename T>
398 struct Handler<std::unordered_set<T> >
399  : public CollectionHandler<std::unordered_set<T>, T> {
400 };
401 
402 template<typename T>
403 struct Handler<std::unordered_multiset<T> >
404  : public CollectionHandler<std::unordered_multiset<T>, T> {
405 };
406 #endif
407 } // namespace serializer
409 } // namespace dmlc
410 #endif // DMLC_SERIALIZER_H_
whether a type is pod type
Definition: type_traits.h:21
T * BeginPtr(std::vector< T > &vec)
safely get the beginning address of a vector
Definition: base.h:239
interface of stream I/O for serialization
Definition: io.h:30
static void Write(Stream *strm, const T &data)
write data to stream
Definition: serializer.h:265
void WriteArray(const T *data, size_t num_elems)
Endian aware write array of data.
Definition: io.h:460
Endian testing, need c++11.
generic serialization handler
Definition: serializer.h:38
namespace for dmlc
Definition: array_view.h:12
#define DMLC_IO_NO_ENDIAN_SWAP
whether serialize using little endian
Definition: endian.h:39
void ByteSwap(void *data, size_t elem_bytes, size_t num_elems)
A generic inplace byte swapping function.
Definition: endian.h:51
bool ReadArray(T *data, size_t num_elems)
Endian aware read array of data.
Definition: io.h:467
virtual size_t Read(void *ptr, size_t size)=0
reads data from a stream
virtual void Write(const void *ptr, size_t size)=0
writes data to a stream
static bool Read(Stream *strm, T *data)
read data to stream
Definition: serializer.h:283
type traits information header