mxnet
data_type.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file data_type.h
21  * \brief Primitive runtime data type.
22  */
23 // Acknowledgement: This file originates from incubator-tvm
24 // Acknowledgement: MXNetDataType structure design originates from Halide.
25 #ifndef MXNET_RUNTIME_DATA_TYPE_H_
26 #define MXNET_RUNTIME_DATA_TYPE_H_
27 
29 #include <dmlc/logging.h>
30 #include <type_traits>
31 
32 namespace mxnet {
33 namespace runtime {
41  public:
43  enum TypeCode {
48  };
55  explicit MXNetDataType(DLDataType dtype) : data_(dtype) {}
62  MXNetDataType(int code, int bits, int lanes) {
63  data_.code = static_cast<uint8_t>(code);
64  data_.bits = static_cast<uint8_t>(bits);
65  data_.lanes = static_cast<uint16_t>(lanes);
66  }
68  int code() const {
69  return static_cast<int>(data_.code);
70  }
72  int bits() const {
73  return static_cast<int>(data_.bits);
74  }
76  int bytes() const {
77  return (bits() + 7) / 8;
78  }
80  int lanes() const {
81  return static_cast<int>(data_.lanes);
82  }
84  bool is_scalar() const {
85  return lanes() == 1;
86  }
88  bool is_bool() const {
89  return code() == MXNetDataType::kUInt && bits() == 1;
90  }
92  bool is_float() const {
93  return code() == MXNetDataType::kFloat;
94  }
96  bool is_int() const {
97  return code() == MXNetDataType::kInt;
98  }
100  bool is_uint() const {
101  return code() == MXNetDataType::kUInt;
102  }
104  bool is_handle() const {
105  return code() == MXNetDataType::kHandle;
106  }
108  bool is_vector() const {
109  return lanes() > 1;
110  }
117  return MXNetDataType(data_.code, data_.bits, lanes);
118  }
125  return MXNetDataType(data_.code, bits, data_.lanes);
126  }
132  return with_lanes(1);
133  }
139  bool operator==(const MXNetDataType& other) const {
140  return data_.code == other.data_.code && data_.bits == other.data_.bits &&
141  data_.lanes == other.data_.lanes;
142  }
148  bool operator!=(const MXNetDataType& other) const {
149  return !operator==(other);
150  }
155  operator DLDataType() const {
156  return data_;
157  }
158 
165  static MXNetDataType Int(int bits, int lanes = 1) {
166  return MXNetDataType(kDLInt, bits, lanes);
167  }
174  static MXNetDataType UInt(int bits, int lanes = 1) {
175  return MXNetDataType(kDLUInt, bits, lanes);
176  }
183  static MXNetDataType Float(int bits, int lanes = 1) {
184  return MXNetDataType(kDLFloat, bits, lanes);
185  }
191  static MXNetDataType Bool(int lanes = 1) {
192  return MXNetDataType::UInt(1, lanes);
193  }
200  static MXNetDataType Handle(int bits = 64, int lanes = 1) {
201  return MXNetDataType(kHandle, bits, lanes);
202  }
203 
204  private:
205  DLDataType data_;
206 };
207 
208 } // namespace runtime
209 
211 
212 } // namespace mxnet
213 #endif // MXNET_RUNTIME_DATA_TYPE_H_
mxnet::runtime::MXNetDataType::operator!=
bool operator!=(const MXNetDataType &other) const
NotEqual comparator.
Definition: data_type.h:148
mxnet
namespace of mxnet
Definition: api_registry.h:33
kDLFloat
@ kDLFloat
Definition: dlpack.h:82
DLDataType
The data type the tensor can hold.
Definition: dlpack.h:94
c_runtime_api.h
mxnet::runtime::MXNetDataType::MXNetDataType
MXNetDataType(int code, int bits, int lanes)
Constructor.
Definition: data_type.h:62
mxnet::runtime::MXNetDataType::bits
int bits() const
Definition: data_type.h:72
mxnet::runtime::MXNetDataType::bytes
int bytes() const
Definition: data_type.h:76
mxnet::runtime::MXNetDataType::is_uint
bool is_uint() const
Definition: data_type.h:100
mxnet::runtime::MXNetDataType::code
int code() const
Definition: data_type.h:68
mxnet::runtime::MXNetDataType::with_bits
MXNetDataType with_bits(int bits) const
Create a new data type by change bits to a specified value.
Definition: data_type.h:124
mxnet::runtime::MXNetDataType::kUInt
@ kUInt
Definition: data_type.h:45
mxnet::runtime::MXNetDataType::MXNetDataType
MXNetDataType(DLDataType dtype)
Constructor.
Definition: data_type.h:55
mxnet::runtime::MXNetDataType::lanes
int lanes() const
Definition: data_type.h:80
mxnet::runtime::MXNetDataType::Bool
static MXNetDataType Bool(int lanes=1)
Construct a bool type.
Definition: data_type.h:191
mxnet::runtime::MXNetDataType::operator==
bool operator==(const MXNetDataType &other) const
Equal comparator.
Definition: data_type.h:139
mxnet::runtime::MXNetDataType::TypeCode
TypeCode
Type code for the MXNetDataType.
Definition: data_type.h:43
DLDataType::code
uint8_t code
Type code of base types. We keep it uint8_t instead of DLDataTypeCode for minimal memory footprint,...
Definition: dlpack.h:100
mxnet::runtime::MXNetDataType::kHandle
@ kHandle
Definition: data_type.h:47
DLDataType::bits
uint8_t bits
Number of bits, common choices are 8, 16, 32.
Definition: dlpack.h:104
kDLInt
@ kDLInt
Definition: dlpack.h:80
mxnet::runtime::MXNetDataType::is_vector
bool is_vector() const
Definition: data_type.h:108
kDLUInt
@ kDLUInt
Definition: dlpack.h:81
mxnet::runtime::MXNetDataType::MXNetDataType
MXNetDataType()
default constructor
Definition: data_type.h:50
mxnet::runtime::MXNetDataType::is_bool
bool is_bool() const
Definition: data_type.h:88
mxnet::runtime::MXNetDataType::is_float
bool is_float() const
Definition: data_type.h:92
kHandle
@ kHandle
Definition: c_runtime_api.h:45
mxnet::runtime::MXNetDataType::with_lanes
MXNetDataType with_lanes(int lanes) const
Create a new data type by change lanes to a specified value.
Definition: data_type.h:116
mxnet::MXNetDataType
runtime::MXNetDataType MXNetDataType
Definition: data_type.h:210
mxnet::runtime::MXNetDataType::is_scalar
bool is_scalar() const
Definition: data_type.h:84
mxnet::runtime::MXNetDataType::is_int
bool is_int() const
Definition: data_type.h:96
mxnet::runtime::MXNetDataType
Runtime primitive data type.
Definition: data_type.h:40
mxnet::runtime::MXNetDataType::Handle
static MXNetDataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:200
mxnet::runtime::MXNetDataType::UInt
static MXNetDataType UInt(int bits, int lanes=1)
Construct an uint type.
Definition: data_type.h:174
DLDataType::lanes
uint16_t lanes
Number of lanes in the type, used for vector types.
Definition: dlpack.h:106
mxnet::runtime::MXNetDataType::element_of
MXNetDataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:131
mxnet::runtime::MXNetDataType::Float
static MXNetDataType Float(int bits, int lanes=1)
Construct an uint type.
Definition: data_type.h:183
mxnet::runtime::MXNetDataType::kInt
@ kInt
Definition: data_type.h:44
mxnet::runtime::MXNetDataType::is_handle
bool is_handle() const
Definition: data_type.h:104
mxnet::runtime::MXNetDataType::kFloat
@ kFloat
Definition: data_type.h:46
mxnet::runtime::MXNetDataType::Int
static MXNetDataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:165