24 #ifndef MXNET_COMMON_EXEC_UTILS_H_
25 #define MXNET_COMMON_EXEC_UTILS_H_
33 #include "../common/utils.h"
34 #include "../imperative/exec_pass.h"
39 #if MXNET_USE_ONEDNN == 1
41 #define DEFAULT_DATA(x) x.IsDefaultData()
43 #define DEFAULT_DATA(x) (x.storage_type() == kDefaultStorage)
60 const std::vector<NDArray>* bufs,
61 std::vector<TBlob>* blobs,
62 std::vector<NDArray>* temp_src,
63 std::vector<NDArray>* temp_dst,
64 std::unordered_map<uint32_t, uint32_t>* idx_map) {
65 bool require_cast =
false;
66 for (
size_t i = 0; i < src.size(); i++) {
67 const auto& nd = src[i];
69 (*idx_map)[i] = temp_dst->size();
71 bufs !=
nullptr ? bufs->at(i) :
NDArray(nd.shape(), nd.ctx(),
true, nd.dtype());
72 #if MXNET_USE_ONEDNN == 1
73 CHECK(temp.IsDefaultData());
75 temp_src->emplace_back(nd);
76 temp_dst->emplace_back(temp);
77 blobs->emplace_back(temp.
data());
80 blobs->push_back(nd.data());
87 const std::vector<NDArray>* bufs,
88 std::vector<OpReqType>* req,
89 std::vector<TBlob>* blobs,
90 std::vector<NDArray>* temp_src,
91 std::vector<NDArray>* temp_dst) {
92 bool require_cast =
false;
93 for (
size_t i = 0; i < src.size(); i++) {
94 const auto& nd = src[i];
96 #if MXNET_USE_ONEDNN == 1
106 #if MXNET_USE_ONEDNN == 1
108 if (bufs !=
nullptr) {
110 }
else if (
kAddTo == req->at(i)) {
111 temp = nd.IsDNNLData() ? nd.Reorder2Default() : nd;
113 temp =
NDArray(nd.shape(), nd.ctx(),
true, nd.dtype());
115 CHECK(temp.IsDefaultData());
118 bufs !=
nullptr ? bufs->at(i) :
NDArray(nd.shape(), nd.ctx(),
true, nd.dtype());
120 temp_src->emplace_back(nd);
121 temp_dst->emplace_back(temp);
122 blobs->emplace_back(temp.
data());
125 blobs->push_back(nd.data());
139 const std::vector<NDArray>& ndoutputs,
140 const std::vector<NDArray>* in_bufs,
141 const std::vector<NDArray>* out_bufs,
142 std::vector<OpReqType>* req,
143 std::vector<TBlob>* input_blobs,
144 std::vector<TBlob>* output_blobs,
145 std::vector<NDArray>* pre_temp_src,
146 std::vector<NDArray>* pre_temp_dst,
147 std::vector<NDArray>* post_temp_src,
148 std::vector<NDArray>* post_temp_dst,
149 std::unordered_map<uint32_t, uint32_t>* in_temp_idx_map,
150 const std::vector<uint32_t>& mutate_idx) {
152 SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst, in_temp_idx_map);
156 for (
const auto idx : mutate_idx) {
157 auto map_iter = in_temp_idx_map->find(idx);
158 if (map_iter != in_temp_idx_map->end()) {
159 post_temp_src->push_back(pre_temp_dst->at(map_iter->second));
160 post_temp_dst->push_back(ndinputs[idx]);
173 const std::vector<NDArray>& dst,
176 CHECK_EQ(dst.size(), src.size());
177 for (
size_t i = 0; i < src.size(); i++) {
180 CastStorageDispatch<gpu>(ctx, src[i], dst[i]);
185 CastStorageDispatch<cpu>(ctx, src[i], dst[i]);
194 std::vector<int>* iattr,
195 std::vector<int>* oattr) {
197 for (
int v : *oattr) {
204 for (
int v : *iattr) {
213 for (
int& v : *oattr) {
216 for (
int& v : *iattr) {
230 std::vector<int>* iattr,
231 std::vector<int>* oattr) {
232 bool fallback =
false;
233 for (
int& v : *oattr) {
239 for (
int& v : *iattr) {
258 if (storage_id == -1) {
260 }
else if (storage_id == -2) {
261 str =
"external storage (-2)";
293 uint32_t node_start = 0, node_end = idx.num_nodes();
294 if (g.
attrs.count(
"node_range")) {
295 const auto&
range = g.
GetAttr<std::pair<uint32_t, uint32_t> >(
"node_range");
296 node_start =
range.first;
297 node_end =
range.second;
299 for (uint32_t nid = node_start; nid < node_end; ++nid) {
300 const auto& inode = idx[nid];
301 if (inode.source->is_variable()) {
302 LOG(INFO) <<
"node " << nid <<
" var";
304 LOG(INFO) <<
"node " << nid <<
" " << inode.source->attrs.op->name;
305 for (
const auto& e : inode.inputs) {
306 auto eid = idx.entry_id(e);
308 LOG(INFO) <<
"\t\tinput " << eid <<
": " << vshape[eid] <<
" (" << kilo_bytes <<
" KB)";
310 for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
311 uint32_t eid = idx.entry_id(nid, index);
313 LOG(INFO) <<
"\t\toutput " << eid <<
": " << vshape[eid] <<
" (" << kilo_bytes <<
" KB)";
348 uint32_t node_start = 0, node_end = idx.num_nodes();
349 if (g.
attrs.count(
"node_range")) {
350 const auto&
range = g.
GetAttr<std::pair<uint32_t, uint32_t> >(
"node_range");
351 node_start =
range.first;
352 node_end =
range.second;
354 for (uint32_t nid = node_start; nid < node_end; ++nid) {
355 const auto& inode = idx[nid];
356 if (inode.source->is_variable()) {
357 LOG(INFO) <<
"node " << nid <<
" var";
359 LOG(INFO) <<
"node " << nid <<
" " << inode.source->attrs.op->name <<
": "
361 for (
const auto& e : inode.inputs) {
362 auto eid = idx.entry_id(e);
363 LOG(INFO) <<
"\t\tinput " << eid <<
": " <<
stype_string(vstorage_type[eid]);
365 for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
366 uint32_t eid = idx.entry_id(nid, index);
367 LOG(INFO) <<
"\t\toutput " << eid <<
": " <<
stype_string(vstorage_type[eid]);
383 const int dest_arg_dtype,
386 std::unordered_map<std::string, NDArray>* shared_buffer,
387 bool enable_row_sparse_sharing) {
389 if (enable_row_sparse_sharing) {
392 auto it = shared_buffer->find(name);
393 if (it != shared_buffer->end()) {
395 bool size_shareable = it->second.shape().Size() >= dest_arg_shape.
Size();
396 if (size_shareable && stype_shareable) {
397 CHECK_EQ(it->second.dtype(), dest_arg_dtype)
398 <<
"Requested arg array's dtype does not match that of the reusable ndarray";
399 CHECK_EQ(it->second.storage_type(), dest_arg_stype)
400 <<
"Requested arg array's stype does not match that of the reusable ndarray";
401 return it->second.Reshape(dest_arg_shape);
402 }
else if (stype_shareable) {
403 LOG(WARNING) <<
"Bucketing: data " << name <<
" has a shape " << dest_arg_shape
404 <<
", which is larger than already allocated shape " << it->second.shape()
405 <<
". Need to re-allocate. Consider putting default bucket key to be "
406 <<
"the bucket taking the largest input for better memory sharing.";
409 it->second =
InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
413 return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
416 auto ret =
InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype);
417 if (stype_shareable) {
418 shared_buffer->emplace(name, ret);
430 const std::map<std::string, Context>& ctx_map,
431 const std::vector<Context>& in_arg_ctxes,
432 const std::vector<Context>& arg_grad_ctxes,
433 const std::vector<Context>& aux_state_ctxes,
434 const std::vector<OpReqType>& grad_req_types,
435 size_t num_forward_inputs,
436 size_t num_forward_outputs) {
440 if (ctx_map.size() == 0) {
442 std::make_shared<nnvm::any>(exec::ContextVector(idx.num_nodes(), default_ctx));
443 for (
const auto& x : in_arg_ctxes) {
444 CHECK(x == default_ctx) <<
"Input array is in " << x
445 <<
" while binding with ctx=" << default_ctx
446 <<
". All arguments must be in global context (" << default_ctx
447 <<
") unless group2ctx is specified for cross-device graph.";
449 for (
const auto& x : arg_grad_ctxes) {
450 CHECK(x == default_ctx) <<
"Gradient array is in " << x
451 <<
" while binding with ctx=" << default_ctx
452 <<
". All gradients must be in global context (" << default_ctx
453 <<
") unless group2ctx is specified for cross-device graph.";
459 std::map<Context, int> ctx2id;
460 std::vector<Context> ctx_list;
466 for (
auto& kv : ctx_map) {
467 if (ctx2id.count(kv.second) == 0) {
468 ctx2id[kv.second] =
static_cast<int>(ctx_list.size());
469 ctx_list.push_back(kv.second);
472 device_map[kv.first] = ctx2id.at(kv.second);
477 size_t arg_top = 0, aux_top = 0;
478 for (
size_t i = 0; i < num_forward_inputs; ++i) {
479 const uint32_t nid = idx.input_nodes().at(i);
481 if (mutable_nodes.count(nid)) {
482 CHECK_LT(aux_top, aux_state_ctxes.size());
483 ctx = aux_state_ctxes[aux_top];
486 CHECK_LT(arg_top, in_arg_ctxes.size());
487 ctx = in_arg_ctxes[arg_top];
490 if (ctx2id.count(ctx) == 0) {
491 ctx2id[ctx] =
static_cast<int>(ctx_list.size());
492 ctx_list.push_back(ctx);
494 device[nid] = ctx2id.at(ctx);
499 size_t arg_grad_offset = 0;
502 CHECK_GE(grad_req_types.size(), g.
outputs.size() - num_forward_outputs)
503 <<
"insufficient number of grad_reqs";
504 for (
size_t i = num_forward_outputs; i < g.
outputs.size(); ++i, ++arg_grad_offset) {
505 while (grad_req_types[arg_grad_offset] ==
kNullOp)
507 const uint32_t nid = idx.outputs()[i].node_id;
508 Context ctx = arg_grad_ctxes[arg_grad_offset];
509 if (ctx2id.count(ctx) == 0) {
510 ctx2id[ctx] =
static_cast<int>(ctx_list.size());
511 ctx_list.push_back(ctx);
513 int devid = ctx2id.at(ctx);
514 if (device[nid] != -1) {
515 CHECK_EQ(device[nid], devid) <<
"device of same output not equal to each other";
521 g.
attrs[
"device"] = std::make_shared<dmlc::any>(std::move(device));
525 exec::ContextVector vcontext;
526 for (
auto context : assigned_devices) {
528 vcontext.push_back(default_ctx);
530 vcontext.push_back(ctx_list[context]);
539 for (
size_t i = num_forward_outputs; i < g.
outputs.size(); ++i, ++arg_grad_offset) {
540 while (grad_req_types[arg_grad_offset] ==
kNullOp)
542 const uint32_t nid = new_idx.
outputs()[i].node_id;
543 Context ctx = arg_grad_ctxes[arg_grad_offset];
544 CHECK(ctx == vcontext[nid]) <<
"Trying to save gradient to " << ctx
545 <<
" while its source node \"" << new_idx[nid].source->attrs.name
546 <<
"\" computes it on " << vcontext[nid]
547 <<
". Check your ctx in NDArray allocation.";
550 g.
attrs[
"context"] = std::make_shared<nnvm::any>(std::move(vcontext));
575 #endif // MXNET_COMMON_EXEC_UTILS_H_