26 #ifndef MXNET_CPP_CONTRIB_H_ 27 #define MXNET_CPP_CONTRIB_H_ 45 inline std::vector<std::string>
split(
const std::string& str,
const std::string& delimiter) {
46 std::vector<std::string> splitted;
49 while ((next = str.find(delimiter, last)) != std::string::npos) {
50 splitted.push_back(str.substr(last, next - last));
53 splitted.push_back(str.substr(last));
63 static const std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER =
"subgraph_params_names";
66 static const std::string TENSORRT_SUBGRAPH_PARAM_PREFIX =
"subgraph_param_";
74 std::map<std::string, mxnet::cpp::NDArray> *argParams,
75 std::map<std::string, mxnet::cpp::NDArray> *auxParams) {
78 for (
mx_uint i = 0; i < numSymbol; ++i) {
79 std::map<std::string, std::string> attrs = internals[i].
ListAttributes();
80 if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) {
81 std::string new_params_names;
82 std::map<std::string, mxnet::cpp::NDArray> tensorrtParams;
84 attrs[TENSORRT_SUBGRAPH_PARAM_IDENTIFIER],
";");
85 for (
const auto& key : keys) {
86 if (argParams->find(key) != argParams->end()) {
87 new_params_names += key +
";";
88 tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*argParams)[key];
89 argParams->erase(key);
90 }
else if (auxParams->find(key) != auxParams->end()) {
91 new_params_names += key +
";";
92 tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*auxParams)[key];
93 auxParams->erase(key);
96 std::map<std::string, std::string> new_attrs = {};
97 for (
const auto& kv : tensorrtParams) {
99 uint64_t address =
reinterpret_cast<uint64_t
>(kv.second.GetHandle());
100 new_attrs[kv.first] = std::to_string(address);
102 if (!new_attrs.empty()) {
104 internals[i].
SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER,
105 new_params_names.substr(0, new_params_names.length() - 1));
115 #endif // MXNET_CPP_CONTRIB_H_
Symbol GetInternals() const
save Symbol into a JSON string the symbol whose outputs are all the internals.
namespace of mxnet
Definition: api_registry.h:33
void SetAttributes(const std::map< std::string, std::string > &attrs)
set a series of key-value attribute to the symbol
void InitTensorRTParams(const mxnet::cpp::Symbol &symbol, std::map< std::string, mxnet::cpp::NDArray > *argParams, std::map< std::string, mxnet::cpp::NDArray > *auxParams)
Definition: contrib.h:73
std::map< std::string, std::string > ListAttributes() const
void SetAttribute(const std::string &key, const std::string &value)
set key-value attribute to the symbol
std::vector< std::string > split(const std::string &str, const std::string &delimiter)
Definition: contrib.h:45
mx_uint GetNumOutputs() const
Symbol interface.
Definition: symbol.h:72
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:58