mxnet
contrib.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_CPP_CONTRIB_H_
26 #define MXNET_CPP_CONTRIB_H_
27 
28 #include <iostream>
29 #include <string>
30 #include <map>
31 #include <vector>
32 #include "mxnet-cpp/symbol.h"
33 
34 namespace mxnet {
35 namespace cpp {
36 namespace details {
37 
44 inline std::vector<std::string> split(const std::string& str, const std::string& delimiter) {
45  std::vector<std::string> splitted;
46  size_t last = 0;
47  size_t next = 0;
48  while ((next = str.find(delimiter, last)) != std::string::npos) {
49  splitted.push_back(str.substr(last, next - last));
50  last = next + 1;
51  }
52  splitted.push_back(str.substr(last));
53  return splitted;
54 }
55 
56 } // namespace details
57 
58 namespace contrib {
59 
60 // needs to be same with
61 // https://github.com/apache/mxnet/blob/1c874cfc807cee755c38f6486e8e0f4d94416cd8/src/operator/subgraph/tensorrt/tensorrt-inl.h#L190
62 static const std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER = "subgraph_params_names"; // NOLINT
63 // needs to be same with
64 // https://github.com/apache/mxnet/blob/master/src/operator/subgraph/tensorrt/tensorrt.cc#L244
65 static const std::string TENSORRT_SUBGRAPH_PARAM_PREFIX = "subgraph_param_"; // NOLINT
75 inline void InitTensorRTParams(const mxnet::cpp::Symbol& symbol,
76  std::map<std::string, mxnet::cpp::NDArray>* argParams,
77  std::map<std::string, mxnet::cpp::NDArray>* auxParams) {
78  mxnet::cpp::Symbol internals = symbol.GetInternals();
79  mx_uint numSymbol = internals.GetNumOutputs();
80  for (mx_uint i = 0; i < numSymbol; ++i) {
81  std::map<std::string, std::string> attrs = internals[i].ListAttributes();
82  if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) {
83  std::string new_params_names;
84  std::map<std::string, mxnet::cpp::NDArray> tensorrtParams;
85  std::vector<std::string> keys =
86  details::split(attrs[TENSORRT_SUBGRAPH_PARAM_IDENTIFIER], ";");
87  for (const auto& key : keys) {
88  if (argParams->find(key) != argParams->end()) {
89  new_params_names += key + ";";
90  tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*argParams)[key];
91  argParams->erase(key);
92  } else if (auxParams->find(key) != auxParams->end()) {
93  new_params_names += key + ";";
94  tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*auxParams)[key];
95  auxParams->erase(key);
96  }
97  }
98  std::map<std::string, std::string> new_attrs = {};
99  for (const auto& kv : tensorrtParams) {
100  // passing the ndarray address into TRT node attributes to get the weight
101  uint64_t address = reinterpret_cast<uint64_t>(kv.second.GetHandle());
102  new_attrs[kv.first] = std::to_string(address);
103  }
104  if (!new_attrs.empty()) {
105  internals[i].SetAttributes(new_attrs);
106  internals[i].SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER,
107  new_params_names.substr(0, new_params_names.length() - 1));
108  }
109  }
110  }
111 }
112 
113 } // namespace contrib
114 } // namespace cpp
115 } // namespace mxnet
116 
117 #endif // MXNET_CPP_CONTRIB_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::cuda::rtc::util::to_string
std::string to_string(OpReqType req)
Convert OpReqType to string.
mxnet::cpp::Symbol::SetAttribute
void SetAttribute(const std::string &key, const std::string &value)
set key-value attribute to the symbol
mxnet::cpp::contrib::InitTensorRTParams
void InitTensorRTParams(const mxnet::cpp::Symbol &symbol, std::map< std::string, mxnet::cpp::NDArray > *argParams, std::map< std::string, mxnet::cpp::NDArray > *auxParams)
Definition: contrib.h:75
symbol.h
definition of symbol
mxnet::cpp::Symbol::GetNumOutputs
mx_uint GetNumOutputs() const
mxnet::cpp::Symbol::GetInternals
Symbol GetInternals() const
save Symbol into a JSON string \retutrn the symbol whose outputs are all the internals.
mxnet::cpp::Symbol::ListAttributes
std::map< std::string, std::string > ListAttributes() const
mxnet::cpp::Symbol::SetAttributes
void SetAttributes(const std::map< std::string, std::string > &attrs)
set a series of key-value attribute to the symbol
mxnet::cpp::details::split
std::vector< std::string > split(const std::string &str, const std::string &delimiter)
Definition: contrib.h:44
mx_uint
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:65
mxnet::cpp::Symbol
Symbol interface.
Definition: symbol.h:73