25 #ifndef MXNET_CPP_CONTRIB_H_
26 #define MXNET_CPP_CONTRIB_H_
44 inline std::vector<std::string>
split(
const std::string& str,
const std::string& delimiter) {
45 std::vector<std::string> splitted;
48 while ((next = str.find(delimiter, last)) != std::string::npos) {
49 splitted.push_back(str.substr(last, next - last));
52 splitted.push_back(str.substr(last));
62 static const std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER =
"subgraph_params_names";
65 static const std::string TENSORRT_SUBGRAPH_PARAM_PREFIX =
"subgraph_param_";
76 std::map<std::string, mxnet::cpp::NDArray>* argParams,
77 std::map<std::string, mxnet::cpp::NDArray>* auxParams) {
80 for (
mx_uint i = 0; i < numSymbol; ++i) {
81 std::map<std::string, std::string> attrs = internals[i].
ListAttributes();
82 if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) {
83 std::string new_params_names;
84 std::map<std::string, mxnet::cpp::NDArray> tensorrtParams;
85 std::vector<std::string> keys =
87 for (
const auto& key : keys) {
88 if (argParams->find(key) != argParams->end()) {
89 new_params_names += key +
";";
90 tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*argParams)[key];
91 argParams->erase(key);
92 }
else if (auxParams->find(key) != auxParams->end()) {
93 new_params_names += key +
";";
94 tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*auxParams)[key];
95 auxParams->erase(key);
98 std::map<std::string, std::string> new_attrs = {};
99 for (
const auto& kv : tensorrtParams) {
101 uint64_t address =
reinterpret_cast<uint64_t
>(kv.second.GetHandle());
104 if (!new_attrs.empty()) {
106 internals[i].
SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER,
107 new_params_names.substr(0, new_params_names.length() - 1));
117 #endif // MXNET_CPP_CONTRIB_H_