mxnet
src
common
cuda
rtc
special_functions-inl.h
Go to the documentation of this file.
1
/*
2
* Licensed to the Apache Software Foundation (ASF) under one
3
* or more contributor license agreements. See the NOTICE file
4
* distributed with this work for additional information
5
* regarding copyright ownership. The ASF licenses this file
6
* to you under the Apache License, Version 2.0 (the
7
* "License"); you may not use this file except in compliance
8
* with the License. You may obtain a copy of the License at
9
*
10
* http://www.apache.org/licenses/LICENSE-2.0
11
*
12
* Unless required by applicable law or agreed to in writing,
13
* software distributed under the License is distributed on an
14
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
* KIND, either express or implied. See the License for the
16
* specific language governing permissions and limitations
17
* under the License.
18
*/
19
20
#ifndef MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_
21
#define MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_
22
23
#include <cfloat>
24
#include <string>
25
26
namespace
mxnet
{
27
namespace
common {
28
namespace
cuda {
29
namespace
rtc {
30
31
// This code is based on the Cephes Library availible at http://www.netlib.org/cephes
32
// The original author, Stephen Moshier, has kindly given permission to use this code
33
// in mxnet. (See email below).
34
//
35
// Date: Tue, 13 Sep 2016 09:28:20 -0400
36
// From: Stephen Moshier
37
// To: Flunkert, Valentin
38
// Subject: Re: cephes code in mxnet
39
//
40
// Hello Valentin,
41
//
42
// Thank you for writing. You are welcome to use and modify the Cephes code
43
// and distribute it under the Apache license.
44
//
45
// Good luck with your project,
46
// Steve Moshier
47
//
48
// Cephes Math Library Release 2.2: June, 1992
49
// Copyright 1984, 1987, 1992 by Stephen L. Moshier
50
// Direct inquiries to 30 Frost Street, Cambridge, MA 02140
51
//
52
const
char
special_functions_definitions
[] = R
"code(
53
namespace op {
54
55
namespace special_functions {
56
57
template<typename DType>
58
__device__ inline static DType trigamma(DType x);
59
60
template<>
61
__device__ inline double trigamma<double>(double x) {
62
double PI(3.14159265358979323846);
63
double sign = +1;
64
double result = 0;
65
if (x < 0.5) {
66
sign = -1;
67
const double sin_pi_x = sin(PI * x);
68
result -= (PI * PI) / (sin_pi_x * sin_pi_x);
69
x = 1 - x;
70
}
71
for (int i = 0; i < 6; ++i) {
72
result += 1 / (x * x);
73
x += 1;
74
}
75
const double ixx = 1 / (x*x);
76
result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x;
77
return sign * result;
78
}
79
80
template<>
81
__device__ inline float trigamma<float>(float x) {
82
float PI(3.14159265358979323846);
83
float sign = +1;
84
float result = 0;
85
if (x < 0.5f) {
86
sign = -1;
87
const float sin_pi_x = sinf(PI * x);
88
result -= (PI * PI) / (sin_pi_x * sin_pi_x);
89
x = 1 - x;
90
}
91
for (int i = 0; i < 6; ++i) {
92
result += 1 / (x * x);
93
x += 1;
94
}
95
const float ixx = 1 / (x*x);
96
result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x;
97
return sign * result;
98
}
99
100
struct cephes {
101
/*
102
* Helper to evaluate a polynomial given an array of coefficients.
103
*/
104
template <typename DType>
105
__device__ inline static DType polevl(DType x, const DType coef[], int N) {
106
DType ans;
107
DType const *p;
108
int i;
109
110
p = coef;
111
ans = *p++;
112
113
i = N;
114
do {
115
ans = ans * x + *p++;
116
} while ( --i );
117
118
return( ans );
119
}
120
121
122
/*
123
* Helper function for psi that handles double/float specific differences
124
* in the algorithm.
125
*/
126
template<typename DType>
127
__device__ inline static DType psi_helper(DType s);
128
129
/*
130
*
131
* Psi (digamma) function
132
*
133
*
134
* SYNOPSIS:
135
*
136
* float x, y, psif();
137
*
138
* y = psif( x );
139
*
140
*
141
* DESCRIPTION:
142
*
143
* d -
144
* psi(x) = -- ln | (x)
145
* dx
146
*
147
* is the logarithmic derivative of the gamma function.
148
* For integer x,
149
* n-1
150
* -
151
* psi(n) = -EUL + > 1/k.
152
* -
153
* k=1
154
*
155
* This formula is used for 0 < n <= 10. If x is negative, it
156
* is transformed to a positive argument by the reflection
157
* formula psi(1-x) = psi(x) + pi cot(pi x).
158
* For general positive x, the argument is made greater than 10
159
* using the recurrence psi(x+1) = psi(x) + 1/x.
160
* Then the following asymptotic expansion is applied:
161
*
162
* inf. B
163
* - 2k
164
* psi(x) = log(x) - 1/2x - > -------
165
* - 2k
166
* k=1 2k x
167
*
168
* where the B2k are Bernoulli numbers.
169
*
170
* ACCURACY:
171
* Absolute error, relative when |psi| > 1 :
172
* arithmetic domain # trials peak rms
173
* IEEE -33,0 30000 8.2e-7 1.2e-7
174
* IEEE 0,33 100000 7.3e-7 7.7e-8
175
*
176
* ERROR MESSAGES:
177
* message condition value returned
178
* psi singularity x integer <=0 MAXNUMF
179
*/
180
template<typename DType>
181
__device__ inline static DType psi(DType x) {
182
DType p, q, nz, s, w, y;
183
int i, n, negative;
184
185
DType EUL(0.57721566490153286061);
186
DType PI(3.14159265358979323846);
187
188
negative = 0;
189
nz = 0.0;
190
191
if ( x <= 0.0 ) {
192
negative = 1;
193
q = x;
194
p = ::floor(q);
195
if ( p == q ) {
196
return DBL_MAX;
197
}
198
/* Remove the zeros of tan(PI x)
199
* by subtracting the nearest integer from x
200
*/
201
nz = q - p;
202
if ( nz != 0.5 ) {
203
if ( nz > 0.5 ) {
204
p += 1.0;
205
nz = q - p;
206
}
207
nz = PI/::tan(PI*nz);
208
} else {
209
nz = 0.0;
210
}
211
x = 1.0 - x;
212
}
213
214
/* check for positive integer up to 10 */
215
if ( (x <= 10.0) && (x == ::floor(x)) ) {
216
y = 0.0;
217
n = x;
218
for ( i = 1; i < n; i++ ) {
219
w = i;
220
y += 1.0/w;
221
}
222
y -= EUL;
223
goto done;
224
}
225
226
s = x;
227
w = 0.0;
228
while ( s < 10.0 ) {
229
w += 1.0/s;
230
s += 1.0;
231
}
232
233
y = psi_helper(s);
234
235
y = logf(s) - (0.5/s) - y - w;
236
237
done:
238
239
if ( negative ) {
240
y -= nz;
241
}
242
243
return(y);
244
}
245
};
246
247
248
template<>
249
__device__ inline double cephes::psi_helper<double>(double s) {
250
double z;
251
const double A[] = {
252
8.33333333333333333333E-2,
253
-2.10927960927960927961E-2,
254
7.57575757575757575758E-3,
255
-4.16666666666666666667E-3,
256
3.96825396825396825397E-3,
257
-8.33333333333333333333E-3,
258
8.33333333333333333333E-2
259
};
260
261
if ( s < 1.0e17 ) {
262
z = 1.0/(s * s);
263
return z * cephes::polevl<double>(z, A, 6);
264
} else {
265
return 0.0;
266
}
267
}
268
269
template<>
270
__device__ inline float cephes::psi_helper<float>(float s) {
271
float z;
272
const float A[] = {
273
-4.16666666666666666667E-3f,
274
3.96825396825396825397E-3f,
275
-8.33333333333333333333E-3f,
276
8.33333333333333333333E-2f
277
};
278
279
if ( s < 1.0e8 ) {
280
z = 1.0/(s * s);
281
return z * cephes::polevl<float>(z, A, 3);
282
} else {
283
return 0.0;
284
}
285
}
286
} // namespace special_functions
287
} // namespace op
288
)code";
289
290
}
// namespace rtc
291
}
// namespace cuda
292
}
// namespace common
293
}
// namespace mxnet
294
295
#endif // MXNET_COMMON_CUDA_RTC_SPECIAL_FUNCTIONS_INL_H_
mxnet
namespace of mxnet
Definition:
api_registry.h:33
mxnet::common::cuda::rtc::special_functions_definitions
const char special_functions_definitions[]
Definition:
special_functions-inl.h:52
Generated on Thu Jan 5 2023 03:47:40 for mxnet by
1.8.17