mxnet
registry.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
43 // Acknowledgement: This file originates from incubator-tvm
44 #ifndef MXNET_RUNTIME_REGISTRY_H_
45 #define MXNET_RUNTIME_REGISTRY_H_
46 
47 #include <string>
48 #include <vector>
49 #include "packed_func.h"
50 
51 namespace mxnet {
52 namespace runtime {
53 
55 class Registry {
56  public:
61  MXNET_DLL Registry& set_body(PackedFunc f); // NOLINT(*)
66  Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
67  return set_body(PackedFunc(f));
68  }
83  template <typename FType, typename FLambda>
84  Registry& set_body_typed(FLambda f) {
85  return set_body(TypedPackedFunc<FType>(f).packed());
86  }
87 
110  template <typename R, typename... Args>
111  Registry& set_body_typed(R (*f)(Args...)) {
112  return set_body(TypedPackedFunc<R(Args...)>(f));
113  }
114 
136  template <typename T, typename R, typename... Args>
137  Registry& set_body_method(R (T::*f)(Args...)) {
138  return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
139  // call method pointer
140  return (target.*f)(params...);
141  });
142  }
143 
165  template <typename T, typename R, typename... Args>
166  Registry& set_body_method(R (T::*f)(Args...) const) {
167  return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
168  // call method pointer
169  return (target.*f)(params...);
170  });
171  }
172 
204  template <typename TObjectRef,
205  typename TNode,
206  typename R,
207  typename... Args,
208  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
209  Registry& set_body_method(R (TNode::*f)(Args...)) {
210  return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
211  TNode* target = ref.operator->();
212  // call method pointer
213  return (target->*f)(params...);
214  });
215  }
216 
248  template <typename TObjectRef,
249  typename TNode,
250  typename R,
251  typename... Args,
252  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
253  Registry& set_body_method(R (TNode::*f)(Args...) const) {
254  return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
255  const TNode* target = ref.operator->();
256  // call method pointer
257  return (target->*f)(params...);
258  });
259  }
260 
267  MXNET_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
273  MXNET_DLL static bool Remove(const std::string& name);
280  MXNET_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*)
285  MXNET_DLL static std::vector<std::string> ListNames();
286 
287  // Internal class.
288  struct Manager;
289 
290  protected:
292  std::string name_;
295  friend struct Manager;
296 };
297 
299 #if defined(__GNUC__)
300 #define MXNET_ATTRIBUTE_UNUSED __attribute__((unused))
301 #else
302 #define MXNET_ATTRIBUTE_UNUSED
303 #endif
304 
305 #define MXNET_STR_CONCAT_(__x, __y) __x##__y
306 #define MXNET_STR_CONCAT(__x, __y) MXNET_STR_CONCAT_(__x, __y)
307 
308 #define MXNET_FUNC_REG_VAR_DEF \
309  static MXNET_ATTRIBUTE_UNUSED ::mxnet::runtime::Registry& __mk_##MXNET
310 
319 #define MXNET_REGISTER_GLOBAL(OpName) \
320  MXNET_STR_CONCAT(MXNET_FUNC_REG_VAR_DEF, __COUNTER__) = \
321  ::mxnet::runtime::Registry::Register(OpName)
322 
323 } // namespace runtime
324 } // namespace mxnet
325 #endif // MXNET_RUNTIME_REGISTRY_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::runtime::Registry::set_body_method
Registry & set_body_method(R(T::*f)(Args...) const)
set the body of the function to be the passed method pointer. Note that this will ignore default arg ...
Definition: registry.h:166
mxnet::runtime::Registry::set_body
Registry & set_body(PackedFunc::FType f)
set the body of the function to be f
Definition: registry.h:66
mxnet::runtime::Registry::set_body_typed
Registry & set_body_typed(R(*f)(Args...))
set the body of the function to the given function pointer. Note that this doesn't work with lambdas,...
Definition: registry.h:111
mxnet::runtime::Registry::Remove
static MXNET_DLL bool Remove(const std::string &name)
Erase global function from registry, if exist.
mxnet::runtime::Registry::func_
PackedFunc func_
internal packed function
Definition: registry.h:294
mxnet::runtime::Registry
Registry for global function.
Definition: registry.h:55
packed_func.h
Type-erased function used across MXNET API.
mxnet::runtime::Registry::set_body_method
Registry & set_body_method(R(TNode::*f)(Args...) const)
set the body of the function to be the passed method pointer. Used when calling a method on a Node su...
Definition: registry.h:253
mxnet::runtime::TypedPackedFunc
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:152
mxnet::runtime::Registry::set_body_typed
Registry & set_body_typed(FLambda f)
set the body of the function to be TypedPackedFunc.
Definition: registry.h:84
mxnet::runtime::Registry::ListNames
static MXNET_DLL std::vector< std::string > ListNames()
Get the names of currently registered global function.
mxnet::runtime::Registry::Get
static const MXNET_DLL PackedFunc * Get(const std::string &name)
Get the global function by name.
MXNET_DLL
#define MXNET_DLL
MXNET_DLL prefix for windows.
Definition: c_api.h:53
mxnet::runtime::Registry::Register
static MXNET_DLL Registry & Register(const std::string &name, bool override=false)
Register a function with given name.
mxnet::runtime::Registry::name_
std::string name_
name of the function
Definition: registry.h:288
mxnet::runtime::PackedFunc::FType
std::function< void(MXNetArgs args, MXNetRetValue *rv)> FType
The internal std::function.
Definition: packed_func.h:100
mxnet::runtime::Registry::set_body_method
Registry & set_body_method(R(T::*f)(Args...))
set the body of the function to be the passed method pointer. Note that this will ignore default arg ...
Definition: registry.h:137
mxnet::runtime::Registry::Manager
friend struct Manager
Definition: registry.h:295
mxnet::runtime::PackedFunc
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:80
mxnet::runtime::Registry::set_body
MXNET_DLL Registry & set_body(PackedFunc f)
set the body of the function to be f
mxnet::runtime::Registry::set_body_method
Registry & set_body_method(R(TNode::*f)(Args...))
set the body of the function to be the passed method pointer. Used when calling a method on a Node su...
Definition: registry.h:209