20 #ifndef MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_
33 #include "../../utils.h"
44 constexpr int vectorized_kernel_thread_num = 512;
48 static_assert(size <= 32, "VectorType needs to have size of at most 32B");
52 struct VectorType<1> {
57 struct VectorType<2> {
63 struct VectorType<4> {
68 struct VectorType<8> {
69 using type = long long;
73 struct VectorType<16> {
74 using type = ulonglong2;
78 struct VectorType<32> {
79 using type = ulonglong4;
82 template <typename DType>
83 __device__ inline DType add_elem(const DType& x, const DType& y) {
88 __device__ inline half add_elem(const half& x, const half& y) {
89 return __float2half(__half2float(x) + __half2float(y));
92 /* \brief Helper class that enables storing multiple values of type DType
93 as 1 value of type LType.
95 template <typename DType, int n>
96 class VectorizedStorage {
98 using LType = typename VectorType<sizeof(DType) * n>::type;
99 constexpr static int nvec = n;
100 union vectorized_storage {
102 DType separate[nvec]; // NOLINT(*)
104 inline __device__ vectorized_storage() {}
105 inline __device__ ~vectorized_storage() {}
108 inline __device__ VectorizedStorage() {}
109 inline __device__ VectorizedStorage (const VectorizedStorage<DType, n>& y2) {
110 scratch_.aligned = y2.scratch_.aligned;
112 inline __device__ VectorizedStorage (const LType &y2) {
113 scratch_.aligned = y2;
115 inline __device__ VectorizedStorage<DType, n>& operator+=(
116 const VectorizedStorage<DType, n>& rhs) {
118 for (int i = 0; i < nvec; ++i) {
119 scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]);
123 inline __device__ ~VectorizedStorage() {}
126 // Returns const LType is DType is const
127 template <typename DType, typename LType>
128 struct select_const {
132 template <typename DType, typename LType>
133 struct select_const<const DType, LType> {
134 using type = const LType;
137 template <typename DType>
138 struct remove_const {
142 template <typename DType>
143 struct remove_const<const DType> {
148 /* \brief Helper class that enables accessing multiple values of type DType
149 as 1 value of type LType. Additional aligned template argument
150 allows performance optimizations if the pointer and the size of
151 the allocation is aligned to sizeof(LType) / sizeof(DType) elements.
153 template <typename DType, int nvec, bool aligned = false>
154 class VectorizedAccessor {
156 using StorageType = VectorizedStorage<typename remove_const<DType>::type,
158 using LType = typename select_const<DType, typename StorageType::LType>::type;
159 StorageType storage_;
162 DType* unaligned_ptr_;
166 inline __device__ VectorizedAccessor(DType* const ptr, const index_t size) {
167 unaligned_ptr_ = ptr;
170 aligned_ptr_ = reinterpret_cast<LType*>(ptr);
171 n_elems_ = (size + nvec - 1) / nvec;
173 size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
174 alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
175 aligned_ptr_ = reinterpret_cast<LType*>(ptr - alignment_);
176 n_elems_ = (size + alignment_ + nvec - 1) / nvec;
180 /* \brief Alignment of the input pointer in elements. */
181 inline __device__ int alignment() const {
185 /* \brief Access to separate elements. */
186 inline __device__ DType* separate() {
187 return storage_.scratch_.separate;
190 /* \brief Number of aligned elements that span the entire input tensor. */
191 inline __device__ index_t num_aligned_elements() const {
195 /* \brief Load values from the input.
196 \param id Aligned index of the element.
197 \param N size of the tensor.
199 inline __device__ void load(const index_t id, const index_t N) {
201 storage_.scratch_.aligned = aligned_ptr_[id];
203 if (id > 0 && id < n_elems_ - 1) {
204 storage_.scratch_.aligned = aligned_ptr_[id];
207 for (int j = 0; j < nvec; ++j) {
208 DType* ptr = reinterpret_cast<DType*>(&(aligned_ptr_[id])) + j;
209 if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(unaligned_ptr_) &&
210 reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(unaligned_ptr_ + N)) {
211 storage_.scratch_.separate[j] = *ptr;
213 storage_.scratch_.separate[j] = DType();
221 /* \brief Class used for vectorized read-only access. */
222 template <typename DType, int nvec, bool aligned = false>
223 class VectorizedLoader : public VectorizedAccessor<const DType, nvec, aligned> {
225 inline __device__ VectorizedLoader(const DType* ptr, const index_t N) :
226 VectorizedAccessor<const DType, nvec, aligned>(ptr, N) {
230 /* \brief Class used for vectorized writable access. */
231 template <typename DType, int nvec, bool aligned = false>
232 class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
234 inline __device__ VectorizedStorer(DType* ptr, const index_t N) :
235 VectorizedAccessor<DType, nvec, aligned>(ptr, N) {
238 /* \brief Store values to the output.
239 \param id Aligned index of the element.
240 \param N size of the tensor.
242 inline __device__ void store(const index_t id, const index_t N) {
244 this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
246 if (id > 0 && id < this->n_elems_ - 1) {
247 this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
250 for (int j = 0; j < nvec; ++j) {
251 DType* ptr = reinterpret_cast<DType*>(&(this->aligned_ptr_[id])) + j;
252 if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(this->unaligned_ptr_) &&
253 reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) {
254 *ptr = this->storage_.scratch_.separate[j];
262 } // namespace vector
268 inline index_t get_num_aligned_elements(
const void* ptr,
272 size_t ptr_as_number =
reinterpret_cast<size_t>(ptr);
273 int alignment = (ptr_as_number % (nvec * size)) / size;
274 return (lead_dim + alignment + nvec - 1) / nvec;
277 enum class Alignment {
283 inline int CalcAlignment(
const void* ptr,
const int size) {
284 size_t ptr_as_number =
reinterpret_cast<size_t>(ptr);
285 return ptr_as_number % size;
296 template <
typename Params>
297 Alignment CheckAlignment(
const Params& params,
301 const std::vector<TBlob>& inputs,
302 const std::vector<TBlob>& outputs) {
303 using namespace common;
307 for (
const void* ptr : params.inputs) {
308 if (ptr !=
nullptr) {
309 int new_align = CalcAlignment(ptr,
mshadow_type_info(inputs[i].type_flag_).size * nvec);
313 if (align != new_align) {
314 return Alignment::DIFFERENT;
322 for (
const void* ptr : params.outputs) {
323 if (ptr !=
nullptr) {
324 int new_align = CalcAlignment(ptr,
mshadow_type_info(outputs[i].type_flag_).size * nvec);
328 if (align != new_align) {
329 return Alignment::DIFFERENT;
336 if ((other_dim != 1) && (lead_dim % nvec != 0)) {
337 return Alignment::DIFFERENT;
340 if ((align == 0) && (lead_dim % nvec == 0)) {
341 return Alignment::SAME_ALIGNED;
343 return Alignment::SAME_UNALIGNED;
347 constexpr
int vectorized_kernel_thread_num = 512;
368 template <
typename Params>
370 const std::string& kernel_name,
371 const std::string& code,
377 const std::vector<TBlob>& inputs,
378 const std::vector<TBlob>& outputs,
380 const int lead_input_num = 0,
382 const index_t N = lead_dim * other_dim;
383 nvec = std::min(nvec, 4);
385 auto align = CheckAlignment(params, lead_dim, other_dim, nvec, inputs, outputs);
386 std::string kernel_builder;
387 kernel_builder.reserve(2560);
391 for (
const auto& input : inputs) {
393 kernel_builder +=
"using InputType";
395 kernel_builder +=
" = ";
396 kernel_builder += type_info.name;
397 kernel_builder +=
";\n";
403 for (
const auto& output : outputs) {
405 kernel_builder +=
"using OutputType";
407 kernel_builder +=
" = ";
408 kernel_builder += type_info.name;
409 kernel_builder +=
";\n";
414 case Alignment::SAME_ALIGNED:
416 "const bool aligned = true;\n"
419 kernel_builder +=
";\n";
421 case Alignment::SAME_UNALIGNED:
423 "const bool aligned = false;\n"
426 kernel_builder +=
";\n";
428 case Alignment::DIFFERENT: {
431 "const bool aligned = true;\n"
432 "const int nvec = 1;\n";
438 kernel_builder += parameters;
441 get_num_aligned_elements(params.inputs[lead_input_num],
445 constexpr
int threads = vectorized_kernel_thread_num;
450 size_t num_elements = other_dim * num_aligned_elements;
451 num_blocks = (num_elements + threads - 1) / threads;
452 constexpr
int max_blocks = 65535;
453 num_blocks = std::min(
static_cast<int>(num_blocks), max_blocks);
455 std::vector<const void*> args = {¶ms, &lead_dim, &other_dim, &N, &num_aligned_elements};
459 {
static_cast<unsigned int>(num_blocks), 1, 1},
460 {
static_cast<unsigned int>(threads), 1, 1},
472 #endif // MXNET_USE_CUDA
474 #endif // MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_