mxnet
special_functions-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_
22 
23 #include <cfloat>
24 #include <string>
25 
26 namespace mxnet {
27 namespace common {
28 namespace cuda {
29 namespace rtc {
30 
31 // This code is based on the Cephes Library availible at http://www.netlib.org/cephes
32 // The original author, Stephen Moshier, has kindly given permission to use this code
33 // in mxnet. (See email below).
34 //
35 // Date: Tue, 13 Sep 2016 09:28:20 -0400
36 // From: Stephen Moshier
37 // To: Flunkert, Valentin
38 // Subject: Re: cephes code in mxnet
39 //
40 // Hello Valentin,
41 //
42 // Thank you for writing. You are welcome to use and modify the Cephes code
43 // and distribute it under the Apache license.
44 //
45 // Good luck with your project,
46 // Steve Moshier
47 //
48 // Cephes Math Library Release 2.2: June, 1992
49 // Copyright 1984, 1987, 1992 by Stephen L. Moshier
50 // Direct inquiries to 30 Frost Street, Cambridge, MA 02140
51 //
52 const char special_functions_definitions[] = R"code(
53 namespace op {
54 
55 namespace special_functions {
56 
57 template<typename DType>
58 __device__ inline static DType trigamma(DType x);
59 
60 template<>
61 __device__ inline double trigamma<double>(double x) {
62  double PI(3.14159265358979323846);
63  double sign = +1;
64  double result = 0;
65  if (x < 0.5) {
66  sign = -1;
67  const double sin_pi_x = sin(PI * x);
68  result -= (PI * PI) / (sin_pi_x * sin_pi_x);
69  x = 1 - x;
70  }
71  for (int i = 0; i < 6; ++i) {
72  result += 1 / (x * x);
73  x += 1;
74  }
75  const double ixx = 1 / (x*x);
76  result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x;
77  return sign * result;
78 }
79 
80 template<>
81 __device__ inline float trigamma<float>(float x) {
82  float PI(3.14159265358979323846);
83  float sign = +1;
84  float result = 0;
85  if (x < 0.5f) {
86  sign = -1;
87  const float sin_pi_x = sinf(PI * x);
88  result -= (PI * PI) / (sin_pi_x * sin_pi_x);
89  x = 1 - x;
90  }
91  for (int i = 0; i < 6; ++i) {
92  result += 1 / (x * x);
93  x += 1;
94  }
95  const float ixx = 1 / (x*x);
96  result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x;
97  return sign * result;
98 }
99 
100 struct cephes {
101  /*
102  * Helper to evaluate a polynomial given an array of coefficients.
103  */
104  template <typename DType>
105  __device__ inline static DType polevl(DType x, const DType coef[], int N) {
106  DType ans;
107  DType const *p;
108  int i;
109 
110  p = coef;
111  ans = *p++;
112 
113  i = N;
114  do {
115  ans = ans * x + *p++;
116  } while ( --i );
117 
118  return( ans );
119  }
120 
121 
122  /*
123  * Helper function for psi that handles double/float specific differences
124  * in the algorithm.
125  */
126  template<typename DType>
127  __device__ inline static DType psi_helper(DType s);
128 
129  /*
130  *
131  * Psi (digamma) function
132  *
133  *
134  * SYNOPSIS:
135  *
136  * float x, y, psif();
137  *
138  * y = psif( x );
139  *
140  *
141  * DESCRIPTION:
142  *
143  * d -
144  * psi(x) = -- ln | (x)
145  * dx
146  *
147  * is the logarithmic derivative of the gamma function.
148  * For integer x,
149  * n-1
150  * -
151  * psi(n) = -EUL + > 1/k.
152  * -
153  * k=1
154  *
155  * This formula is used for 0 < n <= 10. If x is negative, it
156  * is transformed to a positive argument by the reflection
157  * formula psi(1-x) = psi(x) + pi cot(pi x).
158  * For general positive x, the argument is made greater than 10
159  * using the recurrence psi(x+1) = psi(x) + 1/x.
160  * Then the following asymptotic expansion is applied:
161  *
162  * inf. B
163  * - 2k
164  * psi(x) = log(x) - 1/2x - > -------
165  * - 2k
166  * k=1 2k x
167  *
168  * where the B2k are Bernoulli numbers.
169  *
170  * ACCURACY:
171  * Absolute error, relative when |psi| > 1 :
172  * arithmetic domain # trials peak rms
173  * IEEE -33,0 30000 8.2e-7 1.2e-7
174  * IEEE 0,33 100000 7.3e-7 7.7e-8
175  *
176  * ERROR MESSAGES:
177  * message condition value returned
178  * psi singularity x integer <=0 MAXNUMF
179  */
180  template<typename DType>
181  __device__ inline static DType psi(DType x) {
182  DType p, q, nz, s, w, y;
183  int i, n, negative;
184 
185  DType EUL(0.57721566490153286061);
186  DType PI(3.14159265358979323846);
187 
188  negative = 0;
189  nz = 0.0;
190 
191  if ( x <= 0.0 ) {
192  negative = 1;
193  q = x;
194  p = ::floor(q);
195  if ( p == q ) {
196  return DBL_MAX;
197  }
198  /* Remove the zeros of tan(PI x)
199  * by subtracting the nearest integer from x
200  */
201  nz = q - p;
202  if ( nz != 0.5 ) {
203  if ( nz > 0.5 ) {
204  p += 1.0;
205  nz = q - p;
206  }
207  nz = PI/::tan(PI*nz);
208  } else {
209  nz = 0.0;
210  }
211  x = 1.0 - x;
212  }
213 
214  /* check for positive integer up to 10 */
215  if ( (x <= 10.0) && (x == ::floor(x)) ) {
216  y = 0.0;
217  n = x;
218  for ( i = 1; i < n; i++ ) {
219  w = i;
220  y += 1.0/w;
221  }
222  y -= EUL;
223  goto done;
224  }
225 
226  s = x;
227  w = 0.0;
228  while ( s < 10.0 ) {
229  w += 1.0/s;
230  s += 1.0;
231  }
232 
233  y = psi_helper(s);
234 
235  y = logf(s) - (0.5/s) - y - w;
236 
237 done:
238 
239  if ( negative ) {
240  y -= nz;
241  }
242 
243  return(y);
244  }
245 };
246 
247 
248 template<>
249 __device__ inline double cephes::psi_helper<double>(double s) {
250  double z;
251  const double A[] = {
252  8.33333333333333333333E-2,
253  -2.10927960927960927961E-2,
254  7.57575757575757575758E-3,
255  -4.16666666666666666667E-3,
256  3.96825396825396825397E-3,
257  -8.33333333333333333333E-3,
258  8.33333333333333333333E-2
259  };
260 
261  if ( s < 1.0e17 ) {
262  z = 1.0/(s * s);
263  return z * cephes::polevl<double>(z, A, 6);
264  } else {
265  return 0.0;
266  }
267 }
268 
269 template<>
270 __device__ inline float cephes::psi_helper<float>(float s) {
271  float z;
272  const float A[] = {
273  -4.16666666666666666667E-3f,
274  3.96825396825396825397E-3f,
275  -8.33333333333333333333E-3f,
276  8.33333333333333333333E-2f
277  };
278 
279  if ( s < 1.0e8 ) {
280  z = 1.0/(s * s);
281  return z * cephes::polevl<float>(z, A, 3);
282  } else {
283  return 0.0;
284  }
285 }
286 } // namespace special_functions
287 } // namespace op
288 )code";
289 
290 } // namespace rtc
291 } // namespace cuda
292 } // namespace common
293 } // namespace mxnet
294 
295 #endif // MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::cuda::rtc::special_functions_definitions
const char special_functions_definitions[]
Definition: special_functions-inl.h:52