mxnet
contrib.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_CONTRIB_H_
27 #define MXNET_CPP_CONTRIB_H_
28 
29 #include <iostream>
30 #include <string>
31 #include <map>
32 #include <vector>
33 #include "mxnet-cpp/symbol.h"
34 
35 namespace mxnet {
36 namespace cpp {
37 namespace details {
38 
45  inline std::vector<std::string> split(const std::string& str, const std::string& delimiter) {
46  std::vector<std::string> splitted;
47  size_t last = 0;
48  size_t next = 0;
49  while ((next = str.find(delimiter, last)) != std::string::npos) {
50  splitted.push_back(str.substr(last, next - last));
51  last = next + 1;
52  }
53  splitted.push_back(str.substr(last));
54  return splitted;
55  }
56 
57 } // namespace details
58 
59 namespace contrib {
60 
61  // needs to be same with
62  // https://github.com/apache/incubator-mxnet/blob/1c874cfc807cee755c38f6486e8e0f4d94416cd8/src/operator/subgraph/tensorrt/tensorrt-inl.h#L190
63  static const std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER = "subgraph_params_names";
64  // needs to be same with
65  // https://github.com/apache/incubator-mxnet/blob/master/src/operator/subgraph/tensorrt/tensorrt.cc#L244
66  static const std::string TENSORRT_SUBGRAPH_PARAM_PREFIX = "subgraph_param_";
73  inline void InitTensorRTParams(const mxnet::cpp::Symbol& symbol,
74  std::map<std::string, mxnet::cpp::NDArray> *argParams,
75  std::map<std::string, mxnet::cpp::NDArray> *auxParams) {
76  mxnet::cpp::Symbol internals = symbol.GetInternals();
77  mx_uint numSymbol = internals.GetNumOutputs();
78  for (mx_uint i = 0; i < numSymbol; ++i) {
79  std::map<std::string, std::string> attrs = internals[i].ListAttributes();
80  if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) {
81  std::string new_params_names;
82  std::map<std::string, mxnet::cpp::NDArray> tensorrtParams;
83  std::vector<std::string> keys = details::split(
84  attrs[TENSORRT_SUBGRAPH_PARAM_IDENTIFIER], ";");
85  for (const auto& key : keys) {
86  if (argParams->find(key) != argParams->end()) {
87  new_params_names += key + ";";
88  tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*argParams)[key];
89  argParams->erase(key);
90  } else if (auxParams->find(key) != auxParams->end()) {
91  new_params_names += key + ";";
92  tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*auxParams)[key];
93  auxParams->erase(key);
94  }
95  }
96  std::map<std::string, std::string> new_attrs = {};
97  for (const auto& kv : tensorrtParams) {
98  // passing the ndarray address into TRT node attributes to get the weight
99  uint64_t address = reinterpret_cast<uint64_t>(kv.second.GetHandle());
100  new_attrs[kv.first] = std::to_string(address);
101  }
102  if (!new_attrs.empty()) {
103  internals[i].SetAttributes(new_attrs);
104  internals[i].SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER,
105  new_params_names.substr(0, new_params_names.length() - 1));
106  }
107  }
108  }
109 }
110 
111 } // namespace contrib
112 } // namespace cpp
113 } // namespace mxnet
114 
115 #endif // MXNET_CPP_CONTRIB_H_
definition of symbol
Symbol GetInternals() const
save Symbol into a JSON string the symbol whose outputs are all the internals.
namespace of mxnet
Definition: api_registry.h:33
void SetAttributes(const std::map< std::string, std::string > &attrs)
set a series of key-value attribute to the symbol
void InitTensorRTParams(const mxnet::cpp::Symbol &symbol, std::map< std::string, mxnet::cpp::NDArray > *argParams, std::map< std::string, mxnet::cpp::NDArray > *auxParams)
Definition: contrib.h:73
std::map< std::string, std::string > ListAttributes() const
void SetAttribute(const std::string &key, const std::string &value)
set key-value attribute to the symbol
std::vector< std::string > split(const std::string &str, const std::string &delimiter)
Definition: contrib.h:45
mx_uint GetNumOutputs() const
Symbol interface.
Definition: symbol.h:72
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:58