mxnet
ndarray.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_NDARRAY_H_
26 #define MXNET_NDARRAY_H_
27 
28 #include <dmlc/base.h>
29 #include <dmlc/logging.h>
30 #include <dmlc/io.h>
31 #include <dmlc/type_traits.h>
32 #include <dmlc/registry.h>
33 #include <nnvm/node.h>
34 #include <vector>
35 #include <map>
36 #include <string>
37 #include <algorithm>
38 #include <memory>
39 #include <algorithm>
40 #if MXNET_USE_MKLDNN == 1
41 #include <mkldnn.hpp>
42 #endif
43 #include "./base.h"
44 #include "./storage.h"
45 #include "./engine.h"
46 // check c++11
47 #if DMLC_USE_CXX11 == 0
48 #error "cxx11 was required for ndarray module"
49 #endif
50 
51 namespace mxnet {
52 // enum for storage types
53 namespace csr {
55 }
56 
57 namespace rowsparse {
59 }
60 
62  kUndefinedStorage = -1, // undefined storage
63  kDefaultStorage, // dense
64  kRowSparseStorage, // row sparse
65  kCSRStorage, // csr
66 };
67 
69  kNormalErr, // normal
70  kCSRShapeErr, // shape mismatch for csr
71  kCSRIndPtrErr, // indptr error for csr
72  kCSRIdxErr, // idx error for csr
73  kRSPShapeErr, // shape mismatch for row sparse
74  kRSPIdxErr, // indices error for row sparse
75 };
76 
77 class MKLDNNMemory;
78 
82 class NDArray {
83  public:
86  : entry_(nullptr) {
87  }
95  NDArray(const mxnet::TShape &shape, Context ctx,
96  bool delay_alloc = false, int dtype = mshadow::default_type_flag)
97  : ptr_(std::make_shared<Chunk>(shape, ctx, delay_alloc, dtype)),
98  shape_(shape),
99  dtype_(dtype),
100  storage_type_(kDefaultStorage),
101  entry_(nullptr) {
102  }
105  NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Context ctx,
106  bool delay_alloc = true, int dtype = mshadow::default_type_flag,
107  std::vector<int> aux_types = {}, mxnet::ShapeVector aux_shapes = {},
108  mxnet::TShape storage_shape = mxnet::TShape(mshadow::Shape1(0)));
115  explicit NDArray(Context ctx, int dtype = mshadow::default_type_flag)
116  : ptr_(std::make_shared<Chunk>(mxnet::TShape(mshadow::Shape1(0)), ctx, true, dtype)),
117  shape_(),
118  dtype_(dtype),
119  storage_type_(kDefaultStorage),
120  entry_(nullptr) {
121  }
129  NDArray(const TBlob &data, int dev_id)
130  : ptr_(std::make_shared<Chunk>(data, dev_id)),
131  shape_(data.shape_),
132  dtype_(data.type_flag_),
133  storage_type_(kDefaultStorage),
134  entry_(nullptr) {
135  }
136 
145  NDArray(const TBlob &data, int dev_id, const std::function<void()>& deleter)
146  : ptr_(new Chunk(data, dev_id), [deleter](Chunk *p) {
147  deleter(); // call custom deleter
148  delete p; // delete Chunk object
149  }),
150  shape_(data.shape_),
151  dtype_(data.type_flag_), storage_type_(kDefaultStorage),
152  entry_(nullptr) {
153  }
154 
156  NDArray(int shared_pid, int shared_id, const mxnet::TShape& shape, int dtype)
157  : ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)),
158  shape_(shape),
159  dtype_(dtype),
160  storage_type_(kDefaultStorage),
161  entry_(nullptr) {
162  }
163 
174  NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape,
175  const TBlob &data, const std::vector<TBlob> &aux_data, int dev_id)
176  : ptr_(std::make_shared<Chunk>(stype, data, aux_data, dev_id)),
177  shape_(shape),
178  dtype_(data.type_flag_),
179  storage_type_(stype),
180  entry_(nullptr) {
181  }
186  void Init(const mxnet::TShape &shape) {
187  ptr_->Init(shape, this->dtype_);
188  this->shape_ = shape;
189  }
193  void SetShapeFromChunk();
194  /*
195  * This indicates whether an array is a view of another array (created by
196  * reshape or slice). If an array is a view and the data is stored in
197  * MKLDNN format, we need to convert the data to the default format when
198  * data in the view is accessed.
199  */
200  inline bool IsView() const {
201  // View only works on the default storage
202  if (storage_type() != kDefaultStorage)
203  return false;
204  // If the array reuses memory, its shape may be different from the storage
205  // shape. However, we shouldn't consider it as a view.
206  if (reuse_)
207  return false;
208  return byte_offset_ > 0 || shape() != ptr_->storage_shape;
209  }
210 
211  /* \brief Check whether the two arrays are the same array */
212  inline bool IsSame(const NDArray& other) const {
213  return ptr_ == other.ptr_ &&
214  shape_ == other.shape_ &&
215  byte_offset_ == other.byte_offset_ &&
216  dtype_ == other.dtype_;
217  }
218 
222  inline const mxnet::TShape& shape() const {
223  return shape_;
224  }
230  inline const mxnet::TShape &storage_shape() const {
231  CHECK(ptr_ != nullptr);
232  CHECK_NE(storage_type(), kDefaultStorage)
233  << "storage_shape() is not intended for kDefaultStorage.";
234  return ptr_->storage_shape;
235  }
236 
242  inline const mxnet::TShape& aux_shape(size_t index) const {
243  CHECK_NE(storage_type(), kDefaultStorage)
244  << "aux_shape() is not intended for kDefaultStorage.";
245  return ptr_->aux_shapes[index];
246  }
247 
248  /* \return the shapes of all aux data */
250  CHECK_NE(storage_type(), kDefaultStorage)
251  << "aux_shapes() is not intended for kDefaultStorage.";
252  return ptr_->aux_shapes;
253  }
254 
256  const std::vector<int>& aux_types() const {
257  CHECK_NE(storage_type(), kDefaultStorage)
258  << "aux_types() is not intended for kDefaultStorage.";
259  return ptr_->aux_types;
260  }
261 
269  inline void set_aux_shape(size_t index, const mxnet::TShape& shape) const {
270  CHECK_NE(storage_type(), kDefaultStorage)
271  << "set_aux_shape() is not intended for kDefaultStorage.";
272  ptr_->set_aux_shape(index, shape);
273  }
274 
278  inline const TBlob& data() const {
279  if (storage_type() == kDefaultStorage) CheckAndAlloc();
280  SetTBlob();
281  return tblob_;
282  }
286  NDArray grad() const;
287 
291  inline TBlob aux_data(size_t i) const {
292  auto stype = storage_type();
293  TBlob res;
294  auto shape = aux_shape(i);
295  auto type = aux_type(i);
296  MSHADOW_TYPE_SWITCH(type, DType, {
297  auto dptr = static_cast<DType*>(ptr_->aux_handles[i].dptr);
298  CHECK(stype == kRowSparseStorage || stype == kCSRStorage)
299  << "Unexpected storage type: " << stype;
300  res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type);
301  });
302  return res;
303  }
307  inline Context ctx() const {
308  CHECK(!is_none());
309  return ptr_->shandle.ctx;
310  }
314  inline int dtype() const {
315  return dtype_;
316  }
317  inline int aux_type(size_t i) const {
318  CHECK(!is_none());
319  return ptr_->aux_types[i];
320  }
321 
323  return storage_type_;
324  }
326  inline bool is_none() const {
327  return ptr_.get() == nullptr;
328  }
330  bool fresh_out_grad() const;
332  void set_fresh_out_grad(bool state) const;
337  inline bool storage_initialized() const {
338  if (is_none()) return false;
339  auto stype = storage_type();
340  CHECK_NE(stype, kDefaultStorage)
341  << "storage_initialized() is not intended for kDefaultStorage.";
342  if (stype == kRowSparseStorage) {
343  CHECK_EQ(aux_shape(rowsparse::kIdx)[0], storage_shape()[0])
344  << "inconsistent storage shape " << storage_shape()
345  << " vs. aux shape " << aux_shape(rowsparse::kIdx);
346  return aux_shape(rowsparse::kIdx).Size() != 0;
347  } else if (stype == kCSRStorage) {
348  CHECK_EQ(aux_shape(csr::kIdx)[0], storage_shape()[0])
349  << "inconsistent storage shape " << storage_shape()
350  << " vs. aux shape " << aux_shape(csr::kIdx);
351  return aux_shape(csr::kIdx).Size() != 0;
352  } else {
353  LOG(FATAL) << "Unknown storage type";
354  }
355  return true;
356  }
359  CHECK(!is_none());
360  CHECK_EQ(storage_type(), kDefaultStorage);
361  CheckAndAlloc();
362  return ptr_->shandle;
363  }
368  inline void WaitToRead() const {
369  if (is_none()) return;
370  Engine::Get()->WaitForVar(ptr_->var);
371  }
376  inline void WaitToWrite() const {
377  if (is_none()) return;
383  [](RunContext, Engine::CallbackOnComplete on_complete) {
384  on_complete();
385  }, Context{}, {}, {ptr_->var});
386  Engine::Get()->WaitForVar(ptr_->var);
387  }
389  inline Engine::VarHandle var() const {
390  return ptr_->var;
391  }
393  inline size_t byte_offset() const {
394  return byte_offset_;
395  }
397  inline size_t version() const {
398  return var()->version();
399  }
404  void Save(dmlc::Stream *strm) const;
410  bool LegacyLoad(dmlc::Stream *strm, const uint32_t magic);
416  bool Load(dmlc::Stream *strm);
422  NDArray &operator=(real_t scalar);
429  NDArray &operator+=(const NDArray &src);
436  NDArray &operator+=(const real_t &src);
443  NDArray &operator-=(const NDArray &src);
450  NDArray &operator-=(const real_t &src);
457  NDArray &operator*=(const NDArray &src);
464  NDArray &operator*=(const real_t &src);
471  NDArray &operator/=(const NDArray &src);
478  NDArray &operator/=(const real_t &src);
484  NDArray Copy(Context ctx) const;
495  void SyncCopyFromCPU(const void *data, size_t size) const;
496 
500  void SyncCopyFromNDArray(const NDArray &src, int i = -1, int j = -1);
501 
512  void SyncCopyToCPU(void *data, size_t size) const;
518  void SyncCheckFormat(const bool full_check) const;
525  NDArray Slice(index_t begin, index_t end) const;
532  NDArray SliceWithRecord(index_t begin, index_t end);
538  NDArray At(index_t idx) const;
544  NDArray AtWithRecord(index_t idx);
549  NDArray aux_ndarray(size_t i) const;
550 
555  NDArray data_ndarray() const;
556 
564  inline NDArray AsArray(const mxnet::TShape &shape, int dtype) const {
565  CHECK_EQ(storage_type(), kDefaultStorage)
566  << "AsArray is intended only for kDefaultStorage.";
567  CHECK_GE(ptr_->shandle.size,
568  shape.Size() * mshadow::mshadow_sizeof(dtype))
569  << "NDArray.AsArray: target memory size is bigger";
570  // We can't reuse memory in a view.
571  CHECK(!IsView());
572  NDArray ret = *this;
573  ret.shape_ = shape;
574  ret.dtype_ = dtype;
575  ret.reuse_ = true;
576  return ret;
577  }
578 
584  DLManagedTensor* ToDLPack() const;
585 
597  static NDArray FromDLPack(const DLManagedTensor* tensor, bool transient_handle);
598 
606  inline void SparseUpdateChunk(const NDArray &arr) const {
607  CHECK(shape_ == arr.shape_) << "ndarray shape is different from the target";
608  CHECK(dtype_ == arr.dtype_) << "ndarray dtype is different from the target";
609  auto stype = arr.storage_type();
610  CHECK(stype == kCSRStorage || stype == kRowSparseStorage)
611  << "Only to be used with CSR and RSP storage types";
612  // swap shandles between src and dst
613  Storage::Handle shandle_dst = arr.ptr_->shandle;
614  arr.ptr_->shandle = ptr_->shandle;
615  ptr_->shandle = shandle_dst;
616 
617  ptr_->storage_shape = arr.ptr_->storage_shape;
618  ptr_->storage_type = arr.ptr_->storage_type;
619  ptr_->ctx = arr.ptr_->ctx;
620 
621  // swap aux_handles between src and dst
622  size_t aux_idx = 0;
623  CHECK(ptr_->aux_handles.size() == arr.ptr_->aux_handles.size())
624  << "ndarray number of aux_handles is different from target";
625  for (auto &aux_handle : arr.ptr_->aux_handles) {
626  Storage::Handle aux_dst = ptr_->aux_handles[aux_idx];
627  ptr_->aux_handles[aux_idx] = aux_handle;
628  aux_handle = aux_dst;
629  aux_idx++;
630  }
631  ptr_->aux_types = arr.ptr_->aux_types;
632  ptr_->aux_shapes = arr.ptr_->aux_shapes;
633  }
634 
640  NDArray Reshape(const mxnet::TShape &shape) const;
646  NDArray ReshapeWithRecord(const mxnet::TShape &shape);
650  NDArray Detach() const {
651  NDArray ret(*this);
652  ret.entry_ = nnvm::NodeEntry(nullptr);
653  return ret;
654  }
655 
656  nnvm::Symbol get_autograd_symbol() const;
661  inline void CheckAndAlloc() const {
662  CHECK_EQ(storage_type(), kDefaultStorage);
663  ptr_->CheckAndAlloc();
664  }
665 
675  void ReshapeAndAlloc(const mxnet::TShape& shape) {
676  CHECK_EQ(storage_type(), kDefaultStorage);
677  CHECK(!is_none());
678  shape_ = shape;
679  ptr_->CheckAndAlloc(shape.Size() * mshadow::mshadow_sizeof(dtype_));
680  }
681 
682  /* !
683  * \brief Alloc memory for non-default storage
684  * aux_shape is only known at run time
685  */
686  inline void CheckAndAlloc(const mxnet::ShapeVector &aux_shapes) const {
687  CHECK_NE(storage_type(), kDefaultStorage)
688  << "CheckAndAlloc(aux_shapes) is not intended for kDefaultStorage";
689  ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_);
690  }
691  inline void CheckAndAllocData(const mxnet::TShape &storage_shape) const {
692  CHECK_NE(storage_type(), kDefaultStorage)
693  << "CheckAndAllocData is not intended for kDefaultStorage";
694  ptr_->CheckAndAllocData(storage_shape, dtype_);
695  }
696  inline void CheckAndAllocAuxData(size_t i, const mxnet::TShape &aux_shape) const {
697  CHECK_NE(storage_type(), kDefaultStorage)
698  << "CheckAndAllocAuxData is not intended for kDefaultStorage";
699  ptr_->CheckAndAllocAuxData(i, aux_shape);
700  }
701 
702 #if MXNET_USE_MKLDNN == 1
703  /*
704  * Create NDArray from mkldnn memory.
705  * mkldnn_mem The mkldnn memory to be managed.
706  */
707  explicit NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
708  /*
709  * Create NDArray from mkldnn memory descriptor.
710  * mem_pd The mkldnn memory descriptor to be created.
711  */
712  explicit NDArray(const mkldnn::memory::desc &md);
713  /*
714  * Test if the data is stored in one of special MKLDNN format.
715  */
716  bool IsMKLDNNData() const {
717  return ptr_->IsMKLDNN();
718  }
719  /*
720  * Test if the data is stored in one of default MXNet formats.
721  */
722  bool IsDefaultData() const {
723  return ptr_->IsDefault();
724  }
725  /*
726  * All functions below return a raw pointer to mkldnn memory. Actually there
727  * is a shared pointer that hold the memory either in NDArray or in MKLDNN
728  * stream. As long as we call these functions inside an operator, the return
729  * memory is always valid.
730  */
731 
732  /*
733  * This function returns mkldnn::memory with the default primitive_desc.
734  */
735  const mkldnn::memory *GetMKLDNNData() const;
736  /*
737  * This function returns mkldnn::memory with the given primitive_desc
738  * as long as the array size meets the required size in the given primitive_desc.
739  */
740  const mkldnn::memory *GetMKLDNNData(const mkldnn::memory::desc &md) const;
741  /*
742  * This function returns mkldnn::memory with the given primitive_desc.
743  * The returned mkldnn::memory will have the same physical layout as
744  * the given primitive_desc.
745  */
746  const mkldnn::memory *GetMKLDNNDataReorder(
747  const mkldnn::memory::desc &md) const;
748 
749  /*
750  * This function copies data from mkldnn memory.
751  */
752  void CopyFrom(const mkldnn::memory &mem);
753  /*
754  * This function allocates memory for array and creates mkldnn memory
755  * with the specified format.
756  */
757  mkldnn::memory *CreateMKLDNNData(const mkldnn::memory::desc &md);
758 
759  /*
760  * These are the async version of the methods above.
761  * It changes the layout of this NDArray, but it happens after all accesses to
762  * the array are complete.
763  */
764  void Reorder2DefaultAsync() const;
765  void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md) const;
766 
767  /*
768  * This creates a new NDArray with the reordered data.
769  * It doesn't affect the data of the original NDArray.
770  */
771  NDArray Reorder2Default() const;
772 
773  /*
774  * This creates a new NDArray using f32 with the reordered data.
775  * It doesn't affect the data of the original NDArray.
776  */
777  NDArray Reorder2DefaultFloatFormat() const;
778 
779  void InvalidateMKLDNNData();
780 
781  /*
782  * This function is used inside operators to reshape an array.
783  * It doesn't change the layout of the original array and allocate memory from
784  * the temporary buffer. The returned array is only valid inside the current
785  * invocation of this operator.
786  * This is different from Reshape. Reshape will cause data in the array to be
787  * converted to the default layout and allocate memory from malloc directly,
788  * which can be expensive.
789  * It's used by FullyConnected right now.
790  */
791  NDArray MKLDNNDataReshape(const mxnet::TShape &shape) const;
792 
796  void UpdateMKLDNNMemDesc(const mkldnn::memory::desc &desc);
797 #endif
798 
805  static void Save(dmlc::Stream* fo,
806  const std::vector<NDArray>& data,
807  const std::vector<std::string>& names);
814  static void Load(dmlc::Stream* fi,
815  std::vector<NDArray>* data,
816  std::vector<std::string>* keys);
817 
818  private:
819  friend class Imperative;
821  // shandle is used to store the actual values in the NDArray
822  // aux_handles store the aux data(such as indices) if it's needed by non-default storage.
823  struct Chunk {
827  Storage::Handle shandle;
832  std::vector<Storage::Handle> aux_handles;
833 
834 #if MXNET_USE_MKLDNN == 1
835 
837  std::shared_ptr<MKLDNNMemory> mkl_mem_;
838 #endif
839 
840  Engine::VarHandle var;
846  bool static_data;
849  bool delay_alloc;
850  // the type of the storage. The storage_type is never kUndefinedStorage once the chunk
851  // is constructed.
852  NDArrayStorageType storage_type = kDefaultStorage;
854  std::vector<int> aux_types;
855  // context of data
856  Context ctx;
857  // The shape of the chunk data.
858  // This might not be the same shape as the NDArray, since the storage may be sparse.
859  // The default value for storage_shape is {0} when an empty non-default NDArray is created.
860  mxnet::TShape storage_shape;
861  // The shape of aux data. The default value for the shape depends on the type of storage.
862  // If aux_shapes[i].Size() is zero, aux data i is empty.
863  mxnet::ShapeVector aux_shapes;
865  std::shared_ptr<Storage> storage_ref_;
867  std::weak_ptr<Engine> engine_ref_;
868 
869 
871  Chunk() : static_data(true), delay_alloc(false),
872  storage_ref_(Storage::_GetSharedRef()),
873  engine_ref_(Engine::_GetSharedRef()) {}
874 
876  Chunk(mxnet::TShape shape, Context ctx_, bool delay_alloc_, int dtype)
877  : static_data(false), delay_alloc(true), ctx(ctx_),
878  storage_ref_(Storage::_GetSharedRef()),
879  engine_ref_(Engine::_GetSharedRef()) {
880  storage_shape = shape;
881  if (shape_is_known(storage_shape)) {
882  shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);
883  }
884  var = Engine::Get()->NewVariable();
885  shandle.ctx = ctx_;
886  if (!delay_alloc_) {
887  this->CheckAndAlloc();
888  }
889  }
890 
891  Chunk(const TBlob &data, int dev_id)
892  : static_data(true), delay_alloc(false),
893  storage_ref_(Storage::_GetSharedRef()),
894  engine_ref_(Engine::_GetSharedRef()) {
895  CHECK(storage_type == kDefaultStorage);
896  var = Engine::Get()->NewVariable();
897  if (data.dev_mask() == cpu::kDevMask) {
898  ctx = Context::CPU();
899  } else {
900  CHECK_EQ(data.dev_mask(), gpu::kDevMask);
901  ctx = Context::GPU(dev_id);
902  }
903  // init shandle
904  shandle.ctx = ctx;
905  shandle.dptr = data.dptr_;
906  shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_);
907  storage_shape = data.shape_;
908  }
909 
910  Chunk(int shared_pid, int shared_id, const mxnet::TShape& shape, int dtype)
911  : static_data(false), delay_alloc(false),
912  storage_ref_(Storage::_GetSharedRef()),
913  engine_ref_(Engine::_GetSharedRef()) {
914  var = Engine::Get()->NewVariable();
915  ctx = Context::CPUShared(0);
916  shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);
917  shandle.ctx = ctx;
918  shandle.shared_pid = shared_pid;
919  shandle.shared_id = shared_id;
920  Storage::Get()->Alloc(&shandle);
921  storage_shape = shape;
922  }
923  // Constructor for a non-default storage chunk
924  Chunk(NDArrayStorageType storage_type_, const mxnet::TShape &storage_shape_, Context ctx_,
925  bool delay_alloc_, int dtype, const std::vector<int> &aux_types_,
926  const mxnet::ShapeVector &aux_shapes_)
927  : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_),
928  aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_),
929  aux_shapes(aux_shapes_), storage_ref_(Storage::_GetSharedRef()),
930  engine_ref_(Engine::_GetSharedRef()) {
931  shandle.ctx = ctx;
932  var = Engine::Get()->NewVariable();
933  // aux_handles always reflect the correct number of aux data
934  for (size_t i = 0; i < aux_shapes.size(); i++) {
935  CheckAndAllocAuxData(i, aux_shapes[i]);
936  // this line is needed in case when aux_shapes[i].Size() = 0
937  // aux_handles[i] will not be updated and take only default value.
938  aux_handles[i].ctx = ctx;
939  }
940  if (!delay_alloc) {
941  CheckAndAllocData(storage_shape, dtype);
942  }
943  }
944 
945  Chunk(const NDArrayStorageType storage_type_, const TBlob &data,
946  const std::vector<TBlob> &aux_data, int dev_id)
947  : static_data(true), delay_alloc(false), storage_type(storage_type_),
948  storage_ref_(Storage::_GetSharedRef()), engine_ref_(Engine::_GetSharedRef()) {
949  using namespace mshadow;
950  CHECK_NE(storage_type, kDefaultStorage);
951  // init var
952  var = Engine::Get()->NewVariable();
953  // init ctx
954  if (data.dev_mask() == cpu::kDevMask) {
955  ctx = Context::CPU();
956  } else {
957  CHECK_EQ(data.dev_mask(), gpu::kDevMask);
958  ctx = Context::GPU(dev_id);
959  }
960  // init shandle
961  shandle.ctx = ctx;
962  shandle.dptr = data.dptr_;
963  shandle.size = data.shape_.Size() * mshadow_sizeof(data.type_flag_);
964  storage_shape = data.shape_;
965  // init aux handles
966  for (const auto &aux : aux_data) {
967  Storage::Handle aux_handle;
968  aux_handle.ctx = ctx;
969  aux_handle.dptr = aux.dptr_;
970  aux_handle.size = aux.shape_.Size() * mshadow_sizeof(aux.type_flag_);
971  aux_handles.push_back(aux_handle);
972  aux_types.emplace_back(aux.type_flag_);
973  aux_shapes.emplace_back(aux.shape_);
974  }
975  }
976 
978  inline void set_aux_shape(const size_t i, const mxnet::TShape& shape) {
979  aux_shapes[i] = shape;
980  if (storage_shape.ndim() >= 0) {
981  if (storage_type == kRowSparseStorage && i == rowsparse::kIdx) {
982  storage_shape[0] = shape[0];
983  } else if (storage_type == kCSRStorage && i == csr::kIdx) {
984  storage_shape[0] = shape[0];
985  }
986  }
987  }
988 
990  inline void CheckAndAlloc(void) {
991  if (delay_alloc) {
992  shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx);
993 #if MXNET_USE_MKLDNN == 1
994  mkl_mem_ = nullptr;
995 #endif
996  delay_alloc = false;
997  }
998  }
999 
1001  // size is the number of bytes
1002  void CheckAndAlloc(uint64_t dbytes) {
1003  CHECK_EQ(kDefaultStorage, storage_type)
1004  << "CheckAndAlloc(dbytes) is only intended for kDefaultStorage";
1005  dbytes = std::max(dbytes, static_cast<uint64_t>(shandle.size));
1006  if (delay_alloc) {
1007  shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
1008 #if MXNET_USE_MKLDNN == 1
1009  mkl_mem_ = nullptr;
1010 #endif
1011  delay_alloc = false;
1012  } else if (shandle.size < dbytes) {
1013  // free storage
1014  Storage::Get()->Free(shandle);
1015  // init storage
1016  shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
1017 #if MXNET_USE_MKLDNN == 1
1018  mkl_mem_ = nullptr;
1019 #endif
1020  }
1021  }
1023  void Init(const mxnet::TShape &shape, int dtype) {
1024  auto size = shape.Size();
1025  storage_shape = shape;
1026  shandle.size = size * mshadow::mshadow_sizeof(dtype);
1027  this->CheckAndAlloc();
1028  }
1029  inline void CheckAndAlloc(const mxnet::TShape &shape, const mxnet::ShapeVector &aux_shapes,
1030  int dtype) {
1031  // calculate size, perform allocation
1032  if (kRowSparseStorage == storage_type) {
1033  // For row sparse, aux_shape indicates the number of rows to allocate
1034  auto aux_shape = aux_shapes[rowsparse::kIdx];
1035  CheckAndAllocAuxData(rowsparse::kIdx, aux_shape);
1036  mxnet::TShape storage_shape(shape);
1037  storage_shape[0] = aux_shape[0];
1038  CheckAndAllocData(storage_shape, dtype);
1039  } else if (kCSRStorage == storage_type) {
1040  CheckAndAllocAuxData(csr::kIndPtr, aux_shapes[csr::kIndPtr]);
1041  CheckAndAllocAuxData(csr::kIdx, aux_shapes[csr::kIdx]);
1042  CheckAndAllocData(aux_shapes[csr::kIdx], dtype);
1043  } else {
1044  LOG(FATAL) << "Storage type " << storage_type << " not implemented for CheckAndAlloc";
1045  }
1046  }
1047  // create storage handle for data based on shape and dtype, assuming ctx is set
1048  // storage shape is also updated
1049  // if data is already allocated, try reuse the storage. Otherwise, free the current one
1050  // and allocate new storage
1051  void CheckAndAllocData(const mxnet::TShape &shape, int dtype);
1052 
1053 #if MXNET_USE_MKLDNN == 1
1054  // Have MKL memory reference to the data in the default storage
1055  // or create memory for MKLDNN.
1056  void SetMKLMem(const mxnet::TShape &shape, int dtype);
1057  // If the data is stored in MKLDNN layout, we reorder data in mkl_mem_ and
1058  // save the result in shandle.
1059  void Reorder2Default();
1060  // Reroder data to a specified layout.
1061  void MKLDNNDataReorder(const mkldnn::memory::desc &md);
1062  bool IsMKLDNN() const;
1063  bool IsDefault() const;
1064 #endif
1065 
1066  // create storage handle for aux data based on shape
1067  // this function assumes ctx, aux shapes and aux types are set
1068  // aux shape is also updated
1069  // if aux data is already allocated, try reuse the storage. Otherwise, free the current one
1070  // and allocate new storage
1071  inline void CheckAndAllocAuxData(size_t i, const mxnet::TShape &shape) {
1072  CHECK_EQ(shape.ndim(), 1) << "shape must be 1D in CheckAndAllocAuxData";
1073  CHECK_NE(storage_type, kUndefinedStorage)
1074  << "storage type cannot be kUndefinedStorage in CheckAndAllocAuxData";
1075  CHECK_NE(storage_type, kDefaultStorage)
1076  << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData";
1077  if (aux_handles.size() <= i) {
1078  aux_handles.resize(i + 1);
1079  }
1080  size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]);
1081  if (aux_handles[i].size < aux_bytes) {
1082  // free storage
1083  Storage::Get()->Free(aux_handles[i]);
1084  // init aux storage
1085  aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx);
1086  }
1087  // init shape
1088  set_aux_shape(i, shape);
1089  }
1091  ~Chunk();
1092  }; // struct Chunk
1093 
1094  void SetTBlob() const;
1095 
1097  std::shared_ptr<Chunk> ptr_{nullptr};
1099  mxnet::TShape shape_;
1101  size_t byte_offset_ = 0;
1103  int dtype_ = -1;
1105  bool reuse_ = false;
1107  NDArrayStorageType storage_type_ = kUndefinedStorage;
1109  nnvm::NodeEntry entry_;
1117  mutable TBlob tblob_;
1118 }; // class NDArray
1119 
1123 size_t num_aux_data(NDArrayStorageType stype);
1124 
1136 void CopyFromTo(const NDArray &from, const NDArray *to, int priority = 0);
1137 
1151 void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0, bool is_opr = false);
1152 
1159 void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priority = 0);
1160 
1167 NDArray operator+(const NDArray &lhs, const NDArray &rhs);
1174 NDArray operator+(const NDArray &lhs, const real_t &rhs);
1181 NDArray operator-(const NDArray &lhs, const NDArray &rhs);
1188 NDArray operator-(const NDArray &lhs, const real_t &rhs);
1195 NDArray operator*(const NDArray &lhs, const NDArray &rhs); \
1202 NDArray operator*(const NDArray &lhs, const real_t &rhs);
1209 NDArray operator/(const NDArray &lhs, const NDArray &rhs);
1216 NDArray operator/(const NDArray &lhs, const real_t &rhs);
1217 
1222 void RandomSeed(uint32_t seed);
1227 void RandomSeed(Context ctx, uint32_t seed);
1234 void SampleUniform(real_t begin, real_t end, NDArray *out);
1241 void SampleGaussian(real_t mu, real_t sigma, NDArray *out);
1248 void SampleGamma(real_t alpha, real_t beta, NDArray *out);
1254 void SampleExponential(real_t lambda, NDArray *out);
1260 void SamplePoisson(real_t lambda, NDArray *out);
1267 void SampleNegBinomial(int32_t k, real_t p, NDArray *out);
1274 void SampleGenNegBinomial(real_t mu, real_t alpha, NDArray *out);
1275 
1276 
1277 //--------------------------------------------------------------
1278 // The following part are API Registration of NDArray functions.
1279 //--------------------------------------------------------------
1280 
1282 typedef std::function<void (NDArray **used_vars,
1283  real_t *scalars,
1284  NDArray **mutate_vars,
1285  int num_params,
1286  char **param_keys,
1287  char **param_vals)> NDArrayAPIFunction;
1303 };
1306  : public dmlc::FunctionRegEntryBase<NDArrayFunctionReg,
1307  NDArrayAPIFunction> {
1309  unsigned num_use_vars;
1313  unsigned num_scalars;
1320  : num_use_vars(0),
1321  num_mutate_vars(0),
1322  num_scalars(0),
1323  type_mask(0) {}
1330  inline NDArrayFunctionReg &set_function(void (*fsetvalue)(const real_t &rhs,
1331  NDArray *out)) {
1332  body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1333  int num_params, char **param_keys, char **param_vals) {
1334  (*fsetvalue)(s[0], mutate_vars[0]);
1335  };
1336  num_mutate_vars = 1; num_scalars = 1;
1337  this->add_argument("src", "real_t", "Source input to the function.");
1338  return *this;
1339  }
1346  inline NDArrayFunctionReg &set_function(void(*fternary)(const NDArray &lhs,
1347  const NDArray &mhs,
1348  const NDArray &rhs,
1349  NDArray *out)) {
1350  body = [fternary](NDArray **used_vars,
1351  real_t *s, NDArray **mutate_vars,
1352  int num_params, char **param_keys, char **param_vals) {
1353  (*fternary)(*used_vars[0], *used_vars[1], *used_vars[2], mutate_vars[0]);
1354  };
1355  num_use_vars = 3; num_mutate_vars = 1;
1357  this->add_argument("lhs", "NDArray", "Left operand to the function.");
1358  this->add_argument("mhs", "NDArray", "Middle operand to the function.");
1359  this->add_argument("rhs", "NDArray", "Right operand to the function.");
1360  return *this;
1361  }
1368  inline NDArrayFunctionReg &set_function(void (*fbinary)(const NDArray &lhs,
1369  const NDArray &rhs,
1370  NDArray *out)) {
1371  body = [fbinary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1372  int num_params, char **param_keys, char **param_vals) {
1373  (*fbinary)(*used_vars[0], *used_vars[1], mutate_vars[0]);
1374  };
1375  num_use_vars = 2; num_mutate_vars = 1;
1377  this->add_argument("lhs", "NDArray", "Left operand to the function.");
1378  this->add_argument("rhs", "NDArray", "Right operand to the function.");
1379  return *this;
1380  }
1387  inline NDArrayFunctionReg &set_function(void (*fscalar)(const NDArray &lhs,
1388  const real_t &rhs,
1389  NDArray *out)) {
1390  body = [fscalar] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1391  int num_params, char **param_keys, char **param_vals) {
1392  (*fscalar)(*used_vars[0], s[0], mutate_vars[0]);
1393  };
1394  num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1;
1396  this->add_argument("lhs", "NDArray", "Left operand to the function.");
1397  this->add_argument("rhs", "real_t", "Right operand to the function.");
1398  return *this;
1399  }
1406  inline NDArrayFunctionReg &set_function(void (*funary)(const NDArray &src,
1407  NDArray *out)) {
1408  body = [funary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1409  int num_params, char **param_keys, char **param_vals) {
1410  (*funary)(*used_vars[0], mutate_vars[0]);
1411  };
1412  num_use_vars = 1; num_mutate_vars = 1;
1414  this->add_argument("src", "NDArray", "Source input to the function.");
1415  return *this;
1416  }
1424  void (*fgeneric)(NDArray **used_vars,
1425  real_t *s,
1426  NDArray **mutate_vars,
1427  const std::map<std::string, std::string>& param)) {
1428  body = [fgeneric] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1429  int num_params, char **param_keys, char **param_vals) {
1430  std::map<std::string, std::string> param;
1431  for (int i = 0; i < num_params; ++i) {
1432  param[param_keys[i]] = param_vals[i];
1433  }
1434  fgeneric(used_vars, s, mutate_vars, param);
1435  };
1436  return *this;
1437  }
1443  inline NDArrayFunctionReg &set_num_use_vars(unsigned n) {
1444  num_use_vars = n; return *this;
1445  }
1452  num_mutate_vars = n; return *this;
1453  }
1459  inline NDArrayFunctionReg &set_num_scalars(unsigned n) {
1460  num_scalars = n; return *this;
1461  }
1467  inline NDArrayFunctionReg &set_type_mask(int tmask) {
1468  type_mask = tmask; return *this;
1469  }
1470 }; // NDArrayFunctionReg
1471 
1483 #define MXNET_REGISTER_NDARRAY_FUN(name) \
1484  DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name)
1485 
1486 } // namespace mxnet
1487 
1488 namespace dmlc {
1490 DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true);
1491 } // namespace dmlc
1492 #endif // MXNET_NDARRAY_H_
Definition: ndarray.h:74
const int default_type_flag
type enum value for default real type
Definition: base.h:477
Definition: ndarray.h:63
NDArrayStorageType
Definition: ndarray.h:61
Definition: ndarray.h:54
NDArrayFunctionReg & set_num_mutate_vars(unsigned n)
set the number of mutate variables
Definition: ndarray.h:1451
NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, const TBlob &data, const std::vector< TBlob > &aux_data, int dev_id)
constructing a static NDArray of non-default storage that shares data with TBlob Use with caution: al...
Definition: ndarray.h:174
NDArrayFormatErr
Definition: ndarray.h:68
Engine::VarHandle var() const
Definition: ndarray.h:389
mxnet::TShape shape_
shape of the tensor
Definition: tensor_blob.h:72
Common base class for function registry.
Definition: registry.h:151
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:104
void RandomSeed(uint32_t seed)
Seed all random number generator in mxnet.
NDArrayStorageType storage_type() const
Definition: ndarray.h:322
Engine that schedules all the operations according to dependency.
const mxnet::TShape & shape() const
Definition: ndarray.h:222
NDArrayFunctionReg()
constructor
Definition: ndarray.h:1319
namespace of mxnet
Definition: api_registry.h:33
Storage manager across multiple devices.
Definition: storage.h:36
NDArray operator*(const NDArray &lhs, const NDArray &rhs)
elementwise multiplication
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:146
virtual void Free(Handle handle)=0
Free storage.
NDArrayFunctionReg & set_num_use_vars(unsigned n)
set the number of mutate variables
Definition: ndarray.h:1443
void CheckAndAllocData(const mxnet::TShape &storage_shape) const
Definition: ndarray.h:691
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:97
static Context GPU(int32_t dev_id=-1)
int type_mask
information on how function should be called from API
Definition: ndarray.h:1315
NDArrayFunctionReg & set_function(void(*funary)(const NDArray &src, NDArray *out))
set the function body to a unary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1406
NDArray Detach() const
Return a copy of this NDArray without autograd history.
Definition: ndarray.h:650
int type_flag_
type flag of the tensor blob
Definition: tensor_blob.h:74
Definition: optional.h:241
NDArrayFunctionReg & set_num_scalars(unsigned n)
set the number of scalar arguments
Definition: ndarray.h:1459
Definition: ndarray.h:72
unsigned num_mutate_vars
number of variable mutated by this function
Definition: ndarray.h:1311
execution time context. The information needed in runtime for actual execution.
Definition: base.h:350
interface of stream I/O for serialization
Definition: io.h:30
void * dptr
Pointer to the data.
Definition: storage.h:45
NDArrayFunctionReg & set_function(void(*fscalar)(const NDArray &lhs, const real_t &rhs, NDArray *out))
set the function body to a binary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1387
NDArray AsArray(const mxnet::TShape &shape, int dtype) const
Create a NDArray that shares memory with current one The new array must have smaller memory size than...
Definition: ndarray.h:564
Graph node data structure.
base class of engine variables.
Definition: engine.h:44
Definition: ndarray.h:65
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)
macro to quickly declare traits information
Definition: type_traits.h:126
Context ctx
Context information about device and ID.
Definition: storage.h:53
Storage::Handle storage_handle() const
get storage handle
Definition: ndarray.h:358
NDArray()
default constructor
Definition: ndarray.h:85
unsigned num_use_vars
number of variable used by this function
Definition: ndarray.h:1309
int shared_id
Definition: storage.h:58
NDArrayFunctionReg & set_function(void(*fternary)(const NDArray &lhs, const NDArray &mhs, const NDArray &rhs, NDArray *out))
set the function body to a ternary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1346
Definition: ndarray.h:62
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:51
void Init(const mxnet::TShape &shape)
initialize the NDArray, assuming it is not assigned a meaningful shape before
Definition: ndarray.h:186
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:820
RowSparseAuxType
Definition: ndarray.h:58
Definition: ndarray.h:70
bool is_none() const
Definition: ndarray.h:326
all the scalar should go before use_vars
Definition: ndarray.h:1293
void SampleExponential(real_t lambda, NDArray *out)
Sample exponential distribution for each elements of out.
void SparseUpdateChunk(const NDArray &arr) const
Update ndarray chunk storage handles using existing ndarray storage handles Also update the aux_handl...
Definition: ndarray.h:606
size_t Size() const
Definition: tuple.h:521
void * dptr_
pointer to the data
Definition: tensor_blob.h:70
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
Definition: ndarray.h:58
whether this function allows the handles in the target to be empty NDArray that are not yet initializ...
Definition: ndarray.h:1302
Definition: ndarray.h:73
C Tensor object, manage memory of DLTensor. This data structure is intended to facilitate the borrowi...
Definition: dlpack.h:157
static Storage * Get()
namespace for dmlc
Definition: array_view.h:12
virtual void WaitForVar(VarHandle var)=0
Wait for a variable.
bool IsView() const
Definition: ndarray.h:200
Context ctx() const
Definition: ndarray.h:307
void CopyFromTo(const NDArray &from, const NDArray *to, int priority=0)
issue an copy operation from one NDArray to another the two ndarray can sit on different devices this...
CSRAuxType
Definition: ndarray.h:54
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
Definition: ndarray.h:54
const mxnet::TShape & storage_shape() const
Definition: ndarray.h:230
Storage manager across multiple devices.
void WaitToRead() const
Block until all the pending write operations with respect to current NDArray are finished, and read can be performed.
Definition: ndarray.h:368
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr, bool wait=false)=0
Push an asynchronous operation to the engine.
int dtype() const
Definition: ndarray.h:314
bool storage_initialized() const
Returns true if a sparse ndarray&#39;s aux_data and storage are initialized Throws an exception if the in...
Definition: ndarray.h:337
Storage handle.
Definition: storage.h:41
static Context CPUShared(int32_t dev_id=0)
Definition: ndarray.h:64
size_t num_aux_data(NDArrayStorageType stype)
NDArrayFunctionReg & set_type_mask(int tmask)
set type mask
Definition: ndarray.h:1467
void WaitToWrite() const
Block until all the pending read/write operations with respect to current NDArray are finished...
Definition: ndarray.h:376
NDArray(const TBlob &data, int dev_id, const std::function< void()> &deleter)
constructing a static NDArray that shares data with TBlob which is with deleter Use with caution: all...
Definition: ndarray.h:145
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:207
an entry that represents output data from a node
Definition: node.h:51
const mxnet::TShape & aux_shape(size_t index) const
get the shape of aux_data(index)
Definition: ndarray.h:242
Handle Alloc(size_t size, Context ctx)
Allocate a new contiguous memory for a given size.
Definition: storage.h:66
NDArray operator-(const NDArray &lhs, const NDArray &rhs)
elementwise subtraction
Definition: ndarray.h:71
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1472
NDArrayFunctionReg & set_function(void(*fsetvalue)(const real_t &rhs, NDArray *out))
set the function body to a NDArray setvalue function this will also auto set the parameters correctly...
Definition: ndarray.h:1330
NDArray operator+(const NDArray &lhs, const NDArray &rhs)
elementwise add
size_t byte_offset() const
Definition: ndarray.h:393
const mxnet::ShapeVector & aux_shapes() const
Definition: ndarray.h:249
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
void CheckAndAllocAuxData(size_t i, const mxnet::TShape &aux_shape) const
Definition: ndarray.h:696
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:44
Registry entry for NDArrayFunction.
Definition: ndarray.h:1305
NDArrayFunctionReg & set_function(void(*fbinary)(const NDArray &lhs, const NDArray &rhs, NDArray *out))
set the function body to a binary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1368
Dependency engine that schedules operations.
Definition: engine.h:117
static Context CPU(int32_t dev_id=0)
runtime functions for NDArray
Definition: imperative.h:50
int aux_type(size_t i) const
Definition: ndarray.h:317
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:73
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:438
void ReshapeAndAlloc(const mxnet::TShape &shape)
Allocate the space if the allocation has been delayed or the requested size is bigger than the availa...
Definition: ndarray.h:675
all the use_vars should go before scalar
Definition: ndarray.h:1291
size_t version() const
return var version of the NDArray
Definition: ndarray.h:397
unsigned num_scalars
number of scalars used by this function
Definition: ndarray.h:1313
static Engine * Get()
NDArray(int shared_pid, int shared_id, const mxnet::TShape &shape, int dtype)
create ndarray from shared memory
Definition: ndarray.h:156
int ndim() const
Definition: tuple.h:218
#define MSHADOW_TYPE_SWITCH(type, DType,...)
Definition: base.h:1067
const TBlob & data() const
Definition: ndarray.h:278
NDArray(const mxnet::TShape &shape, Context ctx, bool delay_alloc=false, int dtype=mshadow::default_type_flag)
constructs a new dynamic NDArray
Definition: ndarray.h:95
Definition: ndarray.h:69
bool shape_is_known(const TShape &x)
Definition: tuple.h:693
void CheckAndAlloc() const
Allocate the space if it is delayed allocated. This is an internal function used by system that norma...
Definition: ndarray.h:661
overloaded + operator between half_t and bf16_t
Definition: base.h:327
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:95
size_t size
Size of the storage.
Definition: storage.h:49
void set_aux_shape(size_t index, const mxnet::TShape &shape) const
For a sparse operation on a csr matrix for example, the size of the column index array is an estimate...
Definition: ndarray.h:269
TBlob aux_data(size_t i) const
Definition: ndarray.h:291
void SampleGenNegBinomial(real_t mu, real_t alpha, NDArray *out)
Sample generalized negative binomial distribution for each elements of out.
Context information about the execution environment.
Definition: base.h:102
void SamplePoisson(real_t lambda, NDArray *out)
Sample Poisson distribution for each elements of out.
ndarray interface
Definition: ndarray.h:82
NDArray(Context ctx, int dtype=mshadow::default_type_flag)
constructs a new dynamic NDArray whose shape is unknown, hence the NDArray is inherently lazily creat...
Definition: ndarray.h:115
NDArray(const TBlob &data, int dev_id)
constructing a static NDArray that shares data with TBlob Use with caution: allocate ONLY ONE NDArray...
Definition: ndarray.h:129
int dev_mask() const
device mask of the corresponding device
Definition: tensor_blob.h:263
void ElementwiseSum(const std::vector< NDArray > &source, NDArray *out, int priority=0)
Perform elementwise sum over each data from source, store result into out.
std::function< void(NDArray **used_vars, real_t *scalars, NDArray **mutate_vars, int num_params, char **param_keys, char **param_vals)> NDArrayAPIFunction
definition of NDArray function
Definition: ndarray.h:1287
Symbol is help class used to represent the operator node in Graph.
Definition: symbolic.h:50
void SampleNegBinomial(int32_t k, real_t p, NDArray *out)
Sample negative binomial distribution for each elements of out.
NDArrayFunctionReg & set_function(void(*fgeneric)(NDArray **used_vars, real_t *s, NDArray **mutate_vars, const std::map< std::string, std::string > &param))
set the function body to a unary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1423
bool IsSame(const NDArray &other) const
Definition: ndarray.h:212
type traits information header
int shared_pid
Id for IPC shared memory.
Definition: storage.h:57
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:66
const std::vector< int > & aux_types() const
Definition: ndarray.h:256
void SampleGamma(real_t alpha, real_t beta, NDArray *out)
Sample gamma distribution for each elements of out.
NDArray operator/(const NDArray &lhs, const NDArray &rhs)
elementwise division
NDArrayFunctionTypeMask
mask information on how functions can be exposed
Definition: ndarray.h:1289
void CheckAndAlloc(const mxnet::ShapeVector &aux_shapes) const
Definition: ndarray.h:686