mxnet
range.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_RANGE_H_
8 #define MSHADOW_EXTENSION_RANGE_H_
9 
10 #include "../extension.h"
11 
12 namespace mshadow {
13 namespace expr {
23 template<typename DType>
24 struct RangeExp:
25  public Exp<RangeExp<DType>, DType, type::kMapper> {
26  const DType start_;
27  const DType stop_;
28  const DType step_;
29  const int repeat_;
31  RangeExp(DType start, DType stop, DType step, int repeat)
32  : start_(start), stop_(stop), step_(step), repeat_(repeat) {}
33 };
34 
35 template<typename DType>
36 inline RangeExp<DType>
37 range(DType start, DType stop, DType step = 1, int repeat = 1) {
38  return RangeExp<DType>(start, stop, step, repeat);
39 }
40 
41 //----------------------
42 // Execution plan
43 //----------------------
44 template<typename DType>
45 struct Plan<RangeExp<DType>, DType> {
46  public:
47  explicit Plan(const RangeExp<DType> &e)
48  : start_(e.start_),
49  stop_(e.stop_),
50  step_(e.step_),
51  repeat_(e.repeat_) {
52  }
53  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
54  return start_ + static_cast<DType>((static_cast<int>(x) / repeat_)) * step_;
55  }
56 
57  private:
58  const DType start_;
59  const DType stop_;
60  const DType step_;
61  const int repeat_;
62 };
63 
64 template<typename DType>
65 inline Plan<RangeExp<DType>, DType>
67  return Plan<RangeExp<DType>, DType>(exp);
68 }
69 
70 
71 template<typename DType>
72 inline int RangeOutSize(DType start, DType stop, DType step, int repeat) {
73  return repeat * ((stop - start - 1) / step + 1);
74 }
75 
76 template<>
77 inline int RangeOutSize<float>(float start, float stop, float step, int repeat) {
78  double d_start = static_cast<double>(start);
79  double d_stop = static_cast<double>(stop);
80  double d_step = static_cast<double>(step);
81  return repeat * static_cast<int>(ceil((d_stop - d_start) / d_step));
82 }
83 
84 template<>
85 inline int RangeOutSize<double>(double start, double stop, double step, int repeat) {
86  return repeat * static_cast<int>(ceil((stop - start) / step));
87 }
88 
89 
90 template<int dim, typename DType>
91 struct ShapeCheck<dim, RangeExp<DType> > {
92  inline static Shape<dim>
93  Check(const RangeExp<DType> &t) {
94  CHECK(dim == 1)
95  << "RangeExp only support 1 dimension output, received " << dim;
96  CHECK(t.step_ != 0)
97  << "RangeExp does not support step=0, received " << t.step_;
98  CHECK(t.repeat_ > 0)
99  << "RangeExp only supports repeat > 0, received " << t.repeat_;
100  if (t.step_ > 0) {
101  CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = "
102  << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
103  } else {
104  CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= "
105  << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
106  }
107  return Shape1(RangeOutSize<DType>(t.start_, t.stop_, t.step_, t.repeat_));
108  }
109 };
110 
111 template<typename DType>
112 struct ExpInfo<RangeExp<DType> > {
113  static const int kDim = 1;
114  static const int kDevMask = 0xffff;
115 };
116 } // namespace expr
117 } // namespace mshadow
118 #endif // MSHADOW_EXTENSION_RANGE_H_
Definition: expr_engine-inl.h:40
int RangeOutSize(DType start, DType stop, DType step, int repeat)
Definition: range.h:72
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: range.h:53
Plan(const RangeExp< DType > &e)
Definition: range.h:47
const DType start_
Definition: range.h:26
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
const int repeat_
Definition: range.h:29
const DType stop_
Definition: range.h:27
int32_t index_t
type that will be used for index
Definition: base.h:291
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:188
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
const DType step_
Definition: range.h:28
RangeExp< DType > range(DType start, DType stop, DType step=1, int repeat=1)
Definition: range.h:37
int RangeOutSize< double >(double start, double stop, double step, int repeat)
Definition: range.h:85
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
RangeExp(DType start, DType stop, DType step, int repeat)
constructor
Definition: range.h:31
Generate a range vector similar to python: range(start, stop[, step][, repeat]). If step is positive...
Definition: range.h:24
namespace for mshadow
Definition: base.h:282
static Shape< dim > Check(const RangeExp< DType > &t)
Definition: range.h:93
int RangeOutSize< float >(float start, float stop, float step, int repeat)
Definition: range.h:77