mxnet
sse-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_PACKET_SSE_INL_H_
27 #define MSHADOW_PACKET_SSE_INL_H_
28 
29 #include <emmintrin.h>
30 #include "../base.h"
31 #include "../packet-inl.h"
32 
33 namespace mshadow {
34 namespace packet {
35 template<>
36 struct Packet<float, kSSE2> {
37  public:
39  static constexpr index_t size = 4;
41  __m128 data_;
42  // enable default copy constructor
43  Packet(void) {}
44  // constructor from the intrinsic type
45  explicit Packet(__m128 data) : data_(data) {}
46  // create a fill with the target value s
48  return Packet<float, kSSE2>(_mm_set1_ps(s));
49  }
50  // load from address
51  MSHADOW_CINLINE static Packet<float, kSSE2> Load(const float* src) {
52  return Packet<float, kSSE2>(_mm_load_ps(src));
53  }
54  // load from address
56  return Packet<float, kSSE2>(_mm_loadu_ps(src));
57  }
58  // fill it with value s
60  data_ = _mm_set1_ps(s);
61  return *this;
62  }
63  // store data into dst
64  MSHADOW_CINLINE void Store(float* dst) const {
65  _mm_store_ps(dst, data_);
66  }
67  // get the sum of all contents
68  MSHADOW_CINLINE float Sum() const {
69  __m128 ans = _mm_add_ps(data_, _mm_movehl_ps(data_, data_));
70  __m128 rst = _mm_add_ss(ans, _mm_shuffle_ps(ans, ans, 1));
71 #if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64)
72  return rst.m128_f32[0];
73 #else
74  float rr = _mm_cvtss_f32(rst);
75  return rr;
76 #endif
77  }
78 };
79 
80 
82 template<>
83 struct Packet<double, kSSE2> {
85  static constexpr index_t size = 2;
86  // internal data
87  __m128d data_;
88  // constructor
89  Packet(void) {}
90  explicit Packet(__m128d data) : data_(data) {}
91  // create a fill with the target value s
93  return Packet<double, kSSE2>(_mm_set1_pd(s));
94  }
95  // load from address
96  MSHADOW_CINLINE static Packet<double, kSSE2> Load(const double* src) {
97  return Packet<double, kSSE2>(_mm_load_pd(src));
98  }
100  return Packet<double, kSSE2>(_mm_loadu_pd(src));
101  }
102  // fill it with value s
104  data_ = _mm_set1_pd(s);
105  return *this;
106  }
107  // store data into dst
108  MSHADOW_CINLINE void Store(double* dst) const {
109  _mm_store_pd(dst, data_);
110  }
111  // get sum of all content
112  inline double Sum(void) const {
113  __m128d tmp = _mm_add_sd(data_, _mm_unpackhi_pd(data_, data_));
114 #if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64)
115  return tmp.m128d_f64[0];
116 #else
117  double ans = _mm_cvtsd_f64(tmp);
118  return ans;
119 #endif
120  }
121 };
122 
124  const Packet<float, kSSE2>& rhs) {
125  return Packet<float, kSSE2>(_mm_add_ps(lhs.data_, rhs.data_));
126 }
127 
129  const Packet<double, kSSE2>& rhs) {
130  return Packet<double, kSSE2>(_mm_add_pd(lhs.data_, rhs.data_));
131 }
132 
134  const Packet<float, kSSE2>& rhs) {
135  return Packet<float, kSSE2>(_mm_sub_ps(lhs.data_, rhs.data_));
136 }
137 
139  const Packet<double, kSSE2>& rhs) {
140  return Packet<double, kSSE2>(_mm_sub_pd(lhs.data_, rhs.data_));
141 }
142 
144  const Packet<float, kSSE2>& rhs) {
145  return Packet<float, kSSE2>(_mm_mul_ps(lhs.data_, rhs.data_));
146 }
147 
149  const Packet<double, kSSE2>& rhs) {
150  return Packet<double, kSSE2>(_mm_mul_pd(lhs.data_, rhs.data_));
151 }
152 
153 
155  const Packet<float, kSSE2>& rhs) {
156  return Packet<float, kSSE2>(_mm_div_ps(lhs.data_, rhs.data_));
157 }
158 
160  const Packet<double, kSSE2>& rhs) {
161  return Packet<double, kSSE2>(_mm_div_pd(lhs.data_, rhs.data_));
162 }
163 
164 } // namespace packet
165 } // namespace mshadow
166 #endif // MSHADOW_PACKET_SSE_INL_H_
vector real type for float
Definition: sse-inl.h:83
static MSHADOW_CINLINE Packet< float, kSSE2 > Fill(float s)
Definition: sse-inl.h:47
MSHADOW_CINLINE Packet< float, kSSE2 > & operator=(float s)
Definition: sse-inl.h:59
static MSHADOW_CINLINE Packet< float, kSSE2 > LoadUnAligned(const float *src)
Definition: sse-inl.h:55
MSHADOW_CINLINE Packet< DType, kPlain > operator-(const Packet< DType, kPlain > &lhs, const Packet< DType, kPlain > &rhs)
Definition: plain-inl.h:78
MSHADOW_CINLINE Packet< DType, kPlain > operator/(const Packet< DType, kPlain > &lhs, const Packet< DType, kPlain > &rhs)
Definition: plain-inl.h:89
MSHADOW_CINLINE void Store(float *dst) const
Definition: sse-inl.h:64
static MSHADOW_CINLINE Packet< double, kSSE2 > Load(const double *src)
Definition: sse-inl.h:96
Packet(__m128 data)
Definition: sse-inl.h:45
__m128d data_
Definition: sse-inl.h:87
MSHADOW_CINLINE float Sum() const
Definition: sse-inl.h:68
int32_t index_t
type that will be used for index
Definition: base.h:336
Packet(__m128d data)
Definition: sse-inl.h:90
MSHADOW_CINLINE Packet< DType, kPlain > operator*(const Packet< DType, kPlain > &lhs, const Packet< DType, kPlain > &rhs)
Definition: plain-inl.h:83
Definition: packet-inl.h:44
MSHADOW_CINLINE Packet< double, kSSE2 > & operator=(double s)
Definition: sse-inl.h:103
double Sum(void) const
Definition: sse-inl.h:112
__m128 data_
The internal data.
Definition: sse-inl.h:41
MSHADOW_CINLINE Packet< DType, kPlain > operator+(const Packet< DType, kPlain > &lhs, const Packet< DType, kPlain > &rhs)
Definition: plain-inl.h:72
#define MSHADOW_CINLINE
cpu force inline
Definition: base.h:226
static MSHADOW_CINLINE Packet< float, kSSE2 > Load(const float *src)
Definition: sse-inl.h:51
overloaded + operator between half_t and bf16_t
Definition: base.h:327
static MSHADOW_CINLINE Packet< double, kSSE2 > LoadUnAligned(const double *src)
Definition: sse-inl.h:99
Packet(void)
Definition: sse-inl.h:43
static MSHADOW_CINLINE Packet< double, kSSE2 > Fill(double s)
Definition: sse-inl.h:92
MSHADOW_CINLINE void Store(double *dst) const
Definition: sse-inl.h:108
Generic packet type.
Definition: packet-inl.h:60
Packet(void)
Definition: sse-inl.h:89