28 #ifndef MXNET_TENSOR_BLOB_H_ 29 #define MXNET_TENSOR_BLOB_H_ 31 #include <dmlc/logging.h> 89 template<
typename DType>
91 : dptr_(dptr), shape_(shape),
92 type_flag_(
mshadow::DataType<DType>::kFlag) {
93 SetDLTensor(dev_mask,
dev_id);
104 : dptr_(dptr), shape_(shape), type_flag_(type_flag) {
105 SetDLTensor(dev_mask,
dev_id);
112 : dptr_(dltensor.data),
114 type_flag_(DLDataTypeTransform(dltensor.dtype)),
115 dltensor_(dltensor) {
117 if (dltensor.
strides !=
nullptr) {
120 const int64_t *shape = dltensor.
shape;
121 const int64_t *strides = dltensor.
strides;
124 if (strides[ndim - 1] != 1) {
127 for (
int i = ndim - 2; i >= 0; --i) {
128 if (strides[i] != shape[i + 1] * strides[i + 1]) {
135 LOG(FATAL) <<
"Unsupported DLPack because MXNet only support compact tensor now";
147 template<
typename Device,
int dim,
typename DType>
155 TBlob(
const TBlob &src): dptr_(src.dptr_), shape_(src.shape_), type_flag_(src.type_flag_) {
166 template<
typename Device,
int dim,
typename DType>
171 SetDLTensor(Device::kDevMask, -1);
198 CHECK_EQ(this->shape_.
Size(), shape.
Size()) <<
"Shape size mismatch " 199 << this->shape_.
Size() <<
" v.s. " << shape.
Size();
210 template<
typename Device,
typename DType>
213 CHECK(Device::kDevMask == this->
dev_mask())
214 <<
"TBlob.get: device type do not match specified type";
216 <<
"TBlob.get_with_shape: data type do not match specified type." 230 template<
typename Device,
typename DType>
233 return this->get_with_shape<Device, 1, DType>(
238 return shape_.
ndim();
250 inline size_t Size(
void)
const {
251 return shape_.
Size();
254 template<
typename DType>
257 <<
"TBlob.get_with_shape: data type do not match specified type." 260 return static_cast<DType*
>(
dptr_);
287 template<
typename Device,
int dim,
typename DType>
289 CHECK(Device::kDevMask == this->
dev_mask())
290 <<
"TBlob.get: device type do not match specified type";
292 shape_.get<dim>(), shape_[shape_.
ndim() - 1], stream);
304 template<
typename Device,
int dim,
typename DType>
308 CHECK(Device::kDevMask == this->
dev_mask())
309 <<
"TBlob.get: device type do not match specified type";
310 CHECK_EQ(this->
CheckContiguous(),
true) <<
"TBlob.get_reshape: must be contiguous";
311 CHECK_EQ(this->shape_.
Size(),
static_cast<size_t>(shape.
Size()))
312 <<
"TBlob.get_with_shape: new and old shape do not match total elements";
314 shape[dim - 1], stream);
325 template<
typename Device,
typename DType>
328 return this->get_with_shape<Device, 3, DType>(
329 this->shape_.FlatTo3D(axis), stream);
341 template<
typename Device,
typename DType>
343 int axis_begin,
int axis_end,
345 return this->get_with_shape<Device, 3, DType>(
346 this->shape_.FlatTo3D(axis_begin, axis_end), stream);
357 template<
typename Device,
int dim,
typename DType>
363 for (
int i = 0; i < dim -
ndim(); ++i) {
367 for (
int i = 0; i <
ndim() - dim + 1; ++i) {
368 shape[0] *= shape_[i];
371 for (
int i = std::max(0,
ndim() - dim + 1); i <
ndim(); ++i) {
372 shape[i -
ndim() + dim] = shape_[i];
374 return this->get_with_shape<Device, dim, DType>(shape, stream);
378 static DLDataType DTypeTransform(
int type_flag) {
390 LOG(FATAL) <<
"Unknown type_flag=" << type_flag;
395 static int DLDataTypeTransform(
DLDataType dldata_type) {
396 if (dldata_type.
lanes != 1) {
397 LOG(FATAL) <<
"Unsupported DLDataType whose lanes != 1";
399 switch (dldata_type.
code) {
401 switch (dldata_type.
bits) {
408 switch (dldata_type.
bits) {
413 switch (dldata_type.
bits) {
419 switch (dldata_type.
bits) {
426 LOG(FATAL) <<
"Unknown DLDataType{" << dldata_type.
code 427 <<
", " << dldata_type.
bits 428 <<
", " << dldata_type.
lanes <<
"}";
436 dltensor_.
dtype = DTypeTransform(type_flag_);
456 namespace parameter {
460 :
public FieldEntryBase<FieldEntry<mxnet::TShape>, mxnet::TShape> {
466 virtual void Check(
void *head)
const {
469 if (expect_ndim_ != 0 && v.
ndim() != expect_ndim_) {
470 std::ostringstream os;
471 os <<
"value " << v <<
"for Parameter " << this->key_
472 <<
" has wrong dimensions, expected dimension=" << expect_ndim_;
473 throw dmlc::ParamError(os.str());
475 if (enforce_nonzero_) {
476 for (
int i = 0; i < v.
ndim(); ++i) {
478 std::ostringstream os;
479 os <<
"value " << v <<
"for Parameter " << this->key_
480 <<
" is invalid, the input shape must be nonzero in all dimensions";
481 throw dmlc::ParamError(os.str());
487 this->enforce_nonzero_ =
true;
497 bool enforce_nonzero_;
505 #endif // MXNET_TENSOR_BLOB_H_ #define DMLC_DECLARE_TYPE_NAME(Type, Name)
macro to quickly declare traits information
Definition: type_traits.h:133
TBlob & operator=(const mshadow::Tensor< Device, dim, DType > &src)
assignment from tensor
Definition: tensor_blob.h:167
The common header of DLPack.
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:145
mxnet::TShape shape_
shape of the tensor
Definition: tensor_blob.h:72
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:51
constexpr const int kTVMNDArrayTypeCode
Definition: tensor_blob.h:49
DType * dptr_
pointer to the data
Definition: tensor.h:435
TBlob(const DLTensor &dltensor)
constructor that construct TBlob from DLTensor
Definition: tensor_blob.h:111
c++17 compatible optional class.
Definition: optional.h:43
mshadow::Tensor< Device, 3, DType > FlatTo3D(int axis_begin, int axis_end, mshadow::Stream< Device > *stream=nullptr) const
flatten the tensor to 3 dimension, collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim).
Definition: tensor_blob.h:342
namespace of mxnet
Definition: api_registry.h:33
mshadow::Tensor< Device, 2, DType > FlatTo2D(mshadow::Stream< Device > *stream=nullptr) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor_blob.h:211
A Device context for Tensor and operator.
Definition: dlpack.h:69
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:437
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:97
TBlob(void)
default constructor, default copy assign will work
Definition: tensor_blob.h:77
Definition: tensor_blob.h:459
int type_flag_
type flag of the tensor blob
Definition: tensor_blob.h:74
FieldEntry< mxnet::TShape > & set_expect_ndim(int ndim)
Definition: tensor_blob.h:490
FieldEntry< mxnet::TShape > & enforce_nonzero()
Definition: tensor_blob.h:486
FieldEntryBase< FieldEntry< mxnet::TShape >, mxnet::TShape > Parent
Definition: tensor_blob.h:464
const dim_t * data() const
Definition: tuple.h:550
TBlob(void *dptr, const mxnet::TShape &shape, int dev_mask, int type_flag, int dev_id=-1)
constructor that construct TBlob from contiguous memory
Definition: tensor_blob.h:103
uint8_t code
Type code of base types. We keep it uint8_t instead of DLDataTypeCode for minimal memory footprint...
Definition: dlpack.h:100
int device_id
The device index.
Definition: dlpack.h:73
mshadow::Tensor< Device, 3, DType > FlatTo3D(int axis, mshadow::Stream< Device > *stream=nullptr) const
flatten the tensor to 3 dimension, collapse the dimension before and after specified axis...
Definition: tensor_blob.h:326
constexpr const int kGPU
Definition: tensor_blob.h:44
Lightweight JSON Reader/Writer that read save into C++ data structs. This includes STL composites and...
CPU device.
Definition: dlpack.h:40
index_t size(index_t idx) const
return size of i-th dimension, start counting from highest dimension. return type needs to be a signe...
Definition: tensor_blob.h:246
size_t Size() const
Definition: tuple.h:521
void * dptr_
pointer to the data
Definition: tensor_blob.h:70
std::string dtype_string(const int dtype)
Definition: base.h:1479
namespace for dmlc
Definition: array_view.h:12
uint8_t bits
Number of bits, common choices are 8, 16, 32.
Definition: dlpack.h:104
int64_t * strides
strides of the tensor (in number of elements, not bytes) can be NULL, indicating tensor is compact an...
Definition: dlpack.h:145
DLDataType dtype
The data type of the pointer.
Definition: dlpack.h:138
DType * dptr() const
get pointer in dtype
Definition: tensor_blob.h:255
DLDeviceType
The device type in DLContext.
Definition: dlpack.h:38
mshadow::Tensor< Device, dim, DType > get_with_shape(const mshadow::Shape< dim > &shape, mshadow::Stream< Device > *stream=nullptr) const
fetch a tensor in given shape If size do not match the stored size, an error will be issued ...
Definition: tensor_blob.h:305
int ndim(void) const
return number of dimension of the tensor inside
Definition: tensor_blob.h:237
DLDeviceType device_type
The device type used in the device.
Definition: dlpack.h:71
TBlob reshape(const mxnet::TShape &shape) const
reshape to shape
Definition: tensor_blob.h:197
mshadow::Tensor< Device, dim, DType > FlatToKD(mshadow::Stream< Device > *stream=nullptr) const
flatten the tensor to specified number of dimensions, collapse the highest dimensions or pad with hig...
Definition: tensor_blob.h:358
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:207
TBlob(const mshadow::Tensor< Device, dim, DType > &src)
constructor from tensor
Definition: tensor_blob.h:148
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:58
TBlob(DType *dptr, const mxnet::TShape &shape, int dev_mask, int dev_id=-1)
constructor that construct TBlob from contiguous memory
Definition: tensor_blob.h:90
int ndim
Number of dimensions.
Definition: dlpack.h:136
TBlob & operator=(const TBlob &src)
assignment from TBlob (copy assignment)
Definition: tensor_blob.h:179
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:44
void * data
The opaque data pointer points to the allocated data. This will be CUDA device pointer or cl_mem hand...
Definition: dlpack.h:132
const DLTensor & dltensor() const
return the corresponding DLTensor
Definition: tensor_blob.h:274
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:438
virtual void Check(void *head) const
Definition: tensor_blob.h:466
mshadow::Tensor< Device, 1, DType > FlatTo1D(mshadow::Stream< Device > *stream=nullptr) const
flatten the tensor to 1 dimension, collapse all the dimensions together.
Definition: tensor_blob.h:231
int ndim() const
Definition: tuple.h:218
bool CheckContiguous(void) const
Definition: tensor_blob.h:189
overloaded + operator between half_t and bf16_t
Definition: base.h:327
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:95
constexpr const int kCPU
Definition: tensor_blob.h:43
DLContext ctx
The device context of the tensor.
Definition: dlpack.h:134
CUDA GPU device.
Definition: dlpack.h:42
int64_t * shape
The shape of the tensor.
Definition: dlpack.h:140
The data type the tensor can hold.
Definition: dlpack.h:94
general tensor
Definition: tensor.h:421
uint16_t lanes
Number of lanes in the type, used for vector types.
Definition: dlpack.h:106
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
FieldEntry()
Definition: tensor_blob.h:462
ndarray interface
Definition: ndarray.h:82
uint64_t byte_offset
The offset in bytes to the beginning pointer to data.
Definition: dlpack.h:147
int dev_mask() const
device mask of the corresponding device
Definition: tensor_blob.h:263
TBlob(const TBlob &src)
constructor from TBlob (copy constructor)
Definition: tensor_blob.h:155
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:66
int dev_id() const
device index of the corresponding device
Definition: tensor_blob.h:267
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384
size_t Size(void) const
total number of elements in the tensor
Definition: tensor_blob.h:250