mxnet
rtc.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_RTC_H_
21 #define MXNET_RTC_H_
22 #include "./base.h"
23 #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
24 #include <nvrtc.h>
25 #include <cuda.h>
26 
27 #include <vector>
28 #include <string>
29 #include <memory>
30 #include <utility>
31 #include <unordered_map>
32 #include <unordered_set>
33 #include "./ndarray.h"
34 
35 namespace mxnet {
36 namespace rtc {
37 
39 class CudaModule {
40  private:
42  struct Chunk {
48  Chunk(const char* source,
49  const std::vector<std::string>& options,
50  const std::vector<std::string>& exports);
52  ~Chunk();
59  CUfunction GetFunction(const std::string& mangled_name, const Context& ctx);
61  nvrtcProgram prog_;
63  char* ptx_;
65  std::unordered_map<int, CUmodule> mod_;
67  std::unordered_set<std::string> exports_;
68  };
70  std::shared_ptr<Chunk> ptr_;
71 
72  public:
74  struct ArgType {
76  bool is_ndarray;
78  bool is_const;
80  mshadow::TypeFlag dtype;
81  };
83  class Kernel {
84  public:
86  void Launch(const Context& ctx, const std::vector<dmlc::any>& args,
87  uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
88  uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
89  uint32_t shared_mem);
91  const std::vector<ArgType>& signature() { return signature_; }
92 
93  private:
94  friend class CudaModule;
101  Kernel(const std::shared_ptr<Chunk>& mod,
102  const std::string& mangled_name,
103  const std::vector<ArgType>& signature);
105  std::string mangled_name_;
107  std::vector<ArgType> signature_;
109  std::shared_ptr<Chunk> mod_;
111  std::unordered_map<int, CUfunction> func_;
112  };
118  CudaModule(const char* source,
119  const std::vector<std::string>& options,
120  const std::vector<std::string>& exports)
121  : ptr_(std::make_shared<Chunk>(source, options, exports)) {}
128  std::shared_ptr<Kernel> GetKernel(const std::string& name,
129  const std::vector<ArgType>& signature);
130 };
131 
132 } // namespace rtc
133 } // namespace mxnet
134 
135 #endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
136 #endif // MXNET_RTC_H_
namespace of mxnet
Definition: base.h:89
Definition: optional.h:241
TypeFlag
data type flag
Definition: base.h:306