mxnet
tensor_gpu-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_TENSOR_GPU_INL_H_
26 #define MSHADOW_TENSOR_GPU_INL_H_
27 #include "./base.h"
28 #include "./tensor.h"
29 
30 namespace mshadow {
31 #if MSHADOW_USE_CUDA
32 template<>
33 inline void InitTensorEngine<gpu>(int dev_id) {
34  cudaDeviceProp prop;
35  int device_id = 0;
36  int device_count = 0;
37  cudaGetDeviceCount(&device_count);
38  CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration";
39  if (dev_id < 0) {
40  device_id = 0;
41  } else {
42  device_id = dev_id;
43  }
44  CHECK_LT(device_id, device_count) << "Incorrect Device ID";
45  MSHADOW_CUDA_CALL(cudaSetDevice(device_id));
46  MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id));
47 }
48 template<>
49 inline void ShutdownTensorEngine<gpu>(void) {
50 }
51 template<>
52 inline void SetDevice<gpu>(int devid) {
53  MSHADOW_CUDA_CALL(cudaSetDevice(devid));
54 }
55 template<int dim, typename DType>
56 inline void AllocSpace(Tensor<gpu, dim, DType> *obj, bool pad) {
57  size_t pitch;
58  // common choice for cuda mem align unit is 32
59  if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) {
60  MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
61  obj->size(dim - 1) * sizeof(DType),
62  obj->shape_.FlatTo2D()[0]));
63  obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
64  } else {
65  obj->stride_ = obj->size(dim - 1);
66  MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
67  obj->shape_.Size() * sizeof(DType), 1));
68  }
69 }
70 template<int dim, typename DType>
72  MSHADOW_CUDA_CALL(cudaFree(obj->dptr_));
73  obj->dptr_ = NULL;
74 }
75 template<typename A, typename B, int dim, typename DType>
76 inline void Copy(Tensor<A, dim, DType> _dst,
78  cudaMemcpyKind kind,
79  Stream<gpu> *stream) {
80  CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch";
81  Tensor<A, 2, DType> dst = _dst.FlatTo2D();
82  Tensor<B, 2, DType> src = _src.FlatTo2D();
83  MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType),
84  src.dptr_, src.stride_ * sizeof(DType),
85  dst.size(1) * sizeof(DType),
86  dst.size(0), kind,
87  Stream<gpu>::GetStream(stream)));
88  // use synchronize call behavior for zero stream
89  if (stream == NULL) {
90  MSHADOW_CUDA_CALL(cudaStreamSynchronize(0));
91  }
92 }
93 template<int dim, typename DType>
94 inline void Copy(Tensor<cpu, dim, DType> dst,
95  const Tensor<gpu, dim, DType> &src,
96  Stream<gpu> *stream) {
97  Copy(dst, src, cudaMemcpyDeviceToHost, stream);
98 }
99 template<int dim, typename DType>
101  const Tensor<gpu, dim, DType> &src,
102  Stream<gpu> *stream) {
103  Copy(dst, src, cudaMemcpyDeviceToDevice, stream);
104 }
105 template<int dim, typename DType>
107  const Tensor<cpu, dim, DType> &src,
108  Stream<gpu> *stream) {
109  Copy(dst, src, cudaMemcpyHostToDevice, stream);
110 }
111 #endif // MSHADOW_USE_CUDA
112 } // namespace mshadow
113 
114 // the following part is included only if compiler is nvcc
115 #ifdef __CUDACC__
116 #include "./cuda/tensor_gpu-inl.cuh"
117 
118 namespace mshadow {
119 template<typename Saver, typename R, int dim,
120  typename DType, typename E, int etype>
121 inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
122  const expr::Exp<E, DType, etype> &exp) {
123  expr::TypeCheckPass<expr::TypeCheck<gpu, dim, DType, E>::kMapPass>
124  ::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
125  Shape<dim> eshape = expr::ShapeCheck<dim, E>::Check(exp.self());
126  Shape<dim> dshape = expr::ShapeCheck<dim, R>::Check(dst->self());
127  CHECK(eshape[0] == 0 || eshape == dshape)
128  << "Assignment: Shape of Tensors are not consistent with target, "
129  << "eshape: " << eshape << " dshape:" << dshape;
130  cuda::MapPlan<Saver>(MakePlan(dst->self()),
131  MakePlan(exp.self()),
132  dshape.FlatTo2D(),
134 }
135 
136 template<typename Saver, typename Reducer,
137  typename R, typename DType, typename E, int etype>
138 inline void MapReduceKeepLowest(TRValue<R, gpu, 1, DType> *dst,
139  const expr::Exp<E, DType, etype> &exp,
140  DType scale) {
141  expr::TypeCheckPass<expr::TypeCheck<gpu, 1, DType, E>::kRedPass>
142  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
143  Shape<2> eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
144  ::Check(exp.self()).FlatTo2D();
145  Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self());
146  CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
147  CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor";
148  cuda::MapReduceKeepLowest<Saver, Reducer>
149  (MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape,
151 }
152 
153 template<typename Saver, typename Reducer, int dimkeep,
154  typename R, typename DType, typename E, int etype>
155 inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst,
156  const expr::Exp<E, DType, etype> &exp,
157  DType scale) {
158  expr::TypeCheckPass<expr::TypeCheck<gpu, dimkeep, DType, E>::kRedPass>
159  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
160  typedef Shape<expr::ExpInfo<E>::kDim> EShape;
161  EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
162  ::Check(exp.self());
163  Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self());
164  CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match";
165  // use equvalent form
166  Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
167  eshape[dimkeep],
168  eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
169  eshape[EShape::kSubdim]);
170  // call equavalent map red dim 2
171  cuda::MapReduceKeepDim1<Saver, Reducer>
172  (MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape,
174 }
175 template<typename DType>
176 inline void Softmax(Tensor<gpu, 2, DType> dst,
177  const Tensor<gpu, 2, DType>& src) {
178  cuda::Softmax(dst, src);
179 }
180 
181 template<typename DType>
182 inline void Softmax(Tensor<gpu, 3, DType> dst,
183  const Tensor<gpu, 3, DType>& src) {
184  cuda::Softmax(dst, src);
185 }
186 
187 template<typename DType>
188 inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
189  const Tensor<gpu, 2, DType> &src,
190  const Tensor<gpu, 1, DType> &label) {
191  cuda::SoftmaxGrad(dst, src, label);
192 }
193 
194 template<typename DType>
195 inline void SmoothSoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
196  const Tensor<gpu, 2, DType> &src,
197  const Tensor<gpu, 1, DType> &label,
198  const float alpha) {
199  cuda::SmoothSoftmaxGrad(dst, src, label, alpha);
200 }
201 
202 template<typename DType>
203 inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
204  const Tensor<gpu, 2, DType> &src,
205  const Tensor<gpu, 1, DType> &label,
206  const DType &ignore_label) {
207  cuda::SoftmaxGrad(dst, src, label, ignore_label);
208 }
209 
210 template<typename DType>
211 inline void SmoothSoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
212  const Tensor<gpu, 2, DType> &src,
213  const Tensor<gpu, 1, DType> &label,
214  const DType &ignore_label,
215  const float alpha) {
216  cuda::SmoothSoftmaxGrad(dst, src, label, ignore_label, alpha);
217 }
218 
219 template<typename DType>
220 inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
221  const Tensor<gpu, 3, DType> &src,
222  const Tensor<gpu, 2, DType> &label) {
223  cuda::SoftmaxGrad(dst, src, label);
224 }
225 
226 template<typename DType>
227 inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
228  const Tensor<gpu, 3, DType> &src,
229  const Tensor<gpu, 2, DType> &label,
230  const DType &ignore_label) {
231  cuda::SoftmaxGrad(dst, src, label, ignore_label);
232 }
233 
234 template<bool clip, typename IndexType, typename DType>
235 inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
236  const Tensor<gpu, 1, IndexType>& index,
237  const Tensor<gpu, 2, DType> &src) {
238  cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
239 }
240 
241 template<bool clip, typename IndexType, typename DType, typename AType>
242 inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
243  Tensor<gpu, 2, AType> temp,
244  const Tensor<gpu, 1, IndexType>& index,
245  const Tensor<gpu, 2, DType> &src) {
246  cuda::AddTakeGrad<clip, IndexType, DType>(dst, temp, index, src);
247 }
248 
249 template<typename IndexType, typename DType>
250 inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
251  const Tensor<gpu, 1, IndexType>& sorted,
252  const Tensor<gpu, 1, IndexType>& index,
253  const Tensor<gpu, 2, DType> &src) {
254  cuda::AddTakeGradLargeBatch(dst, sorted, index, src);
255 }
256 
257 template<typename KDType, typename VDType>
258 inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values,
259  bool is_ascend) {
260  cuda::SortByKey(keys, values, is_ascend);
261 }
262 
263 template<typename IndexType, typename DType>
264 inline void IndexFill(Tensor<gpu, 2, DType> dst,
265  const Tensor<gpu, 1, IndexType>& index,
266  const Tensor<gpu, 2, DType> &src) {
267  cuda::IndexFill(dst, index, src);
268 }
269 } // namespace mshadow
270 #endif // __CUDACC__
271 #endif // MSHADOW_TENSOR_GPU_INL_H_
mshadow::Shape4
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:254
mshadow::SortByKey
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!...
Definition: tensor_cpu-inl.h:596
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
mshadow::Copy
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:145
mshadow::FreeSpace
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:140
mshadow::expr::StreamInfo::Get
static Stream< Device > * Get(const E &t)
MSHADOW_CUDA_CALL
#define MSHADOW_CUDA_CALL(func)
Protected cuda call in mshadow.
Definition: base.h:264
mshadow::IndexFill
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix....
Definition: tensor_cpu-inl.h:585
mshadow::SoftmaxGrad
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:311
mshadow::MapReduceKeepLowest
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0)
Definition: tensor_cpu-inl.h:228
mshadow::Tensor
general tensor
Definition: tensor.h:525
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::Softmax
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j]))
Definition: tensor_cpu-inl.h:488
tensor.h
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::Stream< gpu >
Definition: stream_gpu-inl.h:37
mshadow::ShutdownTensorEngine< gpu >
void ShutdownTensorEngine< gpu >(void)
Definition: tensor_gpu-inl.h:49
mshadow::Tensor::shape_
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:541
mshadow::Stream< gpu >::GetStream
static cudaStream_t GetStream(Stream< gpu > *stream)
returns actual cudaStream_t given an input GPU stream pointer
Definition: stream_gpu-inl.h:107
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::AddTakeGradLargeBatch
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation. dst[index[i]] += src[i].
Definition: tensor_cpu-inl.h:575
mshadow::expr::pad
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0],...
Definition: pad.h:71
mshadow::AllocSpace
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:116
mshadow::MapReduceKeepHighDim
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2)
Definition: tensor_cpu-inl.h:255
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::Tensor::FlatTo2D
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:624
mshadow::Tensor::dptr_
DType * dptr_
pointer to the data
Definition: tensor.h:539
MSHADOW_MIN_PAD_RATIO
#define MSHADOW_MIN_PAD_RATIO
x dimension of data must be bigger pad_size * ratio to be alloced padded memory, otherwise use tide a...
Definition: base.h:84
mshadow::InitTensorEngine< gpu >
void InitTensorEngine< gpu >(int dev_id)
Definition: tensor_gpu-inl.h:33
mshadow::Tensor::size
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:610
mshadow::AddTakeGrad
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:521
mshadow::SetDevice< gpu >
void SetDevice< gpu >(int devid)
Definition: tensor_gpu-inl.h:52
mshadow::SmoothSoftmaxGrad
void SmoothSoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label, const float alpha)
Definition: tensor_cpu-inl.h:328
base.h
definitions of base types, operators, macros functions
mshadow::Tensor::stride_
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:546
mshadow::MapExp
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:212