mxnet
vectorization-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_
22 
23 #include <mxnet/base.h>
24 
25 #if MXNET_USE_CUDA
26 
27 #include <sstream>
28 #include <string>
29 #include <vector>
30 #include <algorithm>
31 
32 #include "../rtc.h"
33 #include "../../utils.h"
34 
35 namespace mxnet {
36 namespace common {
37 namespace cuda {
38 namespace rtc {
39 
40 const char vectorization_support_string[] = R"code(
41 
42 namespace vector {
43 
44 constexpr int vectorized_kernel_thread_num = 512;
45 
46 template <int size>
47 struct VectorType {
48  static_assert(size <= 32, "VectorType needs to have size of at most 32B");
49 };
50 
51 template <>
52 struct VectorType<1> {
53  using type = char;
54 };
55 
56 template <>
57 struct VectorType<2> {
58  using type = short;
59 };
60 
61 
62 template <>
63 struct VectorType<4> {
64  using type = int;
65 };
66 
67 template <>
68 struct VectorType<8> {
69  using type = long long;
70 };
71 
72 template <>
73 struct VectorType<16> {
74  using type = ulonglong2;
75 };
76 
77 template <>
78 struct VectorType<32> {
79  using type = ulonglong4;
80 };
81 
82 template <typename DType>
83 __device__ inline DType add_elem(const DType& x, const DType& y) {
84  return x + y;
85 }
86 
87 template <>
88 __device__ inline half add_elem(const half& x, const half& y) {
89  return __float2half(__half2float(x) + __half2float(y));
90 }
91 
92 /* \brief Helper class that enables storing multiple values of type DType
93  as 1 value of type LType.
94 */
95 template <typename DType, int n>
96 class VectorizedStorage {
97  public:
98  using LType = typename VectorType<sizeof(DType) * n>::type;
99  constexpr static int nvec = n;
100  union vectorized_storage {
101  LType aligned;
102  DType separate[nvec]; // NOLINT(*)
103 
104  inline __device__ vectorized_storage() {}
105  inline __device__ ~vectorized_storage() {}
106  } scratch_;
107 
108  inline __device__ VectorizedStorage() {}
109  inline __device__ VectorizedStorage (const VectorizedStorage<DType, n>& y2) {
110  scratch_.aligned = y2.scratch_.aligned;
111  }
112  inline __device__ VectorizedStorage (const LType &y2) {
113  scratch_.aligned = y2;
114  }
115  inline __device__ VectorizedStorage<DType, n>& operator+=(
116  const VectorizedStorage<DType, n>& rhs) {
117  #pragma unroll
118  for (int i = 0; i < nvec; ++i) {
119  scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]);
120  }
121  return *this;
122  }
123  inline __device__ ~VectorizedStorage() {}
124 };
125 
126 // Returns const LType is DType is const
127 template <typename DType, typename LType>
128 struct select_const {
129  using type = LType;
130 };
131 
132 template <typename DType, typename LType>
133 struct select_const<const DType, LType> {
134  using type = const LType;
135 };
136 
137 template <typename DType>
138 struct remove_const {
139  using type = DType;
140 };
141 
142 template <typename DType>
143 struct remove_const<const DType> {
144  using type = DType;
145 };
146 
147 
148 /* \brief Helper class that enables accessing multiple values of type DType
149  as 1 value of type LType. Additional aligned template argument
150  allows performance optimizations if the pointer and the size of
151  the allocation is aligned to sizeof(LType) / sizeof(DType) elements.
152 */
153 template <typename DType, int nvec, bool aligned = false>
154 class VectorizedAccessor {
155  public:
156  using StorageType = VectorizedStorage<typename remove_const<DType>::type,
157  nvec>;
158  using LType = typename select_const<DType, typename StorageType::LType>::type;
159  StorageType storage_;
160 
161  LType* aligned_ptr_;
162  DType* unaligned_ptr_;
163  int alignment_;
164  index_t n_elems_;
165 
166  inline __device__ VectorizedAccessor(DType* const ptr, const index_t size) {
167  unaligned_ptr_ = ptr;
168  if (aligned) {
169  alignment_ = 0;
170  aligned_ptr_ = reinterpret_cast<LType*>(ptr);
171  n_elems_ = (size + nvec - 1) / nvec;
172  } else {
173  size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
174  alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
175  aligned_ptr_ = reinterpret_cast<LType*>(ptr - alignment_);
176  n_elems_ = (size + alignment_ + nvec - 1) / nvec;
177  }
178  }
179 
180  /* \brief Alignment of the input pointer in elements. */
181  inline __device__ int alignment() const {
182  return alignment_;
183  }
184 
185  /* \brief Access to separate elements. */
186  inline __device__ DType* separate() {
187  return storage_.scratch_.separate;
188  }
189 
190  /* \brief Number of aligned elements that span the entire input tensor. */
191  inline __device__ index_t num_aligned_elements() const {
192  return n_elems_;
193  }
194 
195  /* \brief Load values from the input.
196  \param id Aligned index of the element.
197  \param N size of the tensor.
198  */
199  inline __device__ void load(const index_t id, const index_t N) {
200  if (aligned) {
201  storage_.scratch_.aligned = aligned_ptr_[id];
202  } else {
203  if (id > 0 && id < n_elems_ - 1) {
204  storage_.scratch_.aligned = aligned_ptr_[id];
205  } else {
206 #pragma unroll
207  for (int j = 0; j < nvec; ++j) {
208  DType* ptr = reinterpret_cast<DType*>(&(aligned_ptr_[id])) + j;
209  if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(unaligned_ptr_) &&
210  reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(unaligned_ptr_ + N)) {
211  storage_.scratch_.separate[j] = *ptr;
212  } else {
213  storage_.scratch_.separate[j] = DType();
214  }
215  }
216  }
217  }
218  }
219 };
220 
221 /* \brief Class used for vectorized read-only access. */
222 template <typename DType, int nvec, bool aligned = false>
223 class VectorizedLoader : public VectorizedAccessor<const DType, nvec, aligned> {
224  public:
225  inline __device__ VectorizedLoader(const DType* ptr, const index_t N) :
226  VectorizedAccessor<const DType, nvec, aligned>(ptr, N) {
227  }
228 };
229 
230 /* \brief Class used for vectorized writable access. */
231 template <typename DType, int nvec, bool aligned = false>
232 class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
233  public:
234  inline __device__ VectorizedStorer(DType* ptr, const index_t N) :
235  VectorizedAccessor<DType, nvec, aligned>(ptr, N) {
236  }
237 
238  /* \brief Store values to the output.
239  \param id Aligned index of the element.
240  \param N size of the tensor.
241  */
242  inline __device__ void store(const index_t id, const index_t N) {
243  if (aligned) {
244  this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
245  } else {
246  if (id > 0 && id < this->n_elems_ - 1) {
247  this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
248  } else {
249 #pragma unroll
250  for (int j = 0; j < nvec; ++j) {
251  DType* ptr = reinterpret_cast<DType*>(&(this->aligned_ptr_[id])) + j;
252  if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(this->unaligned_ptr_) &&
253  reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) {
254  *ptr = this->storage_.scratch_.separate[j];
255  }
256  }
257  }
258  }
259  }
260 };
261 
262 } // namespace vector
263 
264 )code";
265 
266 namespace {
267 
268 inline index_t get_num_aligned_elements(const void* ptr,
269  const index_t lead_dim,
270  const int nvec,
271  const int size) {
272  size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
273  int alignment = (ptr_as_number % (nvec * size)) / size;
274  return (lead_dim + alignment + nvec - 1) / nvec;
275 }
276 
277 enum class Alignment {
278  SAME_ALIGNED, // All tensors aligned
279  SAME_UNALIGNED, // All tensors have the same misalignment
280  DIFFERENT // Tensors have different alignment
281 };
282 
283 inline int CalcAlignment(const void* ptr, const int size) {
284  size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
285  return ptr_as_number % size;
286 }
287 
288 /* \brief Check alignment of the inputs and outputs when using vectorized accesses.
289  \param params Structure containing arrays with inputs' and outputs' pointers
290  \param lead_dim Leading dimension of the tensors.
291  \param other_dim The size of the other dimensions of the tensors.
292  \param nvec Length of the vector.
293  \param inputs Inputs to the operator.
294  \param outputs Outputs of the operator.
295 */
296 template <typename Params>
297 Alignment CheckAlignment(const Params& params,
298  const index_t lead_dim,
299  const index_t other_dim,
300  const int nvec,
301  const std::vector<TBlob>& inputs,
302  const std::vector<TBlob>& outputs) {
303  using namespace common;
304  int align = -1;
305 
306  size_t i = 0;
307  for (const void* ptr : params.inputs) {
308  if (ptr != nullptr) {
309  int new_align = CalcAlignment(ptr, mshadow_type_info(inputs[i].type_flag_).size * nvec);
310  if (align == -1) {
311  align = new_align;
312  } else {
313  if (align != new_align) {
314  return Alignment::DIFFERENT;
315  }
316  }
317  }
318  ++i;
319  }
320 
321  i = 0;
322  for (const void* ptr : params.outputs) {
323  if (ptr != nullptr) {
324  int new_align = CalcAlignment(ptr, mshadow_type_info(outputs[i].type_flag_).size * nvec);
325  if (align == -1) {
326  align = new_align;
327  } else {
328  if (align != new_align) {
329  return Alignment::DIFFERENT;
330  }
331  }
332  }
333  ++i;
334  }
335 
336  if ((other_dim != 1) && (lead_dim % nvec != 0)) {
337  return Alignment::DIFFERENT;
338  }
339 
340  if ((align == 0) && (lead_dim % nvec == 0)) {
341  return Alignment::SAME_ALIGNED;
342  } else {
343  return Alignment::SAME_UNALIGNED;
344  }
345 }
346 
347 constexpr int vectorized_kernel_thread_num = 512;
348 
349 } // namespace
350 
368 template <typename Params>
369 void VectorizedKernelRTCLauncher(const std::string& parameters,
370  const std::string& kernel_name,
371  const std::string& code,
372  int nvec,
373  const index_t lead_dim,
374  const index_t other_dim,
376  const Params params,
377  const std::vector<TBlob>& inputs,
378  const std::vector<TBlob>& outputs,
379  const int dev_id,
380  const int lead_input_num = 0,
381  const index_t blocks = 0) {
382  const index_t N = lead_dim * other_dim;
383  nvec = std::min(nvec, 4); // Use at most 4-wide vectors
384  if (N != 0) {
385  auto align = CheckAlignment(params, lead_dim, other_dim, nvec, inputs, outputs);
386  std::string kernel_builder;
387  kernel_builder.reserve(2560);
388 
389  // Fill input types
390  int counter = 0;
391  for (const auto& input : inputs) {
392  const auto& type_info = common::mshadow_type_info(input.type_flag_);
393  kernel_builder += "using InputType";
394  kernel_builder += std::to_string(counter);
395  kernel_builder += " = ";
396  kernel_builder += type_info.name;
397  kernel_builder += ";\n";
398  ++counter;
399  }
400 
401  // Fill output types
402  counter = 0;
403  for (const auto& output : outputs) {
404  const auto& type_info = common::mshadow_type_info(output.type_flag_);
405  kernel_builder += "using OutputType";
406  kernel_builder += std::to_string(counter);
407  kernel_builder += " = ";
408  kernel_builder += type_info.name;
409  kernel_builder += ";\n";
410  ++counter;
411  }
412 
413  switch (align) {
414  case Alignment::SAME_ALIGNED:
415  kernel_builder +=
416  "const bool aligned = true;\n"
417  "const int nvec = ";
418  kernel_builder += std::to_string(nvec);
419  kernel_builder += ";\n";
420  break;
421  case Alignment::SAME_UNALIGNED:
422  kernel_builder +=
423  "const bool aligned = false;\n"
424  "const int nvec = ";
425  kernel_builder += std::to_string(nvec);
426  kernel_builder += ";\n";
427  break;
428  case Alignment::DIFFERENT: {
429  // If the pointers are aligned differently we cannot vectorize
430  kernel_builder +=
431  "const bool aligned = true;\n"
432  "const int nvec = 1;\n";
433  nvec = 1;
434  break;
435  }
436  }
437 
438  kernel_builder += parameters;
439 
440  index_t num_aligned_elements =
441  get_num_aligned_elements(params.inputs[lead_input_num],
442  lead_dim,
443  nvec,
444  common::mshadow_type_info(inputs[lead_input_num].type_flag_).size);
445  constexpr int threads = vectorized_kernel_thread_num;
446  index_t num_blocks;
447  if (blocks != 0) {
448  num_blocks = blocks;
449  } else {
450  size_t num_elements = other_dim * num_aligned_elements;
451  num_blocks = (num_elements + threads - 1) / threads;
452  constexpr int max_blocks = 65535;
453  num_blocks = std::min(static_cast<int>(num_blocks), max_blocks);
454  }
455  std::vector<const void*> args = {&params, &lead_dim, &other_dim, &N, &num_aligned_elements};
456  auto function = common::cuda::rtc::get_function(kernel_builder, kernel_name, code, dev_id);
457 
458  common::cuda::rtc::launch(function,
459  {static_cast<unsigned int>(num_blocks), 1, 1},
460  {static_cast<unsigned int>(threads), 1, 1},
461  0,
462  s,
463  &args);
464  }
465 }
466 
467 } // namespace rtc
468 } // namespace cuda
469 } // namespace common
470 } // namespace mxnet
471 
472 #endif // MXNET_USE_CUDA
473 
474 #endif // MXNET_COMMON_CUDA_RTC_VECTORIZATION_INL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::MShadowTypeInfo::size
int size
Definition: utils.h:1091
mxnet::common::cuda::rtc::util::to_string
std::string to_string(OpReqType req)
Convert OpReqType to string.
mxnet::common::cuda::rtc::VectorizedKernelRTCLauncher
void VectorizedKernelRTCLauncher(const std::string &parameters, const std::string &kernel_name, const std::string &code, int nvec, const index_t lead_dim, const index_t other_dim, mshadow::Stream< gpu > *s, const Params params, const std::vector< TBlob > &inputs, const std::vector< TBlob > &outputs, const int dev_id, const int lead_input_num=0, const index_t blocks=0)
Launcher helper for the kernels using vectorization.
Definition: vectorization-inl.h:369
mxnet::common::mshadow_type_info
MShadowTypeInfo mshadow_type_info(const int type_flag)
mshadow::Stream< gpu >
Definition: stream_gpu-inl.h:37
mxnet::common::cuda::rtc::get_function
CUfunction get_function(const std::string &parameters, const std::string &kernel_name, const std::string &code, int dev_id)
Compile and get the GPU kernel. Uses cache in order to eliminate the overhead of compilation.
mxnet::index_t
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:81
mxnet::common::cuda::rtc::launch
void launch(CUfunction function, const dim3 grid_dim, const dim3 block_dim, unsigned int shared_mem_bytes, mshadow::Stream< gpu > *stream, std::vector< const void * > *args)
Launch a GPU kernel.
base.h
configuration of MXNet as well as basic data structure.
mxnet::common::cuda::rtc::vectorization_support_string
const char vectorization_support_string[]
Definition: vectorization-inl.h:40