mxnet
registry.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
43 // Acknowledgement: This file originates from incubator-tvm
44 #ifndef MXNET_RUNTIME_REGISTRY_H_
45 #define MXNET_RUNTIME_REGISTRY_H_
46 
47 #include <string>
48 #include <vector>
49 #include "packed_func.h"
50 
51 namespace mxnet {
52 namespace runtime {
53 
55 class Registry {
56  public:
61  MXNET_DLL Registry& set_body(PackedFunc f); // NOLINT(*)
66  Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
67  return set_body(PackedFunc(f));
68  }
83  template<typename FType, typename FLambda>
84  Registry& set_body_typed(FLambda f) {
85  return set_body(TypedPackedFunc<FType>(f).packed());
86  }
87 
109  template<typename R, typename ...Args>
110  Registry& set_body_typed(R (*f)(Args...)) {
111  return set_body(TypedPackedFunc<R(Args...)>(f));
112  }
113 
134  template<typename T, typename R, typename ...Args>
135  Registry& set_body_method(R (T::*f)(Args...)) {
136  return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
137  // call method pointer
138  return (target.*f)(params...);
139  });
140  }
141 
162  template<typename T, typename R, typename ...Args>
163  Registry& set_body_method(R (T::*f)(Args...) const) {
164  return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
165  // call method pointer
166  return (target.*f)(params...);
167  });
168  }
169 
200  template<typename TObjectRef, typename TNode, typename R, typename ...Args,
201  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
202  Registry& set_body_method(R (TNode::*f)(Args...)) {
203  return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
204  TNode* target = ref.operator->();
205  // call method pointer
206  return (target->*f)(params...);
207  });
208  }
209 
240  template<typename TObjectRef, typename TNode, typename R, typename ...Args,
241  typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
242  Registry& set_body_method(R (TNode::*f)(Args...) const) {
243  return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
244  const TNode* target = ref.operator->();
245  // call method pointer
246  return (target->*f)(params...);
247  });
248  }
249 
256  MXNET_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
262  MXNET_DLL static bool Remove(const std::string& name);
269  MXNET_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*)
274  MXNET_DLL static std::vector<std::string> ListNames();
275 
276  // Internal class.
277  struct Manager;
278 
279  protected:
281  std::string name_;
284  friend struct Manager;
285 };
286 
288 #if defined(__GNUC__)
289 #define MXNET_ATTRIBUTE_UNUSED __attribute__((unused))
290 #else
291 #define MXNET_ATTRIBUTE_UNUSED
292 #endif
293 
294 #define MXNET_STR_CONCAT_(__x, __y) __x##__y
295 #define MXNET_STR_CONCAT(__x, __y) MXNET_STR_CONCAT_(__x, __y)
296 
297 #define MXNET_FUNC_REG_VAR_DEF \
298  static MXNET_ATTRIBUTE_UNUSED ::mxnet::runtime::Registry& __mk_ ## MXNET
299 
308 #define MXNET_REGISTER_GLOBAL(OpName) \
309  MXNET_STR_CONCAT(MXNET_FUNC_REG_VAR_DEF, __COUNTER__) = \
310  ::mxnet::runtime::Registry::Register(OpName)
311 
312 } // namespace runtime
313 } // namespace mxnet
314 #endif // MXNET_RUNTIME_REGISTRY_H_
Registry & set_body_method(R(TNode::*f)(Args...) const)
set the body of the function to be the passed method pointer. Used when calling a method on a Node su...
Definition: registry.h:242
std::function< void(MXNetArgs args, MXNetRetValue *rv)> FType
The internal std::function.
Definition: packed_func.h:97
Registry & set_body_method(R(T::*f)(Args...) const)
set the body of the function to be the passed method pointer. Note that this will ignore default arg ...
Definition: registry.h:163
MXNET_DLL Registry & set_body(PackedFunc f)
set the body of the function to be f
namespace of mxnet
Definition: api_registry.h:33
static MXNET_DLL Registry & Register(const std::string &name, bool override=false)
Register a function with given name.
static MXNET_DLL std::vector< std::string > ListNames()
Get the names of currently registered global function.
Registry & set_body_typed(R(*f)(Args...))
set the body of the function to the given function pointer. Note that this doesn&#39;t work with lambdas...
Definition: registry.h:110
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:77
Registry & set_body(PackedFunc::FType f)
set the body of the function to be f
Definition: registry.h:66
friend struct Manager
Definition: registry.h:284
Registry & set_body_method(R(T::*f)(Args...))
set the body of the function to be the passed method pointer. Note that this will ignore default arg ...
Definition: registry.h:135
Registry for global function.
Definition: registry.h:55
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:149
Registry & set_body_typed(FLambda f)
set the body of the function to be TypedPackedFunc.
Definition: registry.h:84
static MXNET_DLL const PackedFunc * Get(const std::string &name)
Get the global function by name.
#define MXNET_DLL
MXNET_DLL prefix for windows.
Definition: c_api.h:54
Registry & set_body_method(R(TNode::*f)(Args...))
set the body of the function to be the passed method pointer. Used when calling a method on a Node su...
Definition: registry.h:202
std::string name_
name of the function
Definition: registry.h:277
PackedFunc func_
internal packed function
Definition: registry.h:283
static MXNET_DLL bool Remove(const std::string &name)
Erase global function from registry, if exist.
Type-erased function used across MXNET API.