mxnet
op_map.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_OP_MAP_H_
27 #define MXNET_CPP_OP_MAP_H_
28 
29 #include <map>
30 #include <string>
31 #include "mxnet-cpp/base.h"
32 #include "dmlc/logging.h"
33 
34 namespace mxnet {
35 namespace cpp {
36 
42 class OpMap {
43  public:
47  inline OpMap() {
48  mx_uint num_symbol_creators = 0;
49  AtomicSymbolCreator* symbol_creators = nullptr;
50  int r = MXSymbolListAtomicSymbolCreators(&num_symbol_creators, &symbol_creators);
51  CHECK_EQ(r, 0);
52  for (mx_uint i = 0; i < num_symbol_creators; i++) {
53  const char* name;
54  const char* description;
55  mx_uint num_args;
56  const char** arg_names;
57  const char** arg_type_infos;
58  const char** arg_descriptions;
59  const char* key_var_num_args;
60  r = MXSymbolGetAtomicSymbolInfo(symbol_creators[i],
61  &name,
62  &description,
63  &num_args,
64  &arg_names,
65  &arg_type_infos,
66  &arg_descriptions,
67  &key_var_num_args);
68  CHECK_EQ(r, 0);
69  symbol_creators_[name] = symbol_creators[i];
70  }
71 
72  nn_uint num_ops;
73  const char** op_names;
74  r = NNListAllOpNames(&num_ops, &op_names);
75  CHECK_EQ(r, 0);
76  for (nn_uint i = 0; i < num_ops; i++) {
77  OpHandle handle;
78  r = NNGetOpHandle(op_names[i], &handle);
79  CHECK_EQ(r, 0);
80  op_handles_[op_names[i]] = handle;
81  }
82  }
83 
90  inline AtomicSymbolCreator GetSymbolCreator(const std::string& name) {
91  if (symbol_creators_.count(name) == 0)
92  return GetOpHandle(name);
93  return symbol_creators_[name];
94  }
95 
102  inline OpHandle GetOpHandle(const std::string& name) {
103  return op_handles_[name];
104  }
105 
106  private:
107  std::map<std::string, AtomicSymbolCreator> symbol_creators_;
108  std::map<std::string, OpHandle> op_handles_;
109 };
110 
111 } // namespace cpp
112 } // namespace mxnet
113 
114 #endif // MXNET_CPP_OP_MAP_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
OpHandle
void * OpHandle
handle to a function that takes param and creates symbol
Definition: c_api.h:44
mxnet::cpp::OpMap::OpMap
OpMap()
Create an Mxnet instance.
Definition: op_map.h:47
MXSymbolGetAtomicSymbolInfo
MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, const char **name, const char **description, uint32_t *num_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **key_var_num_args, const char **return_type DEFAULT(NULL))
Get the detailed information about atomic symbol.
MXSymbolListAtomicSymbolCreators
MXNET_DLL int MXSymbolListAtomicSymbolCreators(uint32_t *out_size, AtomicSymbolCreator **out_array)
list all the available AtomicSymbolEntry
NNListAllOpNames
NNVM_DLL int NNListAllOpNames(nn_uint *out_size, const char ***out_array)
list all the available operator names, include entries.
nn_uint
unsigned int nn_uint
manually define unsigned int
Definition: c_api.h:41
mxnet::cpp::OpMap
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name....
Definition: op_map.h:42
mxnet::cpp::OpMap::GetOpHandle
OpHandle GetOpHandle(const std::string &name)
Get an op handle with its name.
Definition: op_map.h:102
NNGetOpHandle
NNVM_DLL int NNGetOpHandle(const char *op_name, OpHandle *op_out)
Get operator handle given name.
AtomicSymbolCreator
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:78
mxnet::cpp::OpMap::GetSymbolCreator
AtomicSymbolCreator GetSymbolCreator(const std::string &name)
Get a symbol creator with its name.
Definition: op_map.h:90
base.h
base definitions for mxnetcpp
mx_uint
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:65