mxnet
utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_COMMON_UTILS_H_
25 #define MXNET_COMMON_UTILS_H_
26 
27 #include <dmlc/logging.h>
28 #include <dmlc/omp.h>
29 #include <nnvm/graph.h>
30 #include <nnvm/node.h>
31 #include <mxnet/imperative.h>
32 #include <mxnet/engine.h>
33 #include <mxnet/ndarray.h>
34 #include <mxnet/storage.h>
35 #include <mxnet/op_attr_types.h>
36 #include <mxnet/graph_attr_types.h>
37 #include <nnvm/graph_attr_types.h>
38 
39 #include <memory>
40 #include <vector>
41 #include <type_traits>
42 #include <utility>
43 #include <random>
44 #include <string>
45 #include <thread>
46 #include <algorithm>
47 #include <functional>
48 #include <limits>
49 
50 #include "../operator/mxnet_op.h"
51 #if MXNET_USE_ONEDNN == 1
52 #include "../operator/nn/dnnl/dnnl_base-inl.h"
53 #endif
54 
55 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
56 #include <windows.h>
57 #else
58 #include <unistd.h>
59 #endif
60 
61 namespace mxnet {
62 namespace common {
63 
64 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
65 inline size_t current_process_id() {
66  return ::GetCurrentProcessId();
67 }
68 #else
69 inline size_t current_process_id() {
70  return getpid();
71 }
72 #endif
73 
78  template <typename DType, typename IType>
79  MSHADOW_XINLINE static void Map(int i,
80  DType* out,
81  const IType* indptr,
82  const nnvm::dim_t end,
83  const nnvm::dim_t idx_size) {
84  if (indptr[i + 1] < 0 || indptr[i + 1] < indptr[i] || (i == 0 && indptr[i] != 0) ||
85  (i == end - 1 && indptr[end] != idx_size))
86  *out = kCSRIndPtrErr;
87  }
88 };
89 
94 struct csr_idx_check {
95  template <typename DType, typename IType, typename RType>
96  MSHADOW_XINLINE static void Map(int i,
97  DType* out,
98  const IType* idx,
99  const RType* indptr,
100  const nnvm::dim_t ncols) {
101  for (RType j = indptr[i]; j < indptr[i + 1]; j++) {
102  if (idx[j] >= ncols || idx[j] < 0 || (j < indptr[i + 1] - 1 && idx[j] >= idx[j + 1])) {
103  *out = kCSRIdxErr;
104  break;
105  }
106  }
107  }
108 };
109 
115  template <typename DType, typename IType>
116  MSHADOW_XINLINE static void Map(int i,
117  DType* out,
118  const IType* idx,
119  const nnvm::dim_t end,
120  const nnvm::dim_t nrows) {
121  if ((i < end && idx[i + 1] <= idx[i]) || idx[i] < 0 || idx[i] >= nrows)
122  *out = kRSPIdxErr;
123  }
124 };
125 
126 template <typename xpu>
127 void CheckFormatWrapper(const RunContext& rctx,
128  const NDArray& input,
129  const TBlob& err_cpu,
130  const bool full_check);
131 
140 template <typename xpu>
142  const NDArray& input,
143  const TBlob& err_cpu,
144  const bool full_check) {
145  using namespace op::mxnet_op;
146  CHECK_EQ(input.storage_type(), kCSRStorage) << "CheckFormatCSRImpl is for CSRNDArray";
147  const mxnet::TShape shape = input.shape();
148  const mxnet::TShape idx_shape = input.aux_shape(csr::kIdx);
149  const mxnet::TShape indptr_shape = input.aux_shape(csr::kIndPtr);
150  const mxnet::TShape storage_shape = input.storage_shape();
151  if ((shape.ndim() != 2) ||
152  (idx_shape.ndim() != 1 || indptr_shape.ndim() != 1 || storage_shape.ndim() != 1) ||
153  (indptr_shape[0] != shape[0] + 1) || (idx_shape[0] != storage_shape[0])) {
154  MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, {
155  DType* err = err_cpu.dptr<DType>();
156  *err = kCSRShapeErr;
157  });
158  return;
159  }
160  if (full_check) {
161  MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, {
162  MSHADOW_IDX_TYPE_SWITCH(input.aux_type(csr::kIndPtr), RType, {
163  MSHADOW_IDX_TYPE_SWITCH(input.aux_type(csr::kIdx), IType, {
164  mshadow::Stream<xpu>* s = rctx.get_stream<xpu>();
165  NDArray ret_xpu = NDArray(mshadow::Shape1(1), rctx.get_ctx(), false, err_cpu.type_flag_);
166  TBlob val_xpu = ret_xpu.data();
167  Kernel<set_to_int<kNormalErr>, xpu>::Launch(s, val_xpu.Size(), val_xpu.dptr<DType>());
168  Kernel<csr_indptr_check, xpu>::Launch(s,
169  indptr_shape[0] - 1,
170  val_xpu.dptr<DType>(),
171  input.aux_data(csr::kIndPtr).dptr<RType>(),
172  indptr_shape[0] - 1,
173  idx_shape[0]);
174  // no need to check indices if indices are empty
175  if (idx_shape[0] != 0) {
176  Kernel<csr_idx_check, xpu>::Launch(s,
177  indptr_shape[0] - 1,
178  val_xpu.dptr<DType>(),
179  input.aux_data(csr::kIdx).dptr<IType>(),
180  input.aux_data(csr::kIndPtr).dptr<RType>(),
181  shape[1]);
182  }
183  mshadow::Copy(err_cpu.get<cpu, 1, DType>(), val_xpu.get<xpu, 1, DType>(s), s);
184  });
185  });
186  });
187  }
188 }
189 
198 template <typename xpu>
200  const NDArray& input,
201  const TBlob& err_cpu,
202  const bool full_check) {
203  using namespace op::mxnet_op;
204  CHECK_EQ(input.storage_type(), kRowSparseStorage) << "CheckFormatRSPImpl is for RSPNDArray";
205  const mxnet::TShape idx_shape = input.aux_shape(rowsparse::kIdx);
206  if (idx_shape[0] != input.storage_shape()[0]) {
207  MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, {
208  DType* err = err_cpu.dptr<DType>();
209  *err = kRSPShapeErr;
210  });
211  return;
212  }
213  if (idx_shape[0] == 0) {
214  return;
215  }
216  if (full_check) {
217  MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, {
218  MSHADOW_IDX_TYPE_SWITCH(input.aux_type(rowsparse::kIdx), IType, {
219  mshadow::Stream<xpu>* s = rctx.get_stream<xpu>();
220  NDArray ret_xpu = NDArray(mshadow::Shape1(1), rctx.get_ctx(), false, err_cpu.type_flag_);
221  TBlob val_xpu = ret_xpu.data();
222  Kernel<set_to_int<kNormalErr>, xpu>::Launch(s, val_xpu.Size(), val_xpu.dptr<DType>());
223 
224  Kernel<rsp_idx_check, xpu>::Launch(s,
225  idx_shape[0],
226  val_xpu.dptr<DType>(),
227  input.aux_data(rowsparse::kIdx).dptr<IType>(),
228  idx_shape[0] - 1,
229  input.shape()[0]);
230  mshadow::Copy(err_cpu.get<cpu, 1, DType>(), val_xpu.get<xpu, 1, DType>(s), s);
231  });
232  });
233  }
234 }
235 
236 template <typename xpu>
237 void CheckFormatImpl(const RunContext& rctx,
238  const NDArray& input,
239  const TBlob& err_cpu,
240  const bool full_check) {
241  int stype = input.storage_type();
242  if (stype == kCSRStorage) {
243  CheckFormatCSRImpl<xpu>(rctx, input, err_cpu, full_check);
244  } else if (stype == kRowSparseStorage) {
245  CheckFormatRSPImpl<xpu>(rctx, input, err_cpu, full_check);
246  } else if (stype == kDefaultStorage) {
247  // no-op for default storage
248  } else {
249  LOG(FATAL) << "Unknown storage type " << stype;
250  }
251 }
252 
256 template <typename xpu>
258  const NDArray& input_nd,
259  const TBlob& idx_data,
260  const OpReqType req,
261  NDArray* output_nd);
262 
263 /* \brief Casts tensor storage type to the new type.
264  */
265 template <typename xpu>
266 void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output);
267 
271 inline bool ContainsOnlyStorage(const StorageTypeVector& vstorage, const NDArrayStorageType stype) {
272  if (!vstorage.empty()) {
273  for (const auto& i : vstorage) {
274  if (i != stype)
275  return false;
276  }
277  return true;
278  }
279  return false;
280 }
281 
286 inline bool ContainsOnlyStorage(const StorageTypeVector& vstorage,
287  const NDArrayStorageType stype1,
288  const NDArrayStorageType stype2,
289  bool* has_both) {
290  if (has_both) {
291  *has_both = false;
292  }
293  if (!vstorage.empty()) {
294  uint8_t has = 0;
295  for (const auto i : vstorage) {
296  if (i == stype1) {
297  has |= 1;
298  } else if (i == stype2) {
299  has |= 2;
300  } else {
301  return false;
302  }
303  }
304  if (has_both) {
305  *has_both = has == 3;
306  }
307  return true;
308  }
309  return false;
310 }
311 
315 inline bool ContainsOnlyStorage(const std::vector<NDArray>& ndarrays,
316  const NDArrayStorageType stype) {
317  if (!ndarrays.empty()) {
318  for (const auto& nd : ndarrays) {
319  if (nd.storage_type() != stype) {
320  return false;
321  }
322  }
323  return true;
324  }
325  return false;
326 }
327 
331 inline bool ContainsOnlyStorage(const std::vector<NDArray>& ndarrays,
332  const NDArrayStorageType stype1,
333  const NDArrayStorageType stype2,
334  bool* has_both) {
335  if (has_both) {
336  *has_both = false;
337  }
338  if (!ndarrays.empty()) {
339  uint8_t has = 0;
340  for (const auto& nd : ndarrays) {
341  const NDArrayStorageType stype = nd.storage_type();
342  if (stype == stype1) {
343  has |= 1;
344  } else if (stype == stype2) {
345  has |= 2;
346  } else {
347  return false;
348  }
349  }
350  if (has_both) {
351  *has_both = has == 3;
352  }
353  return true;
354  }
355  return false;
356 }
357 
361 inline bool ContainsStorageType(const std::vector<NDArray>& ndarrays,
362  const NDArrayStorageType stype) {
363  if (!ndarrays.empty()) {
364  for (const auto& nd : ndarrays) {
365  if (nd.storage_type() == stype) {
366  return true;
367  }
368  }
369  }
370  return false;
371 }
372 
376 inline bool ContainsStorageType(const std::vector<int>& ndstypes, const NDArrayStorageType stype) {
377  if (!ndstypes.empty()) {
378  for (const auto& ndstype : ndstypes) {
379  if (ndstype == stype) {
380  return true;
381  }
382  }
383  }
384  return false;
385 }
386 
388 inline std::string dispatch_mode_string(const DispatchMode x) {
389  switch (x) {
391  return "fcompute";
393  return "fcompute_ex";
395  return "fcompute_fallback";
397  return "variable";
399  return "undefined";
400  }
401  return "unknown";
402 }
403 
405 inline std::string stype_string(const int x) {
406  switch (x) {
407  case kDefaultStorage:
408  return "default";
409  case kCSRStorage:
410  return "csr";
411  case kRowSparseStorage:
412  return "row_sparse";
413  }
414  return "unknown";
415 }
416 
418 inline std::string dev_type_string(const int dev_type) {
419  switch (dev_type) {
420  case Context::kCPU:
421  return "cpu";
422  case Context::kGPU:
423  return "gpu";
424  case Context::kCPUPinned:
425  return "cpu_pinned";
426  case Context::kCPUShared:
427  return "cpu_shared";
428  }
429  return "unknown";
430 }
431 
432 inline std::string attr_value_string(const nnvm::NodeAttrs& attrs,
433  const std::string& attr_name,
434  std::string default_val = "") {
435  if (attrs.dict.find(attr_name) == attrs.dict.end()) {
436  return default_val;
437  }
438  return attrs.dict.at(attr_name);
439 }
440 
442 template <typename Fn>
443 inline void attr_foreach(const nnvm::NodeAttrs& attrs, const std::string& attr_name, const Fn& fn) {
444  const auto& found_it = attrs.dict.find(attr_name);
445  if (found_it != attrs.dict.end()) {
446  fn(found_it->second);
447  }
448  for (const auto& subgraph : attrs.subgraphs) {
449  DFSVisit(subgraph->outputs,
450  [&](const nnvm::ObjectPtr& node) { attr_foreach(node->attrs, attr_name, fn); });
451  }
452 }
453 
454 template <typename ValueType>
455 inline ValueType flag_attr_accumulate(const nnvm::NodeAttrs& attrs, const std::string& attr_name) {
456  static_assert(std::is_integral<ValueType>::value, "ValueType must be an integral type.");
457 
458  ValueType result = 0;
459  attr_foreach(attrs, attr_name, [&](const std::string& attr_value) {
460  std::istringstream ss(attr_value);
461  ValueType temp;
462  ss >> temp;
463  result |= temp;
464 
465  if (ss.fail() || !ss.eof()) {
466  LOG(WARNING) << "Incorrect value of an attribute: " << attr_name
467  << ". Expected an integer, while got: " << attr_value;
468  }
469  });
470  return result;
471 }
472 
474 inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs,
475  const int dev_mask,
476  const std::vector<int>& in_attrs,
477  const std::vector<int>& out_attrs) {
478  std::ostringstream os;
479  os << "operator = " << attrs.op->name << "\ninput storage types = [";
480  for (const int attr : in_attrs) {
481  os << stype_string(attr) << ", ";
482  }
483  os << "]\n"
484  << "output storage types = [";
485  for (const int attr : out_attrs) {
486  os << stype_string(attr) << ", ";
487  }
488  os << "]\n"
489  << "params = {";
490  for (auto kv : attrs.dict) {
491  os << "\"" << kv.first << "\" : " << kv.second << ", ";
492  }
493  os << "}\n"
494  << "context.dev_mask = " << dev_type_string(dev_mask);
495  return os.str();
496 }
497 
499 inline std::string operator_string(const nnvm::NodeAttrs& attrs,
500  const OpContext& ctx,
501  const std::vector<NDArray>& inputs,
502  const std::vector<OpReqType>& req,
503  const std::vector<NDArray>& outputs) {
504  std::string result = "";
505  std::vector<int> in_stypes;
506  std::vector<int> out_stypes;
507  in_stypes.reserve(inputs.size());
508  out_stypes.reserve(outputs.size());
509  auto xform = [](const NDArray arr) -> int { return arr.storage_type(); };
510  std::transform(inputs.begin(), inputs.end(), std::back_inserter(in_stypes), xform);
511  std::transform(outputs.begin(), outputs.end(), std::back_inserter(out_stypes), xform);
512  result += operator_stype_string(attrs, ctx.run_ctx.ctx.dev_mask(), in_stypes, out_stypes);
513  return result;
514 }
515 
517 inline void LogOnce(const std::string& message) {
519  auto log_store = LogStore::Get();
520  if (log_store->find(message) == log_store->end()) {
521  LOG(INFO) << message;
522  log_store->insert(message);
523  }
524 }
525 
528 inline void LogStorageFallback(const nnvm::NodeAttrs& attrs,
529  const int dev_mask,
530  const std::vector<int>* in_attrs,
531  const std::vector<int>* out_attrs) {
532  static bool log = dmlc::GetEnv("MXNET_STORAGE_FALLBACK_LOG_VERBOSE", true);
533  if (!log)
534  return;
535  const std::string op_str = operator_stype_string(attrs, dev_mask, *in_attrs, *out_attrs);
536  std::ostringstream os;
537  const char* warning =
538  "\n WARNING:\n"
539  "Execution of the operator above will fallback to the generic implementation "
540 #if MXNET_USE_ONEDNN == 1
541  "(not utilizing kernels from oneDNN library) "
542 #endif
543  "with default dense storage type. You are seeing this warning message because "
544 #if MXNET_USE_ONEDNN == 1
545  "MXNET_ONEDNN_ENABLED flag is set to 0, in which case you can re-enable the default "
546  "execution path by setting MXNET_ONEDNN_ENABLED back to 1, or "
547 #endif
548  "the operator above is unable to process the given ndarrays with specified storage types, "
549  "context and/or parameter, in which case temporary dense ndarrays are generated in order to "
550  "execute the operator. The fallback does not affect the correctness of the programme. Using "
551  "default storage type performance degradation might be observed. \nYou can set environment "
552  "variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning.";
553  os << "\nStorage type fallback detected:\n" << op_str << warning;
554  LogOnce(os.str());
555 #if MXNET_USE_ONEDNN == 1
556  if (GetDNNLCacheSize() != -1)
558  "MXNET_ONEDNN_CACHE_NUM is set."
559  "Should only be set if "
560  "your model has variable input shapes, "
561  "as cache size may grow unbounded");
562 #endif
563 }
564 
565 // heuristic to dermine number of threads per GPU
566 inline int GetNumThreadsPerGPU() {
567  // This is resource efficient option.
568  return dmlc::GetEnv("MXNET_GPU_WORKER_NTHREADS", 2);
569 }
570 
571 // heuristic to get number of matching colors.
572 // this decides how much parallelism we can get in each GPU.
573 inline int GetExecNumMatchColor() {
574  // This is resource efficient option.
575  int num_match_color = dmlc::GetEnv("MXNET_EXEC_NUM_TEMP", 1);
576  return std::min(num_match_color, GetNumThreadsPerGPU());
577 }
578 
579 template <typename T, typename V>
580 V ParallelAccumulate(const T* a, const int n, V start) {
581  V sum = start;
582 #pragma omp parallel for reduction(+ : sum)
583  for (int i = 0; i < n; ++i) {
584  sum += a[i];
585  }
586  return sum;
587 }
588 
596 template <typename RandomIt, typename Compare>
597 void ParallelSortHelper(RandomIt first, size_t len, size_t grainsize, const Compare& comp) {
598  if (len < grainsize) {
599  std::sort(first, first + len, comp);
600  } else {
601  std::thread thr(ParallelSortHelper<RandomIt, Compare>, first, len / 2, grainsize, comp);
602  ParallelSortHelper(first + len / 2, len - len / 2, grainsize, comp);
603  thr.join();
604  std::inplace_merge(first, first + len / 2, first + len, comp);
605  }
606 }
607 
617 template <typename RandomIt, typename Compare>
618 void ParallelSort(RandomIt first, RandomIt last, size_t num_threads, Compare comp) {
619  const auto num = std::distance(first, last);
620  size_t grainsize = std::max(num / num_threads + 5, static_cast<size_t>(1024 * 16));
621  ParallelSortHelper(first, num, grainsize, comp);
622 }
623 
633 template <typename RandomIt>
634 void ParallelSort(RandomIt first, RandomIt last, size_t num_threads) {
635  ParallelSort(
636  first, last, num_threads, std::less<typename std::iterator_traits<RandomIt>::value_type>());
637 }
638 
642 typedef std::mt19937 RANDOM_ENGINE;
643 
647 namespace helper {
648 
652 template <class T>
653 struct UniqueIf {
657  using SingleObject = std::unique_ptr<T>;
658 };
659 
663 template <class T>
664 struct UniqueIf<T[]> {
668  using UnknownBound = std::unique_ptr<T[]>;
669 };
670 
674 template <class T, size_t kSize>
675 struct UniqueIf<T[kSize]> {
679  using KnownBound = void;
680 };
681 
682 } // namespace helper
683 
695 template <class T, class... Args>
697  return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
698 }
699 
709 template <class T>
711  using U = typename std::remove_extent<T>::type;
712  return std::unique_ptr<T>(new U[n]{});
713 }
714 
723 template <class T, class... Args>
724 typename helper::UniqueIf<T>::KnownBound MakeUnique(Args&&... args) = delete;
725 
726 template <typename FCompType>
727 FCompType GetFCompute(const nnvm::Op* op, const std::string& name, const Context& ctx) {
728  static auto& fcompute_cpu = nnvm::Op::GetAttr<FCompType>(name + "<cpu>");
729  static auto& fcompute_gpu = nnvm::Op::GetAttr<FCompType>(name + "<gpu>");
730 
731  if (ctx.dev_mask() == cpu::kDevMask) {
732  return fcompute_cpu.get(op, nullptr);
733  } else if (ctx.dev_mask() == gpu::kDevMask) {
734  return fcompute_gpu.get(op, nullptr);
735  } else {
736  LOG(FATAL) << "Unknown device mask " << ctx.dev_mask();
737  return nullptr;
738  }
739 }
740 
744 template <typename T>
745 constexpr size_t MaxIntegerValue() {
746  return std::is_integral<T>::value ? std::numeric_limits<T>::max() :
747  size_t(2) << (std::numeric_limits<T>::digits - 1);
748 }
749 
750 template <>
751 constexpr size_t MaxIntegerValue<mshadow::half::half_t>() {
752  return size_t(2) << 10;
753 }
754 
755 template <>
756 constexpr size_t MaxIntegerValue<mshadow::bfloat::bf16_t>() {
757  return size_t(2) << 14;
758 }
759 
760 MSHADOW_XINLINE int ilog2ul(size_t a) {
761  int k = 1;
762  while (a >>= 1)
763  ++k;
764  return k;
765 }
766 
767 MSHADOW_XINLINE int ilog2ui(unsigned int a) {
768  int k = 1;
769  while (a >>= 1)
770  ++k;
771  return k;
772 }
773 
778  const mxnet::TShape& shape,
779  const Context& ctx,
780  const int dtype) {
781  // NDArray with default storage
782  if (stype == kDefaultStorage) {
783  NDArray ret(shape, ctx, false, dtype);
784  ret = 0;
785  return ret;
786  }
787  // NDArray with non-default storage. Storage allocation is always delayed.
788  return NDArray(stype, shape, ctx, true, dtype);
789 }
790 
794 inline void EmplaceBackZeros(const NDArrayStorageType stype,
795  const mxnet::TShape& shape,
796  const Context& ctx,
797  const int dtype,
798  std::vector<NDArray>* vec) {
799  // NDArray with default storage
800  if (stype == kDefaultStorage) {
801  vec->emplace_back(shape, ctx, false, dtype);
802  vec->back() = 0;
803  } else {
804  // NDArray with non-default storage. Storage allocation is always delayed.
805  vec->emplace_back(stype, shape, ctx, true, dtype);
806  }
807 }
808 
812 template <typename DType>
813 inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
814  static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000);
815  if (size >= copy_block_size) {
816 #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
817  for (index_t i = 0; i < size; ++i) {
818  dst[i] = src[i];
819  }
820  } else {
821 #pragma GCC diagnostic push
822 #if __GNUC__ >= 8
823 #pragma GCC diagnostic ignored "-Wclass-memaccess"
824 #endif
825  std::memcpy(dst, src, sizeof(DType) * size);
826 #pragma GCC diagnostic pop
827  }
828 }
829 
833 template <typename DType>
834 inline void ParallelAdd(DType* dst, const DType* src, index_t size) {
835  static index_t add_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000);
836  if (size >= add_block_size) {
837 #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
838  for (index_t i = 0; i < size; ++i) {
839  dst[i] += src[i];
840  }
841  } else {
842  for (index_t i = 0; i < size; ++i) {
843  dst[i] += src[i];
844  }
845  }
846 }
847 
866 inline void ConvertToNumpyShape(mxnet::TShape* shape) {
867  if (shape->ndim() == 0) { // legacy shape ndim = 0 means unknown
868  *shape = mxnet::TShape(); // unknown shape ndim = -1
869  } else {
870  for (int j = 0; j < shape->ndim(); ++j) {
871  if ((*shape)[j] == 0) { // legacy shape dim_size = 0 means unknown
872  (*shape)[j] = -1; // unknown dim size = -1
873  }
874  }
875  }
876 }
877 
879  for (size_t i = 0; i < shapes->size(); ++i) {
880  ConvertToNumpyShape(&(shapes->at(i)));
881  }
882 }
883 
888 inline void ConvertToLegacyShape(mxnet::TShape* shape) {
889  if (!mxnet::ndim_is_known(*shape)) {
890  *shape = mxnet::TShape(0, -1);
891  } else {
892  for (int j = 0; j < shape->ndim(); ++j) {
893  if (!mxnet::dim_size_is_known(*shape, j)) {
894  (*shape)[j] = 0;
895  }
896  }
897  }
898 }
899 
901  for (size_t i = 0; i < shapes->size(); ++i) {
902  ConvertToLegacyShape(&(shapes->at(i)));
903  }
904 }
906  const nnvm::IndexedGraph& idx,
907  const std::vector<NDArray*>& state_arrays,
908  size_t nid,
909  const std::function<void(const char*, const char*, void*)>& monitor_callback);
910 
912  const nnvm::IndexedGraph& idx,
913  const std::vector<NDArray*>& state_arrays,
914  size_t nid,
915  const std::function<void(const char*, const char*, void*)>& monitor_callback);
916 
918  // convert negative axes to positive values
919  const int ndim = src.ndim();
920  mxnet::TShape axes = src;
921  for (int i = 0; i < ndim; ++i) {
922  if (axes[i] < 0) {
923  axes[i] += ndim;
924  }
925  CHECK(axes[i] >= 0 && axes[i] < ndim)
926  << "axes[" << i << "]=" << axes[i] << " exceeds the range [" << 0 << ", " << ndim << ")";
927  }
928  return axes;
929 }
930 
931 inline bool is_float(const int dtype) {
932  return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16 ||
933  dtype == mshadow::kBfloat16;
934 }
935 
936 inline bool is_int(const int dtype) {
937  return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 || dtype == mshadow::kUint16 ||
938  dtype == mshadow::kInt16 || dtype == mshadow::kUint32 || dtype == mshadow::kInt32 ||
939  dtype == mshadow::kUint64 || dtype == mshadow::kInt64;
940 }
941 
942 inline bool is_signed_int(const int dtype) {
943  return dtype == mshadow::kInt8 || dtype == mshadow::kInt16 || dtype == mshadow::kInt32 ||
944  dtype == mshadow::kInt64;
945 }
946 
947 inline bool is_unsigned_int(const int dtype) {
948  return dtype == mshadow::kUint8 || dtype == mshadow::kUint16 || dtype == mshadow::kUint32 ||
949  dtype == mshadow::kUint64;
950 }
951 
952 static int bits_of(const int type_flag) {
953  switch (type_flag) {
954  case mshadow::kFloat32:
955  return sizeof(float) * CHAR_BIT;
956  case mshadow::kFloat64:
957  return sizeof(double) * CHAR_BIT;
958  case mshadow::kUint8:
959  return sizeof(uint8_t) * CHAR_BIT;
960  case mshadow::kInt32:
961  return sizeof(int32_t) * CHAR_BIT;
962  case mshadow::kInt8:
963  return sizeof(int8_t) * CHAR_BIT;
964  case mshadow::kInt64:
965  return sizeof(int64_t) * CHAR_BIT;
966  case mshadow::kBool:
967  return sizeof(bool) * CHAR_BIT;
968  case mshadow::kInt16:
969  return sizeof(int16_t) * CHAR_BIT;
970  case mshadow::kUint16:
971  return sizeof(uint16_t) * CHAR_BIT;
972  case mshadow::kUint32:
973  return sizeof(uint32_t) * CHAR_BIT;
974  case mshadow::kUint64:
975  return sizeof(uint64_t) * CHAR_BIT;
976  default: {
977  LOG(FATAL) << "Unknown type_flag=" << type_flag;
978  return -1;
979  }
980  }
981 }
982 
983 inline int type_promotion(const int type1, const int type2) {
984  if (type1 == type2)
985  return type1;
986  if (is_float(type1) && is_float(type2)) {
987  if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
988  return mshadow::kFloat64;
989  }
990  if (type1 == mshadow::kFloat32 || type2 == mshadow::kFloat32) {
991  return mshadow::kFloat32;
992  }
993  return mshadow::kFloat16;
994  } else if (is_float(type1) || is_float(type2)) {
995  return is_float(type1) ? type1 : type2;
996  }
997  if (is_signed_int(type1) && is_signed_int(type2)) {
998  if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) {
999  return mshadow::kInt64;
1000  }
1001  if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) {
1002  return mshadow::kInt32;
1003  }
1004  if (type1 == mshadow::kInt16 || type2 == mshadow::kInt16) {
1005  return mshadow::kInt16;
1006  }
1007  return mshadow::kInt8;
1008  } else if (is_unsigned_int(type1) && is_unsigned_int(type2)) {
1009  if (type1 == mshadow::kUint64 || type2 == mshadow::kUint64) {
1010  return mshadow::kUint64;
1011  }
1012  if (type1 == mshadow::kUint32 || type2 == mshadow::kUint32) {
1013  return mshadow::kUint32;
1014  }
1015  if (type1 == mshadow::kUint16 || type2 == mshadow::kUint16) {
1016  return mshadow::kUint16;
1017  }
1018  return mshadow::kUint8;
1019  } else if (type1 == mshadow::kBool) {
1020  return type2;
1021  } else if (type2 == mshadow::kBool) {
1022  return type1;
1023  } else if (is_unsigned_int(type1) || is_unsigned_int(type2)) {
1024  if (bits_of(type1) < bits_of(type2)) {
1025  if (type1 == mshadow::kInt8 && type2 == mshadow::kUint16) {
1026  return mshadow::kInt32;
1027  } else if (type1 == mshadow::kInt8 && type2 == mshadow::kUint32) {
1028  return mshadow::kInt64;
1029  } else if (type1 == mshadow::kInt16 && type2 == mshadow::kUint32) {
1030  return mshadow::kInt64;
1031  } else if (type2 == mshadow::kUint64) {
1032  LOG(FATAL) << "Unsupported type promotions between " << mshadow::dtype_string(type1)
1033  << " and " << mshadow::dtype_string(type2);
1034  } else {
1035  return type2;
1036  }
1037  } else if (bits_of(type2) < bits_of(type1)) {
1038  if (type2 == mshadow::kInt8 && type1 == mshadow::kUint16) {
1039  return mshadow::kInt32;
1040  } else if (type2 == mshadow::kInt8 && type1 == mshadow::kUint32) {
1041  return mshadow::kInt64;
1042  } else if (type2 == mshadow::kInt16 && type1 == mshadow::kUint32) {
1043  return mshadow::kInt64;
1044  } else if (type1 == mshadow::kUint64) {
1045  LOG(FATAL) << "Unsupported type promotions between " << mshadow::dtype_string(type1)
1046  << " and " << mshadow::dtype_string(type2);
1047  } else {
1048  return type1;
1049  }
1050  } else {
1051  if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) {
1052  return mshadow::kInt16;
1053  }
1054  if (type1 == mshadow::kUint16 || type2 == mshadow::kUint16) {
1055  return mshadow::kInt32;
1056  }
1057  if (type1 == mshadow::kUint32 || type2 == mshadow::kUint32) {
1058  return mshadow::kInt64;
1059  }
1060  }
1061  }
1062  LOG(FATAL) << "Unsupported type promotions between " << mshadow::dtype_string(type1) << " and "
1063  << mshadow::dtype_string(type2);
1064  return -1;
1065 }
1066 
1067 inline const std::string NodeAttrsGetProfilerScope(const nnvm::NodeAttrs& attrs) {
1068  // obtain the profiler scope name, if assigned previously
1069  std::string profiler_scope = MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR;
1070  const std::unordered_map<std::string, std::string>& node_attrs_dict = attrs.dict;
1071  const std::unordered_map<std::string, std::string>::const_iterator profiler_scope_iter =
1072  node_attrs_dict.find("__profiler_scope__");
1073  if (profiler_scope_iter != node_attrs_dict.end()) {
1074  profiler_scope = profiler_scope_iter->second;
1075  }
1076  return profiler_scope;
1077 }
1078 
1079 inline int GetDefaultDtype() {
1081 }
1082 
1083 inline int GetDefaultDtype(int dtype) {
1084  if (dtype != -1)
1085  return dtype;
1087 }
1088 
1090  std::string name;
1091  int size;
1093 
1094  MShadowTypeInfo(const std::string name, const int size, const int acc_size)
1095  : name(std::move(name)), size(size), acc_size(acc_size) {}
1096 
1097  MShadowTypeInfo(const std::string name, const int size) : MShadowTypeInfo(name, size, size) {}
1098 };
1099 
1100 MShadowTypeInfo mshadow_type_info(const int type_flag);
1101 
1102 inline bool AlignedMemAlloc(void** ptr, size_t size, size_t alignment) {
1103 #if _MSC_VER
1104  *ptr = _aligned_malloc(size, alignment);
1105  if (*ptr == nullptr)
1106  return false;
1107 #else
1108  int res = posix_memalign(ptr, alignment, size);
1109  if (res != 0)
1110  return false;
1111 #endif
1112  return true;
1113 }
1114 
1115 inline void AlignedMemFree(void* ptr) {
1116 #if _MSC_VER
1117  _aligned_free(ptr);
1118 #else
1119  free(ptr);
1120 #endif
1121 }
1122 
1123 inline index_t div_round(const index_t a, const index_t b) {
1124  return (a + b - 1) / b;
1125 }
1126 
1127 inline bool IsPower2(size_t N) {
1128  return ((N & (N - 1)) == 0) && N != 0;
1129 }
1130 
1131 inline size_t RoundToPower2(size_t N) {
1132  size_t ret = 1;
1133  size_t copyN = N;
1134  while (N >= 2) {
1135  ret *= 2;
1136  N /= 2;
1137  }
1138  if (ret < copyN) {
1139  ret *= 2;
1140  }
1141  return ret;
1142 }
1143 
1144 } // namespace common
1145 } // namespace mxnet
1146 #endif // MXNET_COMMON_UTILS_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::ndim_is_known
bool ndim_is_known(const int ndim)
Definition: tuple.h:416
mxnet::common::ExecuteMonOutputCallback
void ExecuteMonOutputCallback(const nnvm::IndexedGraph &idx, const std::vector< NDArray * > &state_arrays, size_t nid, const std::function< void(const char *, const char *, void *)> &monitor_callback)
MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR
#define MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR
Definition: storage.h:34
mxnet::common::MShadowTypeInfo::size
int size
Definition: utils.h:1091
mxnet::common::CheckFormatImpl
void CheckFormatImpl(const RunContext &rctx, const NDArray &input, const TBlob &err_cpu, const bool full_check)
Definition: utils.h:237
mxnet::common::flag_attr_accumulate
ValueType flag_attr_accumulate(const nnvm::NodeAttrs &attrs, const std::string &attr_name)
Definition: utils.h:455
mxnet::Context::kCPU
@ kCPU
Definition: base.h:93
mxnet::common::csr_idx_check
Indices should be non-negative, less than the number of columns and in ascending order per row.
Definition: utils.h:94
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
mxnet::common::ilog2ui
MSHADOW_XINLINE int ilog2ui(unsigned int a)
Definition: utils.h:767
mxnet::DispatchMode::kVariable
@ kVariable
mshadow::kUint16
@ kUint16
Definition: base.h:361
mxnet::common::is_unsigned_int
bool is_unsigned_int(const int dtype)
Definition: utils.h:947
mshadow::kUint64
@ kUint64
Definition: base.h:363
mxnet::common::attr_foreach
void attr_foreach(const nnvm::NodeAttrs &attrs, const std::string &attr_name, const Fn &fn)
Seeks an attribute in a node and its subgraphs and invokes a function on each.
Definition: utils.h:443
mxnet::common::GetFCompute
FCompType GetFCompute(const nnvm::Op *op, const std::string &name, const Context &ctx)
Definition: utils.h:727
nnvm::Op::name
std::string name
name of the operator
Definition: op.h:108
mxnet::common::LogOnce
void LogOnce(const std::string &message)
log message once. Intended for storage fallback warning messages.
Definition: utils.h:517
mxnet::common::current_process_id
size_t current_process_id()
Definition: utils.h:69
imperative.h
mxnet::common::GetNumThreadsPerGPU
int GetNumThreadsPerGPU()
Definition: utils.h:566
mxnet::common::csr_idx_check::Map
static MSHADOW_XINLINE void Map(int i, DType *out, const IType *idx, const RType *indptr, const nnvm::dim_t ncols)
Definition: utils.h:96
mxnet::kDefaultStorage
@ kDefaultStorage
Definition: ndarray.h:63
mxnet::kCSRIndPtrErr
@ kCSRIndPtrErr
Definition: ndarray.h:71
mxnet::OpReqType
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
dmlc::ThreadLocalStore
A threadlocal store to store threadlocal variables. Will return a thread local singleton of type T.
Definition: thread_local.h:35
nnvm::NodeAttrs::op
const Op * op
The operator this node uses. For place holder variable, op == nullptr.
Definition: node.h:112
mxnet::common::MaxIntegerValue
constexpr size_t MaxIntegerValue()
Return the max integer value representable in the type T without loss of precision.
Definition: utils.h:745
mshadow::kInt8
@ kInt8
Definition: base.h:357
mxnet::common::is_signed_int
bool is_signed_int(const int dtype)
Definition: utils.h:942
mshadow::kUint32
@ kUint32
Definition: base.h:362
mxnet::common::AlignedMemFree
void AlignedMemFree(void *ptr)
Definition: utils.h:1115
mxnet::DispatchMode
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:122
mxnet::common::attr_value_string
std::string attr_value_string(const nnvm::NodeAttrs &attrs, const std::string &attr_name, std::string default_val="")
Definition: utils.h:432
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mxnet::common::ExecuteMonInputCallback
void ExecuteMonInputCallback(const nnvm::IndexedGraph &idx, const std::vector< NDArray * > &state_arrays, size_t nid, const std::function< void(const char *, const char *, void *)> &monitor_callback)
mxnet::common::csr_indptr_check::Map
static MSHADOW_XINLINE void Map(int i, DType *out, const IType *indptr, const nnvm::dim_t end, const nnvm::dim_t idx_size)
Definition: utils.h:79
nnvm::IndexedGraph
Auxiliary data structure to index a graph. It maps Nodes in the graph to consecutive integers node_id...
Definition: graph.h:108
mxnet::common::ParallelAdd
void ParallelAdd(DType *dst, const DType *src, index_t size)
Definition: utils.h:834
mxnet::NDArray::storage_shape
const mxnet::TShape & storage_shape() const
Definition: ndarray.h:252
mxnet::csr::kIndPtr
@ kIndPtr
Definition: ndarray.h:54
mxnet::common::helper::UniqueIf< T[kSize]>::KnownBound
void KnownBound
Type of T.
Definition: utils.h:679
mxnet::common::InitZeros
NDArray InitZeros(const NDArrayStorageType stype, const mxnet::TShape &shape, const Context &ctx, const int dtype)
Return an NDArray of all zeros.
Definition: utils.h:777
mxnet::RunContext
execution time context. The information needed in runtime for actual execution.
Definition: base.h:343
mxnet::common::LogStorageFallback
void LogStorageFallback(const nnvm::NodeAttrs &attrs, const int dev_mask, const std::vector< int > *in_attrs, const std::vector< int > *out_attrs)
log storage fallback event
Definition: utils.h:528
mshadow::cpu::kDevMask
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:43
mxnet::Tuple::ndim
int ndim() const
Definition: tuple.h:217
nnvm::NodeAttrs::dict
std::unordered_map< std::string, std::string > dict
The dictionary representation of attributes.
Definition: node.h:116
mxnet::OpContext
All the possible information needed by Operator. This is the superset of RunContext....
Definition: op_attr_types.h:66
mxnet::StorageTypeVector
std::vector< int > StorageTypeVector
The result holder of storage type of each NodeEntry in the graph.
Definition: graph_attr_types.h:45
mxnet::common::mshadow_type_info
MShadowTypeInfo mshadow_type_info(const int type_flag)
mshadow::kBool
@ kBool
Definition: base.h:359
mxnet::common::ParallelSort
void ParallelSort(RandomIt first, RandomIt last, size_t num_threads, Compare comp)
Sort the elements in the range [first, last) into the ascending order defined by the comparator comp....
Definition: utils.h:618
mxnet::csr::kIdx
@ kIdx
Definition: ndarray.h:54
mxnet::common::ContainsStorageType
bool ContainsStorageType(const std::vector< NDArray > &ndarrays, const NDArrayStorageType stype)
returns true if storage type of any array in ndarrays is the same as the target stype....
Definition: utils.h:361
mxnet::common::operator_string
std::string operator_string(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)
get string representation of the operator
Definition: utils.h:499
mxnet::Context::kCPUShared
@ kCPUShared
Definition: base.h:96
mxnet::Context::dev_mask
DeviceType dev_mask() const
Get corresponding device mask.
Definition: base.h:108
mshadow::kFloat64
@ kFloat64
Definition: base.h:353
mxnet::common::csr_indptr_check
IndPtr should be non-negative, in non-decreasing order, start with 0 and end with value equal with si...
Definition: utils.h:77
mxnet::NDArray::aux_shape
const mxnet::TShape & aux_shape(size_t index) const
get the shape of aux_data(index)
Definition: ndarray.h:264
mxnet::DispatchMode::kFComputeFallback
@ kFComputeFallback
mxnet::kCSRIdxErr
@ kCSRIdxErr
Definition: ndarray.h:72
mxnet::NDArrayStorageType
NDArrayStorageType
Definition: ndarray.h:61
mxnet::common::MShadowTypeInfo::MShadowTypeInfo
MShadowTypeInfo(const std::string name, const int size)
Definition: utils.h:1097
mxnet::Imperative::is_np_default_dtype
bool is_np_default_dtype() const
return current numpy default dtype compatibility status.
Definition: imperative.h:234
mxnet::Context::kCPUPinned
@ kCPUPinned
Definition: base.h:95
nnvm::NodeAttrs
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
mshadow::kInt16
@ kInt16
Definition: base.h:360
graph_attr_types.h
Data structures that can appear in graph attributes.
mxnet::common::ParallelSortHelper
void ParallelSortHelper(RandomIt first, size_t len, size_t grainsize, const Compare &comp)
Helper function for ParallelSort. DO NOT call this function directly. Use the interface ParallelSort ...
Definition: utils.h:597
nnvm::ObjectPtr
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
mxnet::TBlob::type_flag_
int type_flag_
type flag of the tensor blob
Definition: tensor_blob.h:74
mxnet::common::type_promotion
int type_promotion(const int type1, const int type2)
Definition: utils.h:983
mxnet::kRSPIdxErr
@ kRSPIdxErr
Definition: ndarray.h:74
mxnet::common::MShadowTypeInfo::name
std::string name
Definition: utils.h:1090
mxnet::common::AlignedMemAlloc
bool AlignedMemAlloc(void **ptr, size_t size, size_t alignment)
Definition: utils.h:1102
mxnet::common::div_round
index_t div_round(const index_t a, const index_t b)
Definition: utils.h:1123
mxnet::common::operator_stype_string
std::string operator_stype_string(const nnvm::NodeAttrs &attrs, const int dev_mask, const std::vector< int > &in_attrs, const std::vector< int > &out_attrs)
get string representation of the operator stypes
Definition: utils.h:474
nnvm::NodeAttrs::subgraphs
std::vector< std::shared_ptr< Symbol > > subgraphs
Some operators take graphs as input. These operators include control flow operators and high-order fu...
Definition: node.h:137
mxnet::NDArray
ndarray interface
Definition: ndarray.h:82
mxnet::common::is_int
bool is_int(const int dtype)
Definition: utils.h:936
mshadow::kInt64
@ kInt64
Definition: base.h:358
mxnet::common::ConvertToLegacyShape
void ConvertToLegacyShape(mxnet::TShape *shape)
This is function is used to convert shapes returned by the infer shape functions/pass to the legacy s...
Definition: utils.h:888
omp.h
header to handle OpenMP compatibility issues
mxnet::Imperative::Get
static Imperative * Get()
mxnet::TBlob
tensor blob class that can be used to hold tensor of any dimension, any device and any data type,...
Definition: tensor_blob.h:65
mxnet::NDArray::shape
const mxnet::TShape & shape() const
Definition: ndarray.h:244
mxnet::common::SparseRetainOpForwardRspWrapper
void SparseRetainOpForwardRspWrapper(mshadow::Stream< xpu > *s, const NDArray &input_nd, const TBlob &idx_data, const OpReqType req, NDArray *output_nd)
Pick rows specified by user input index array from a row sparse ndarray and save them in the output s...
mshadow::kInt32
@ kInt32
Definition: base.h:356
mxnet::common::CheckFormatWrapper
void CheckFormatWrapper(const RunContext &rctx, const NDArray &input, const TBlob &err_cpu, const bool full_check)
mxnet::common::helper::UniqueIf::SingleObject
std::unique_ptr< T > SingleObject
Type of T.
Definition: utils.h:657
mxnet::common::rsp_idx_check
Indices of RSPNDArray should be non-negative, less than the size of first dimension and in ascending ...
Definition: utils.h:114
mxnet::Context::kGPU
@ kGPU
Definition: base.h:94
nnvm::dim_t
int64_t dim_t
data type to store dim size
Definition: tuple.h:39
mxnet::common::helper::UniqueIf< T[]>::UnknownBound
std::unique_ptr< T[]> UnknownBound
Type of T.
Definition: utils.h:668
mxnet::common::rsp_idx_check::Map
static MSHADOW_XINLINE void Map(int i, DType *out, const IType *idx, const nnvm::dim_t end, const nnvm::dim_t nrows)
Definition: utils.h:116
mxnet::common::CanonicalizeAxes
mxnet::TShape CanonicalizeAxes(const mxnet::TShape &src)
Definition: utils.h:917
mxnet::common::dev_type_string
std::string dev_type_string(const int dev_type)
get string representation of device type
Definition: utils.h:418
mxnet::common::RANDOM_ENGINE
std::mt19937 RANDOM_ENGINE
Random Engine.
Definition: utils.h:642
mshadow::dtype_string
std::string dtype_string(const int dtype)
Definition: base.h:1811
mxnet::Context
Context information about the execution environment.
Definition: base.h:90
mxnet::common::ContainsOnlyStorage
bool ContainsOnlyStorage(const StorageTypeVector &vstorage, const NDArrayStorageType stype)
returns true if all storage types in vstorage are the same as target stype. false is returned for emp...
Definition: utils.h:271
op_attr_types.h
Additional operator attributes beside the ones provided by NNVM.
graph_attr_types.h
Data structures that can appear in graph attributes.
storage.h
Storage manager across multiple devices.
mxnet::common::stype_string
std::string stype_string(const int x)
get string representation of storage_type
Definition: utils.h:405
nnvm::DFSVisit
void DFSVisit(const std::vector< NodeEntry > &heads, FVisit fvisit)
perform a Post Order DFS visit to each node in the graph. This order is deterministic and is also top...
Definition: graph.h:284
graph.h
Configuation of nnvm as well as basic data structure.
mxnet::dim_size_is_known
bool dim_size_is_known(const dim_t dim_size)
Definition: tuple.h:422
mxnet::common::dispatch_mode_string
std::string dispatch_mode_string(const DispatchMode x)
get string representation of dispatch_mode
Definition: utils.h:388
mxnet::common::EmplaceBackZeros
void EmplaceBackZeros(const NDArrayStorageType stype, const mxnet::TShape &shape, const Context &ctx, const int dtype, std::vector< NDArray > *vec)
Helper to add a NDArray of zeros to a std::vector.
Definition: utils.h:794
mxnet::common::RoundToPower2
size_t RoundToPower2(size_t N)
Definition: utils.h:1131
mxnet::index_t
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:81
mxnet::kCSRStorage
@ kCSRStorage
Definition: ndarray.h:65
std
Definition: optional.h:251
mxnet::DispatchMode::kFComputeEx
@ kFComputeEx
mshadow::kUint8
@ kUint8
Definition: base.h:355
mxnet::common::ConvertToNumpyShape
void ConvertToNumpyShape(mxnet::TShape *shape)
If numpy compatibility is turned off (default), the shapes passed in by users follow the legacy shape...
Definition: utils.h:866
mshadow::kBfloat16
@ kBfloat16
Definition: base.h:364
mxnet::ShapeVector
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:830
mxnet::rowsparse::kIdx
@ kIdx
Definition: ndarray.h:58
mxnet::RunContext::ctx
Context ctx
base Context
Definition: base.h:345
mxnet::TShape
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:440
engine.h
Engine that schedules all the operations according to dependency.
mxnet::DispatchMode::kUndefined
@ kUndefined
mxnet::common::ParallelAccumulate
V ParallelAccumulate(const T *a, const int n, V start)
Definition: utils.h:580
mxnet::OpContext::run_ctx
RunContext run_ctx
RunContext related resources.
Definition: op_attr_types.h:72
mxnet::common::ParallelCopy
void ParallelCopy(DType *dst, const DType *src, index_t size)
parallelize copy by OpenMP.
Definition: utils.h:813
ndarray.h
NDArray interface that handles array arithematics.
mxnet::common::MShadowTypeInfo
Definition: utils.h:1089
node.h
Graph node data structure.
mxnet::common::CastStorageDispatch
void CastStorageDispatch(const OpContext &ctx, const NDArray &input, const NDArray &output)
mshadow::gpu::kDevMask
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:50
mxnet::common::CheckFormatRSPImpl
void CheckFormatRSPImpl(const RunContext &rctx, const NDArray &input, const TBlob &err_cpu, const bool full_check)
Check the validity of RowSparseNDArray.
Definition: utils.h:199
nnvm::Op
Operator structure.
Definition: op.h:105
mxnet::NDArray::storage_type
NDArrayStorageType storage_type() const
Definition: ndarray.h:343
mshadow::kFloat16
@ kFloat16
Definition: base.h:354
mxnet::kRowSparseStorage
@ kRowSparseStorage
Definition: ndarray.h:64
mxnet::common::MShadowTypeInfo::MShadowTypeInfo
MShadowTypeInfo(const std::string name, const int size, const int acc_size)
Definition: utils.h:1094
mxnet::common::ilog2ul
MSHADOW_XINLINE int ilog2ul(size_t a)
Definition: utils.h:760
mxnet::common::NodeAttrsGetProfilerScope
const std::string NodeAttrsGetProfilerScope(const nnvm::NodeAttrs &attrs)
Definition: utils.h:1067
mxnet::common::GetDefaultDtype
int GetDefaultDtype()
Definition: utils.h:1079
mxnet::common::IsPower2
bool IsPower2(size_t N)
Definition: utils.h:1127
mxnet::common::is_float
bool is_float(const int dtype)
Definition: utils.h:931
mxnet::common::helper::UniqueIf
Helper for non-array type T.
Definition: utils.h:653
mxnet::common::MakeUnique
helper::UniqueIf< T >::SingleObject MakeUnique(Args &&... args)
Constructs an object of type T and wraps it in a std::unique_ptr.
Definition: utils.h:696
mxnet::common::GetExecNumMatchColor
int GetExecNumMatchColor()
Definition: utils.h:573
mxnet::common::MShadowTypeInfo::acc_size
int acc_size
Definition: utils.h:1092
mshadow::kFloat32
@ kFloat32
Definition: base.h:352
mxnet::common::CheckFormatCSRImpl
void CheckFormatCSRImpl(const RunContext &rctx, const NDArray &input, const TBlob &err_cpu, const bool full_check)
Check the validity of CSRNDArray.
Definition: utils.h:141
mxnet::DispatchMode::kFCompute
@ kFCompute
MSHADOW_TYPE_SWITCH
#define MSHADOW_TYPE_SWITCH(type, DType,...)
Definition: base.h:1163