24 #ifndef MXNET_COMMON_CUDA_CUDNN_CXX_H_
25 #define MXNET_COMMON_CUDA_CUDNN_CXX_H_
28 #if MXNET_USE_CUDNN == 1
36 #if !defined(__CUDACC__) // Can be removed when CUDA 10 support is dropped.
38 #endif // !defined(__CUDACC__)
41 #include <unordered_set>
47 STATIC_ASSERT_CUDNN_VERSION_GE(8002);
52 struct DescriptorDestroyer {
53 using pointer = cudnnBackendDescriptor_t;
55 void operator()(cudnnBackendDescriptor_t desc) {
56 CUDNN_CALL_NONFATAL(cudnnBackendDestroyDescriptor(desc));
60 using Descriptor = std::unique_ptr<cudnnBackendDescriptor_t, DescriptorDestroyer>;
62 struct WeakDescriptor {
63 cudnnBackendDescriptor_t desc =
nullptr;
65 explicit WeakDescriptor(
const Descriptor& other) : desc(other.get()) {}
66 cudnnBackendDescriptor_t get()
const {
75 struct AttrType<int64_t> {
76 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_INT64;
80 struct AttrType<void*> {
81 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_VOID_PTR;
85 struct AttrType<float> {
86 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_FLOAT;
90 struct AttrType<double> {
91 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_DOUBLE;
95 struct AttrType<cudnnHandle_t> {
96 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_HANDLE;
100 struct AttrType<bool> {
101 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_BOOLEAN;
105 struct AttrType<cudnnDataType_t> {
106 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_DATA_TYPE;
110 struct AttrType<cudnnConvolutionMode_t> {
111 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_CONVOLUTION_MODE;
115 struct AttrType<cudnnNanPropagation_t> {
116 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_NAN_PROPOGATION;
120 struct AttrType<cudnnPointwiseMode_t> {
121 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_POINTWISE_MODE;
125 struct AttrType<cudnnBackendHeurMode_t> {
126 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_HEUR_MODE;
130 struct AttrType<cudnnBackendNumericalNote_t> {
131 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_NUMERICAL_NOTE;
134 #if CUDNN_VERSION >= 8100
136 struct AttrType<cudnnReduceTensorOp_t> {
137 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_REDUCTION_OPERATOR_TYPE;
139 #if CUDNN_VERSION >= 8200
141 struct AttrType<cudnnBackendBehaviorNote_t> {
142 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_BEHAVIOR_NOTE;
144 #endif // CUDNN_VERSION >= 8200
145 #endif // CUDNN_VERSION >= 8100
148 struct AttrType<cudnnBackendKnobType_t> {
149 static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_KNOB_TYPE;
152 void SetAttr(
const Descriptor& desc, cudnnBackendAttributeName_t name,
const Descriptor& val);
153 void SetAttr(
const Descriptor& desc, cudnnBackendAttributeName_t name,
const WeakDescriptor& val);
154 void SetAttr(
const Descriptor& desc,
155 cudnnBackendAttributeName_t name,
156 const std::vector<Descriptor>& val);
158 template <
typename T>
159 void SetAttr(
const Descriptor& desc, cudnnBackendAttributeName_t name, T val) {
160 CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, AttrType<T>::type, 1, &val));
163 template <
typename T>
164 void SetAttr(
const Descriptor& desc, cudnnBackendAttributeName_t name,
const std::vector<T>& val) {
165 CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, AttrType<T>::type, val.size(), val.data()));
168 template <
typename T,
size_t N>
169 void SetAttr(
const Descriptor& desc,
170 cudnnBackendAttributeName_t name,
171 const std::array<T, N>& val) {
172 CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, AttrType<T>::type, val.size(), val.data()));
175 inline void SetAttrs(
const Descriptor& desc) {}
177 template <
typename T,
typename... Attrs>
178 void SetAttrs(
const Descriptor& desc, cudnnBackendAttributeName_t name, T&& val, Attrs&&... rest) {
179 SetAttr(desc, name, std::forward<T>(val));
180 SetAttrs(desc, std::forward<Attrs>(rest)...);
183 std::vector<cudnnBackendDescriptor_t> MakeRawDescriptors(
size_t n,
184 cudnnBackendDescriptorType_t type);
186 Descriptor Make(cudnnBackendDescriptorType_t type);
188 template <
typename... Attrs>
189 Descriptor Make(cudnnBackendDescriptorType_t type, Attrs&&... attrs) {
190 auto desc = Make(type);
191 SetAttrs(desc, std::forward<Attrs>(attrs)...);
195 template <
typename... Attrs>
196 Descriptor MakeFinalized(cudnnBackendDescriptorType_t type, Attrs&&... attrs) {
197 auto desc = Make(type, std::forward<Attrs>(attrs)...);
198 CUDNN_CALL(cudnnBackendFinalize(desc.get()));
202 template <
typename T>
203 T GetAttr(
const Descriptor& desc, cudnnBackendAttributeName_t name) {
205 int64_t ret_count = 0;
206 CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, AttrType<T>::type, 1, &ret_count, &ret));
207 CHECK_EQ(ret_count, 1);
211 template <
typename T>
212 std::vector<T> GetAllAttrs(
const Descriptor& desc, cudnnBackendAttributeName_t name) {
214 CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, AttrType<T>::type, 0, &count,
nullptr));
215 std::vector<T> ret(count);
216 CUDNN_CALL(cudnnBackendGetAttribute(
217 desc.get(), name, AttrType<T>::type, ret.size(), &count, ret.data()));
221 template <
typename T>
222 std::vector<T> GetSomeAttrs(
size_t max_n,
223 const Descriptor& desc,
224 cudnnBackendAttributeName_t name) {
226 std::vector<T> ret(max_n);
227 CUDNN_CALL(cudnnBackendGetAttribute(
228 desc.get(), name, AttrType<T>::type, ret.size(), &count, ret.data()));
233 Descriptor GetAttr(
const Descriptor& desc,
234 cudnnBackendAttributeName_t name,
235 cudnnBackendDescriptorType_t type);
237 std::vector<Descriptor> GetAllAttrs(
const Descriptor& desc,
238 cudnnBackendAttributeName_t name,
239 cudnnBackendDescriptorType_t type);
241 std::vector<Descriptor> GetSomeAttrs(
size_t max_n,
242 const Descriptor& desc,
243 cudnnBackendAttributeName_t name,
244 cudnnBackendDescriptorType_t type);
247 template <
typename T>
248 std::vector<T> PackedStrides(
const std::vector<size_t>& order,
const std::vector<T>& dims) {
249 CHECK_EQ(order.size(), dims.size());
250 std::vector<T> ret(dims.size(), 1);
251 for (
size_t i = dims.size() - 1; i--;)
252 ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
258 template <
typename Note>
259 inline bool IsCompatible(
const std::vector<Note>& notes,
260 const std::vector<Note>& require_notes,
261 const std::vector<Note>& exclude_notes) {
262 for (
auto rn : require_notes) {
263 auto it = std::find(notes.begin(), notes.end(), rn);
264 if (it == notes.end())
267 for (
auto en : exclude_notes) {
268 auto it = std::find(notes.begin(), notes.end(), en);
269 if (it != notes.end())
278 std::vector<Descriptor> GetPlans(cudnnBackendHeurMode_t h_mode,
279 cudnnHandle_t handle,
280 const Descriptor& op_graph,
281 size_t workspace_limit,
282 size_t* max_workspace,
283 const std::unordered_set<int64_t>& excl_engines,
284 const std::vector<cudnnBackendNumericalNote_t>& req_numeric,
285 const std::vector<cudnnBackendNumericalNote_t>& excl_numeric,
286 #
if CUDNN_VERSION >= 8200
287 const std::vector<cudnnBackendBehaviorNote_t>& req_behavior,
288 const std::vector<cudnnBackendBehaviorNote_t>& excl_behavior,
290 bool verbose_filter);
292 #if !defined(__CUDACC__) // Can be removed when CUDA 10 support is dropped.
297 using Sampler = std::function<std::optional<float>(
float)>;
303 Sampler MakeAvgSampler(
size_t n,
float max_cutoff_msec = 1000.0,
size_t warmups = 1);
312 std::vector<FindResult> FindTopPlans(std::vector<Descriptor>&& plans,
314 cudnnHandle_t handle,
315 const Descriptor& var_pack,
317 #endif // !defined(__CUDACC__)
319 std::string PlanStr(
const Descriptor& plan);
324 #endif // MXNET_USE_CUDNN == 1
326 #endif // MXNET_COMMON_CUDA_CUDNN_CXX_H_