mxnet
tensor_gpu-inl.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_TENSOR_GPU_INL_H_
8 #define MSHADOW_TENSOR_GPU_INL_H_
9 #include "./base.h"
10 #include "./tensor.h"
11 
12 namespace mshadow {
13 #if MSHADOW_USE_CUDA
14 template<>
15 inline void InitTensorEngine<gpu>(int dev_id) {
16  cudaDeviceProp prop;
17  int device_id = 0;
18  int device_count = 0;
19  cudaGetDeviceCount(&device_count);
20  CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration";
21  if (dev_id < 0) {
22  device_id = 0;
23  } else {
24  device_id = dev_id;
25  }
26  CHECK_LT(device_id, device_count) << "Incorrect Device ID";
27  MSHADOW_CUDA_CALL(cudaSetDevice(device_id));
28  MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id));
29 }
30 template<>
31 inline void ShutdownTensorEngine<gpu>(void) {
32 }
33 template<>
34 inline void SetDevice<gpu>(int devid) {
35  MSHADOW_CUDA_CALL(cudaSetDevice(devid));
36 }
37 template<int dim, typename DType>
38 inline void AllocSpace(Tensor<gpu, dim, DType> *obj, bool pad) {
39  size_t pitch;
40  // common choice for cuda mem align unit is 32
41  if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) {
42  MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
43  obj->size(dim - 1) * sizeof(DType),
44  obj->shape_.FlatTo2D()[0]));
45  obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
46  } else {
47  obj->stride_ = obj->size(dim - 1);
48  MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
49  obj->shape_.Size() * sizeof(DType), 1));
50  }
51 }
52 template<int dim, typename DType>
54  MSHADOW_CUDA_CALL(cudaFree(obj->dptr_));
55  obj->dptr_ = NULL;
56 }
57 template<typename A, typename B, int dim, typename DType>
58 inline void Copy(Tensor<A, dim, DType> _dst,
60  cudaMemcpyKind kind,
61  Stream<gpu> *stream) {
62  CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch";
63  Tensor<A, 2, DType> dst = _dst.FlatTo2D();
64  Tensor<B, 2, DType> src = _src.FlatTo2D();
65  MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType),
66  src.dptr_, src.stride_ * sizeof(DType),
67  dst.size(1) * sizeof(DType),
68  dst.size(0), kind,
69  Stream<gpu>::GetStream(stream)));
70  // use synchronize call behavior for zero stream
71  if (stream == NULL) {
72  MSHADOW_CUDA_CALL(cudaStreamSynchronize(0));
73  }
74 }
75 template<int dim, typename DType>
76 inline void Copy(Tensor<cpu, dim, DType> dst,
77  const Tensor<gpu, dim, DType> &src,
78  Stream<gpu> *stream) {
79  Copy(dst, src, cudaMemcpyDeviceToHost, stream);
80 }
81 template<int dim, typename DType>
82 inline void Copy(Tensor<gpu, dim, DType> dst,
83  const Tensor<gpu, dim, DType> &src,
84  Stream<gpu> *stream) {
85  Copy(dst, src, cudaMemcpyDeviceToDevice, stream);
86 }
87 template<int dim, typename DType>
88 inline void Copy(Tensor<gpu, dim, DType> dst,
89  const Tensor<cpu, dim, DType> &src,
90  Stream<gpu> *stream) {
91  Copy(dst, src, cudaMemcpyHostToDevice, stream);
92 }
93 #endif // MSHADOW_USE_CUDA
94 } // namespace mshadow
95 
96 // the following part is included only if compiler is nvcc
97 #ifdef __CUDACC__
98 #include "./cuda/tensor_gpu-inl.cuh"
99 
100 namespace mshadow {
101 template<typename Saver, typename R, int dim,
102  typename DType, typename E, int etype>
103 inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
104  const expr::Exp<E, DType, etype> &exp) {
106  ::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
109  CHECK(eshape[0] == 0 || eshape == dshape)
110  << "Assignment: Shape of Tensors are not consistent with target, "
111  << "eshape: " << eshape << " dshape:" << dshape;
112  cuda::MapPlan<Saver>(MakePlan(dst->self()),
113  MakePlan(exp.self()),
114  dshape.FlatTo2D(),
116 }
117 
118 template<typename Saver, typename Reducer,
119  typename R, typename DType, typename E, int etype>
121  const expr::Exp<E, DType, etype> &exp,
122  DType scale) {
124  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
126  ::Check(exp.self()).FlatTo2D();
128  CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
129  CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor";
130  cuda::MapReduceKeepLowest<Saver, Reducer>
131  (MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape,
133 }
134 
135 template<typename Saver, typename Reducer, int dimkeep,
136  typename R, typename DType, typename E, int etype>
138  const expr::Exp<E, DType, etype> &exp,
139  DType scale) {
141  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
142  typedef Shape<expr::ExpInfo<E>::kDim> EShape;
143  EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
144  ::Check(exp.self());
146  CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match";
147  // use equvalent form
148  Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
149  eshape[dimkeep],
150  eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
151  eshape[EShape::kSubdim]);
152  // call equavalent map red dim 2
153  cuda::MapReduceKeepDim1<Saver, Reducer>
154  (MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape,
156 }
157 template<typename DType>
158 inline void Softmax(Tensor<gpu, 2, DType> dst,
159  const Tensor<gpu, 2, DType>& src) {
160  cuda::Softmax(dst, src);
161 }
162 
163 template<typename DType>
164 inline void Softmax(Tensor<gpu, 3, DType> dst,
165  const Tensor<gpu, 3, DType>& src) {
166  cuda::Softmax(dst, src);
167 }
168 
169 template<typename DType>
170 inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
171  const Tensor<gpu, 2, DType> &src,
172  const Tensor<gpu, 1, DType> &label) {
173  cuda::SoftmaxGrad(dst, src, label);
174 }
175 
176 template<typename DType>
177 inline void SmoothSoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
178  const Tensor<gpu, 2, DType> &src,
179  const Tensor<gpu, 1, DType> &label,
180  const float alpha) {
181  cuda::SmoothSoftmaxGrad(dst, src, label, alpha);
182 }
183 
184 template<typename DType>
185 inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
186  const Tensor<gpu, 2, DType> &src,
187  const Tensor<gpu, 1, DType> &label,
188  const DType &ignore_label) {
189  cuda::SoftmaxGrad(dst, src, label, ignore_label);
190 }
191 
192 template<typename DType>
193 inline void SmoothSoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
194  const Tensor<gpu, 2, DType> &src,
195  const Tensor<gpu, 1, DType> &label,
196  const DType &ignore_label,
197  const float alpha) {
198  cuda::SmoothSoftmaxGrad(dst, src, label, ignore_label, alpha);
199 }
200 
201 template<typename DType>
202 inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
203  const Tensor<gpu, 3, DType> &src,
204  const Tensor<gpu, 2, DType> &label) {
205  cuda::SoftmaxGrad(dst, src, label);
206 }
207 
208 template<typename DType>
209 inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
210  const Tensor<gpu, 3, DType> &src,
211  const Tensor<gpu, 2, DType> &label,
212  const DType &ignore_label) {
213  cuda::SoftmaxGrad(dst, src, label, ignore_label);
214 }
215 
216 template<bool clip, typename IndexType, typename DType>
217 inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
218  const Tensor<gpu, 1, IndexType>& index,
219  const Tensor<gpu, 2, DType> &src) {
220  cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
221 }
222 
223 template<typename IndexType, typename DType>
225  const Tensor<gpu, 1, IndexType>& sorted,
226  const Tensor<gpu, 1, IndexType>& index,
227  const Tensor<gpu, 2, DType> &src) {
228  cuda::AddTakeGradLargeBatch(dst, sorted, index, src);
229 }
230 
231 template<typename KDType, typename VDType>
233  bool is_ascend) {
234  cuda::SortByKey(keys, values, is_ascend);
235 }
236 
237 template<typename IndexType, typename DType>
238 inline void IndexFill(Tensor<gpu, 2, DType> dst,
239  const Tensor<gpu, 1, IndexType>& index,
240  const Tensor<gpu, 2, DType> &src) {
241  cuda::IndexFill(dst, index, src);
242 }
243 } // namespace mshadow
244 #endif // __CUDACC__
245 #endif // MSHADOW_TENSOR_GPU_INL_H_
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:122
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix...
Definition: tensor_cpu-inl.h:526
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:288
void SmoothSoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label, const float alpha)
Definition: tensor_cpu-inl.h:305
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
Definition: pad.h:53
DType * dptr_
pointer to the data
Definition: tensor.h:416
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:391
used to help static type check
Definition: expr_engine-inl.h:312
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:127
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:189
Definition: stream_gpu-inl.h:19
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:418
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:222
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) ...
Definition: tensor_cpu-inl.h:537
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) ...
Definition: tensor_cpu-inl.h:465
#define MSHADOW_CUDA_CALL(func)
Protected cuda call in mshadow.
Definition: base.h:252
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) ...
Definition: tensor_cpu-inl.h:205
static Shape< dim > Check(const E &t)
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:501
Definition: expr_engine-inl.h:327
int32_t index_t
type that will be used for index
Definition: base.h:291
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:98
void ShutdownTensorEngine< gpu >(void)
Definition: tensor_gpu-inl.h:31
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[sorted[i]] += src[index[i]] Called when the bat...
Definition: tensor_cpu-inl.h:516
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
void InitTensorEngine< gpu >(int dev_id)
Definition: tensor_gpu-inl.h:15
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) ...
Definition: tensor_cpu-inl.h:232
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
void SetDevice< gpu >(int devid)
Definition: tensor_gpu-inl.h:34
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:498
namespace for mshadow
Definition: base.h:282
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:487
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:423
general tensor
Definition: tensor.h:402
#define MSHADOW_MIN_PAD_RATIO
x dimension of data must be bigger pad_size * ratio to be alloced padded memory, otherwise use tide a...
Definition: base.h:65
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365