mxnet
tensor_gpu-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_TENSOR_GPU_INL_H_
27 #define MSHADOW_TENSOR_GPU_INL_H_
28 #include "./base.h"
29 #include "./tensor.h"
30 
31 namespace mshadow {
32 #if MSHADOW_USE_CUDA
33 template<>
34 inline void InitTensorEngine<gpu>(int dev_id) {
35  cudaDeviceProp prop;
36  int device_id = 0;
37  int device_count = 0;
38  cudaGetDeviceCount(&device_count);
39  CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration";
40  if (dev_id < 0) {
41  device_id = 0;
42  } else {
43  device_id = dev_id;
44  }
45  CHECK_LT(device_id, device_count) << "Incorrect Device ID";
46  MSHADOW_CUDA_CALL(cudaSetDevice(device_id));
47  MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id));
48 }
49 template<>
50 inline void ShutdownTensorEngine<gpu>(void) {
51 }
52 template<>
53 inline void SetDevice<gpu>(int devid) {
54  MSHADOW_CUDA_CALL(cudaSetDevice(devid));
55 }
56 template<int dim, typename DType>
57 inline void AllocSpace(Tensor<gpu, dim, DType> *obj, bool pad) {
58  size_t pitch;
59  // common choice for cuda mem align unit is 32
60  if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) {
61  MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
62  obj->size(dim - 1) * sizeof(DType),
63  obj->shape_.FlatTo2D()[0]));
64  obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
65  } else {
66  obj->stride_ = obj->size(dim - 1);
67  MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
68  obj->shape_.Size() * sizeof(DType), 1));
69  }
70 }
71 template<int dim, typename DType>
73  MSHADOW_CUDA_CALL(cudaFree(obj->dptr_));
74  obj->dptr_ = NULL;
75 }
76 template<typename A, typename B, int dim, typename DType>
77 inline void Copy(Tensor<A, dim, DType> _dst,
79  cudaMemcpyKind kind,
80  Stream<gpu> *stream) {
81  CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch";
82  Tensor<A, 2, DType> dst = _dst.FlatTo2D();
83  Tensor<B, 2, DType> src = _src.FlatTo2D();
84  MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType),
85  src.dptr_, src.stride_ * sizeof(DType),
86  dst.size(1) * sizeof(DType),
87  dst.size(0), kind,
88  Stream<gpu>::GetStream(stream)));
89  // use synchronize call behavior for zero stream
90  if (stream == NULL) {
91  MSHADOW_CUDA_CALL(cudaStreamSynchronize(0));
92  }
93 }
94 template<int dim, typename DType>
95 inline void Copy(Tensor<cpu, dim, DType> dst,
96  const Tensor<gpu, dim, DType> &src,
97  Stream<gpu> *stream) {
98  Copy(dst, src, cudaMemcpyDeviceToHost, stream);
99 }
100 template<int dim, typename DType>
102  const Tensor<gpu, dim, DType> &src,
103  Stream<gpu> *stream) {
104  Copy(dst, src, cudaMemcpyDeviceToDevice, stream);
105 }
106 template<int dim, typename DType>
108  const Tensor<cpu, dim, DType> &src,
109  Stream<gpu> *stream) {
110  Copy(dst, src, cudaMemcpyHostToDevice, stream);
111 }
112 #endif // MSHADOW_USE_CUDA
113 } // namespace mshadow
114 
115 // the following part is included only if compiler is nvcc
116 #ifdef __CUDACC__
117 #include "./cuda/tensor_gpu-inl.cuh"
118 
119 namespace mshadow {
120 template<typename Saver, typename R, int dim,
121  typename DType, typename E, int etype>
122 inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
123  const expr::Exp<E, DType, etype> &exp) {
125  ::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
128  CHECK(eshape[0] == 0 || eshape == dshape)
129  << "Assignment: Shape of Tensors are not consistent with target, "
130  << "eshape: " << eshape << " dshape:" << dshape;
131  cuda::MapPlan<Saver>(MakePlan(dst->self()),
132  MakePlan(exp.self()),
133  dshape.FlatTo2D(),
135 }
136 
137 template<typename Saver, typename Reducer,
138  typename R, typename DType, typename E, int etype>
140  const expr::Exp<E, DType, etype> &exp,
141  DType scale) {
143  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
145  ::Check(exp.self()).FlatTo2D();
147  CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
148  CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor";
149  cuda::MapReduceKeepLowest<Saver, Reducer>
150  (MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape,
152 }
153 
154 template<typename Saver, typename Reducer, int dimkeep,
155  typename R, typename DType, typename E, int etype>
157  const expr::Exp<E, DType, etype> &exp,
158  DType scale) {
160  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
161  typedef Shape<expr::ExpInfo<E>::kDim> EShape;
162  EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
163  ::Check(exp.self());
165  CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match";
166  // use equvalent form
167  Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
168  eshape[dimkeep],
169  eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
170  eshape[EShape::kSubdim]);
171  // call equavalent map red dim 2
172  cuda::MapReduceKeepDim1<Saver, Reducer>
173  (MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape,
175 }
176 template<typename DType>
177 inline void Softmax(Tensor<gpu, 2, DType> dst,
178  const Tensor<gpu, 2, DType>& src) {
179  cuda::Softmax(dst, src);
180 }
181 
182 template<typename DType>
183 inline void Softmax(Tensor<gpu, 3, DType> dst,
184  const Tensor<gpu, 3, DType>& src) {
185  cuda::Softmax(dst, src);
186 }
187 
188 template<typename DType>
189 inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
190  const Tensor<gpu, 2, DType> &src,
191  const Tensor<gpu, 1, DType> &label) {
192  cuda::SoftmaxGrad(dst, src, label);
193 }
194 
195 template<typename DType>
196 inline void SmoothSoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
197  const Tensor<gpu, 2, DType> &src,
198  const Tensor<gpu, 1, DType> &label,
199  const float alpha) {
200  cuda::SmoothSoftmaxGrad(dst, src, label, alpha);
201 }
202 
203 template<typename DType>
204 inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
205  const Tensor<gpu, 2, DType> &src,
206  const Tensor<gpu, 1, DType> &label,
207  const DType &ignore_label) {
208  cuda::SoftmaxGrad(dst, src, label, ignore_label);
209 }
210 
211 template<typename DType>
212 inline void SmoothSoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
213  const Tensor<gpu, 2, DType> &src,
214  const Tensor<gpu, 1, DType> &label,
215  const DType &ignore_label,
216  const float alpha) {
217  cuda::SmoothSoftmaxGrad(dst, src, label, ignore_label, alpha);
218 }
219 
220 template<typename DType>
221 inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
222  const Tensor<gpu, 3, DType> &src,
223  const Tensor<gpu, 2, DType> &label) {
224  cuda::SoftmaxGrad(dst, src, label);
225 }
226 
227 template<typename DType>
228 inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
229  const Tensor<gpu, 3, DType> &src,
230  const Tensor<gpu, 2, DType> &label,
231  const DType &ignore_label) {
232  cuda::SoftmaxGrad(dst, src, label, ignore_label);
233 }
234 
235 template<bool clip, typename IndexType, typename DType>
236 inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
237  const Tensor<gpu, 1, IndexType>& index,
238  const Tensor<gpu, 2, DType> &src) {
239  cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
240 }
241 
242 template<typename IndexType, typename DType>
244  const Tensor<gpu, 1, IndexType>& sorted,
245  const Tensor<gpu, 1, IndexType>& index,
246  const Tensor<gpu, 2, DType> &src) {
247  cuda::AddTakeGradLargeBatch(dst, sorted, index, src);
248 }
249 
250 template<typename KDType, typename VDType>
252  bool is_ascend) {
253  cuda::SortByKey(keys, values, is_ascend);
254 }
255 
256 template<typename IndexType, typename DType>
257 inline void IndexFill(Tensor<gpu, 2, DType> dst,
258  const Tensor<gpu, 1, IndexType>& index,
259  const Tensor<gpu, 2, DType> &src) {
260  cuda::IndexFill(dst, index, src);
261 }
262 } // namespace mshadow
263 #endif // __CUDACC__
264 #endif // MSHADOW_TENSOR_GPU_INL_H_
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:141
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix...
Definition: tensor_cpu-inl.h:548
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:307
void SmoothSoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label, const float alpha)
Definition: tensor_cpu-inl.h:324
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
Definition: pad.h:72
DType * dptr_
pointer to the data
Definition: tensor.h:435
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:410
used to help static type check
Definition: expr_engine-inl.h:331
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:146
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:208
Definition: stream_gpu-inl.h:38
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:437
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:241
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) ...
Definition: tensor_cpu-inl.h:559
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) ...
Definition: tensor_cpu-inl.h:484
#define MSHADOW_CUDA_CALL(func)
Protected cuda call in mshadow.
Definition: base.h:271
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) ...
Definition: tensor_cpu-inl.h:224
static Shape< dim > Check(const E &t)
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:520
Definition: expr_engine-inl.h:346
int32_t index_t
type that will be used for index
Definition: base.h:336
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:117
void ShutdownTensorEngine< gpu >(void)
Definition: tensor_gpu-inl.h:50
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[sorted[i]] += src[index[i]] Called when the bat...
Definition: tensor_cpu-inl.h:538
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
void InitTensorEngine< gpu >(int dev_id)
Definition: tensor_gpu-inl.h:34
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) ...
Definition: tensor_cpu-inl.h:251
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
void SetDevice< gpu >(int devid)
Definition: tensor_gpu-inl.h:53
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:517
overloaded + operator between half_t and bf16_t
Definition: base.h:327
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:506
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:442
general tensor
Definition: tensor.h:421
#define MSHADOW_MIN_PAD_RATIO
x dimension of data must be bigger pad_size * ratio to be alloced padded memory, otherwise use tide a...
Definition: base.h:84
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384