Go to the documentation of this file.
24 #ifndef MXNET_KVSTORE_H_
25 #define MXNET_KVSTORE_H_
29 #include <unordered_map>
33 #include "../../src/kvstore/gradient_compression.h"
35 #if MXNET_USE_DIST_KVSTORE
37 #endif // MXNET_USE_DIST_KVSTORE
76 inline const std::string&
type() {
86 const std::vector<std::pair<std::string, std::string>>& kwargs) = 0;
104 virtual void Init(
const std::vector<int>& keys,
const std::vector<NDArray>& values) = 0;
110 virtual void Init(
const std::vector<std::string>& str_keys,
111 const std::vector<NDArray>& values) = 0;
148 virtual void Push(
const std::vector<int>& keys,
149 const std::vector<NDArray>& values,
150 int priority = 0) = 0;
158 virtual void Push(
const std::vector<std::string>& str_keys,
159 const std::vector<NDArray>& values,
160 int priority = 0) = 0;
185 virtual void Pull(
const std::vector<int>& keys,
186 const std::vector<NDArray*>& values,
188 bool ignore_sparse =
true) = 0;
196 virtual void Pull(
const std::vector<std::string>& str_keys,
197 const std::vector<NDArray*>& values,
199 bool ignore_sparse =
true) = 0;
209 virtual void Broadcast(
const std::vector<int>& vkeys,
210 const std::vector<int>& okeys,
211 const std::vector<NDArray>& values,
212 const std::vector<NDArray*>& outs,
213 int priority = 0) = 0;
224 virtual void Broadcast(
const std::vector<std::string>& str_vkeys,
225 const std::vector<std::string>& str_okeys,
226 const std::vector<NDArray>& values,
227 const std::vector<NDArray*>& outs,
228 int priority = 0) = 0;
238 virtual void PushPull(
const std::vector<int>& vkeys,
239 const std::vector<int>& okeys,
240 const std::vector<NDArray>& values,
241 const std::vector<NDArray*>& outs,
242 int priority = 0) = 0;
253 virtual void PushPull(
const std::vector<std::string>& str_vkeys,
254 const std::vector<std::string>& str_okeys,
255 const std::vector<NDArray>& values,
256 const std::vector<NDArray*>& outs,
257 int priority = 0) = 0;
267 const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
268 int priority = 0) = 0;
278 virtual void PullRowSparse(
const std::vector<std::string>& str_keys,
279 const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
280 int priority = 0) = 0;
300 CHECK(updater) <<
"invalid updater";
314 CHECK(updater) <<
"invalid updater";
326 static void InitPSEnv(
const std::unordered_map<std::string, std::string>& envs) {
327 #if MXNET_USE_DIST_KVSTORE
328 ps::Environment::Init(envs);
330 LOG(FATAL) <<
"compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
331 #endif // MXNET_USE_DIST_KVSTORE
340 #if MXNET_USE_DIST_KVSTORE
341 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
342 return (role_str ==
nullptr) || (!strcmp(role_str,
"worker"));
345 #endif // MXNET_USE_DIST_KVSTORE
354 #if MXNET_USE_DIST_KVSTORE
355 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
356 return (role_str !=
nullptr) && (!strcmp(role_str,
"server"));
359 #endif // MXNET_USE_DIST_KVSTORE
363 #if MXNET_USE_DIST_KVSTORE
365 LOG(FATAL) <<
"barrier_before_exit takes effect only on worker nodes";
368 LOG(FATAL) <<
"compile with USE_DIST_KVSTORE=1 to enable barrier";
378 #if MXNET_USE_DIST_KVSTORE
379 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
380 return (role_str !=
nullptr) && (!strcmp(role_str,
"scheduler"));
383 #endif // MXNET_USE_DIST_KVSTORE
444 const std::string& params) {
445 LOG(INFO) <<
"Unable to pass server the profiler command. If you are using "
446 <<
"distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
447 <<
"If you are training on single machine, then there is no server process"
448 <<
"to profile. Please profile the worker process instead.";
454 typedef std::function<void(
int,
const std::string&)>
Controller;
500 #endif // MXNET_KVSTORE_H_
virtual void set_updater(const StrUpdater &updater)
set an updater with string keys
Definition: kvstore.h:313
namespace of mxnet
Definition: api_registry.h:33
virtual void PullRowSparse(const std::vector< int > &str_keys, const std::vector< std::pair< NDArray *, NDArray >> &val_rowids, int priority=0)=0
pull a list of key-value pairs from the store. The NDArray pulled back will be in row_sparse storage ...
virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type, const std::string ¶ms)
Sends server profiler commands to all server nodes Only the worker with rank=0 sends the command whic...
Definition: kvstore.h:443
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:362
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:59
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:326
StrUpdater str_updater_
the user-defined updater with string keys
Definition: kvstore.h:480
static bool IsWorkerNode()
Definition: kvstore.h:339
virtual void SetGradientCompression(const std::vector< std::pair< std::string, std::string >> &kwargs)=0
Set parameters to use low-bit compressed gradients.
const std::string & type()
return the type
Definition: kvstore.h:76
virtual void Pull(const std::vector< int > &keys, const std::vector< NDArray * > &values, int priority=0, bool ignore_sparse=true)=0
pull a list of key-value pairs from the store
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:435
Updater updater_
the user-defined updater
Definition: kvstore.h:475
std::function< void(const std::string &, const NDArray &, NDArray *)> StrUpdater
the prototype of user-defined updater with string keys
Definition: kvstore.h:289
std::string type_
the kvstore type
Definition: kvstore.h:485
static bool IsServerNode()
Definition: kvstore.h:353
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:411
std::shared_ptr< kvstore::GradientCompression > gradient_compression_
Gradient compression object starts with GC_NONE mode Used if SetGradientCompression sets the type....
Definition: kvstore.h:491
static KVStore * Create(const char *type="local")
Factory function to create a new KVStore.
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:454
std::atomic< bool > barrier_before_exit_
whether to do barrier when finalize
Definition: kvstore.h:496
ndarray interface
Definition: ndarray.h:82
distributed key-value store
Definition: kvstore.h:56
virtual int get_rank() const
Definition: kvstore.h:392
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:469
defines serializable interface of dmlc
virtual void Push(const std::vector< int > &keys, const std::vector< NDArray > &values, int priority=0)=0
push a list of key-value pairs into the store
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:285
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:299
virtual void Init(const std::vector< int > &keys, const std::vector< NDArray > &values)=0
Initialize a list of key-value pair to the store.
virtual int get_group_size() const
Definition: kvstore.h:399
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:422
virtual void PushPull(const std::vector< int > &vkeys, const std::vector< int > &okeys, const std::vector< NDArray > &values, const std::vector< NDArray * > &outs, int priority=0)=0
push and pull a list of key-value pairs from the store
KVStoreServerProfilerCommand
enum to denote types of commands kvstore sends to server regarding profiler kSetConfig sets profiler ...
Definition: kvstore.h:48
NDArray interface that handles array arithematics.
virtual void Broadcast(const std::vector< int > &vkeys, const std::vector< int > &okeys, const std::vector< NDArray > &values, const std::vector< NDArray * > &outs, int priority=0)=0
broadcast a list of key-value pairs from the store
static bool IsSchedulerNode()
Definition: kvstore.h:377