mxnet
data_type.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file data_type.h
21  * \brief Primitive runtime data type.
22  */
23 // Acknowledgement: This file originates from incubator-tvm
24 // Acknowledgement: MXNetDataType structure design originates from Halide.
25 #ifndef MXNET_RUNTIME_DATA_TYPE_H_
26 #define MXNET_RUNTIME_DATA_TYPE_H_
27 
29 #include <dmlc/logging.h>
30 #include <type_traits>
31 
32 
33 namespace mxnet {
34 namespace runtime {
42  public:
44  enum TypeCode {
49  };
56  explicit MXNetDataType(DLDataType dtype)
57  : data_(dtype) {}
64  MXNetDataType(int code, int bits, int lanes) {
65  data_.code = static_cast<uint8_t>(code);
66  data_.bits = static_cast<uint8_t>(bits);
67  data_.lanes = static_cast<uint16_t>(lanes);
68  }
70  int code() const {
71  return static_cast<int>(data_.code);
72  }
74  int bits() const {
75  return static_cast<int>(data_.bits);
76  }
78  int bytes() const {
79  return (bits() + 7) / 8;
80  }
82  int lanes() const {
83  return static_cast<int>(data_.lanes);
84  }
86  bool is_scalar() const {
87  return lanes() == 1;
88  }
90  bool is_bool() const {
91  return code() == MXNetDataType::kUInt && bits() == 1;
92  }
94  bool is_float() const {
95  return code() == MXNetDataType::kFloat;
96  }
98  bool is_int() const {
99  return code() == MXNetDataType::kInt;
100  }
102  bool is_uint() const {
103  return code() == MXNetDataType::kUInt;
104  }
106  bool is_handle() const {
107  return code() == MXNetDataType::kHandle;
108  }
110  bool is_vector() const {
111  return lanes() > 1;
112  }
119  return MXNetDataType(data_.code, data_.bits, lanes);
120  }
127  return MXNetDataType(data_.code, bits, data_.lanes);
128  }
134  return with_lanes(1);
135  }
141  bool operator==(const MXNetDataType& other) const {
142  return
143  data_.code == other.data_.code &&
144  data_.bits == other.data_.bits &&
145  data_.lanes == other.data_.lanes;
146  }
152  bool operator!=(const MXNetDataType& other) const {
153  return !operator==(other);
154  }
159  operator DLDataType () const {
160  return data_;
161  }
162 
169  static MXNetDataType Int(int bits, int lanes = 1) {
170  return MXNetDataType(kDLInt, bits, lanes);
171  }
178  static MXNetDataType UInt(int bits, int lanes = 1) {
179  return MXNetDataType(kDLUInt, bits, lanes);
180  }
187  static MXNetDataType Float(int bits, int lanes = 1) {
188  return MXNetDataType(kDLFloat, bits, lanes);
189  }
195  static MXNetDataType Bool(int lanes = 1) {
196  return MXNetDataType::UInt(1, lanes);
197  }
204  static MXNetDataType Handle(int bits = 64, int lanes = 1) {
205  return MXNetDataType(kHandle, bits, lanes);
206  }
207 
208  private:
209  DLDataType data_;
210 };
211 
212 } // namespace runtime
213 
215 
216 } // namespace mxnet
217 #endif // MXNET_RUNTIME_DATA_TYPE_H_
Definition: c_runtime_api.h:46
bool is_uint() const
Definition: data_type.h:102
int lanes() const
Definition: data_type.h:82
Definition: dlpack.h:81
Definition: data_type.h:48
static MXNetDataType Handle(int bits=64, int lanes=1)
Construct a handle type.
Definition: data_type.h:204
namespace of mxnet
Definition: api_registry.h:33
bool is_float() const
Definition: data_type.h:94
bool is_handle() const
Definition: data_type.h:106
MXNetDataType()
default constructor
Definition: data_type.h:51
bool is_vector() const
Definition: data_type.h:110
Definition: dlpack.h:80
Definition: data_type.h:45
MXNetDataType(int code, int bits, int lanes)
Constructor.
Definition: data_type.h:64
Definition: dlpack.h:82
MXNetDataType with_lanes(int lanes) const
Create a new data type by change lanes to a specified value.
Definition: data_type.h:118
uint8_t code
Type code of base types. We keep it uint8_t instead of DLDataTypeCode for minimal memory footprint...
Definition: dlpack.h:100
MXNetDataType element_of() const
Get the scalar version of the type.
Definition: data_type.h:133
static MXNetDataType Int(int bits, int lanes=1)
Construct an int type.
Definition: data_type.h:169
bool is_scalar() const
Definition: data_type.h:86
bool is_bool() const
Definition: data_type.h:90
int bits() const
Definition: data_type.h:74
uint8_t bits
Number of bits, common choices are 8, 16, 32.
Definition: dlpack.h:104
static MXNetDataType Bool(int lanes=1)
Construct a bool type.
Definition: data_type.h:195
Runtime primitive data type.
Definition: data_type.h:41
MXNetDataType(DLDataType dtype)
Constructor.
Definition: data_type.h:56
static MXNetDataType Float(int bits, int lanes=1)
Construct an uint type.
Definition: data_type.h:187
Definition: data_type.h:46
TypeCode
Type code for the MXNetDataType.
Definition: data_type.h:44
bool operator!=(const MXNetDataType &other) const
NotEqual comparator.
Definition: data_type.h:152
static MXNetDataType UInt(int bits, int lanes=1)
Construct an uint type.
Definition: data_type.h:178
int bytes() const
Definition: data_type.h:78
bool operator==(const MXNetDataType &other) const
Equal comparator.
Definition: data_type.h:141
Definition: data_type.h:47
runtime::MXNetDataType MXNetDataType
Definition: data_type.h:214
The data type the tensor can hold.
Definition: dlpack.h:94
uint16_t lanes
Number of lanes in the type, used for vector types.
Definition: dlpack.h:106
bool is_int() const
Definition: data_type.h:98
MXNetDataType with_bits(int bits) const
Create a new data type by change bits to a specified value.
Definition: data_type.h:126
int code() const
Definition: data_type.h:70