mxnet
kvstore.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_KVSTORE_H_
25 #define MXNET_KVSTORE_H_
26 #include <dmlc/io.h>
27 #include <vector>
28 #include <utility>
29 #include <unordered_map>
30 #include <string>
31 #include <functional>
32 #include <atomic>
33 #include "../../src/kvstore/gradient_compression.h"
34 #include "./ndarray.h"
35 #if MXNET_USE_DIST_KVSTORE
36 #include "ps/ps.h"
37 #endif // MXNET_USE_DIST_KVSTORE
38 
39 namespace mxnet {
40 
49 
56 class KVStore {
57  public:
59  virtual ~KVStore() {}
60 
71  static KVStore* Create(const char* type = "local");
72 
76  inline const std::string& type() {
77  return type_;
78  }
79 
85  virtual void SetGradientCompression(
86  const std::vector<std::pair<std::string, std::string>>& kwargs) = 0;
87 
104  virtual void Init(const std::vector<int>& keys, const std::vector<NDArray>& values) = 0;
110  virtual void Init(const std::vector<std::string>& str_keys,
111  const std::vector<NDArray>& values) = 0;
148  virtual void Push(const std::vector<int>& keys,
149  const std::vector<NDArray>& values,
150  int priority = 0) = 0;
151 
158  virtual void Push(const std::vector<std::string>& str_keys,
159  const std::vector<NDArray>& values,
160  int priority = 0) = 0;
185  virtual void Pull(const std::vector<int>& keys,
186  const std::vector<NDArray*>& values,
187  int priority = 0,
188  bool ignore_sparse = true) = 0;
196  virtual void Pull(const std::vector<std::string>& str_keys,
197  const std::vector<NDArray*>& values,
198  int priority = 0,
199  bool ignore_sparse = true) = 0;
200 
209  virtual void Broadcast(const std::vector<int>& vkeys,
210  const std::vector<int>& okeys,
211  const std::vector<NDArray>& values,
212  const std::vector<NDArray*>& outs,
213  int priority = 0) = 0;
214 
224  virtual void Broadcast(const std::vector<std::string>& str_vkeys,
225  const std::vector<std::string>& str_okeys,
226  const std::vector<NDArray>& values,
227  const std::vector<NDArray*>& outs,
228  int priority = 0) = 0;
229 
238  virtual void PushPull(const std::vector<int>& vkeys,
239  const std::vector<int>& okeys,
240  const std::vector<NDArray>& values,
241  const std::vector<NDArray*>& outs,
242  int priority = 0) = 0;
243 
253  virtual void PushPull(const std::vector<std::string>& str_vkeys,
254  const std::vector<std::string>& str_okeys,
255  const std::vector<NDArray>& values,
256  const std::vector<NDArray*>& outs,
257  int priority = 0) = 0;
266  virtual void PullRowSparse(const std::vector<int>& str_keys,
267  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
268  int priority = 0) = 0;
269 
278  virtual void PullRowSparse(const std::vector<std::string>& str_keys,
279  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
280  int priority = 0) = 0;
281 
285  typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
289  typedef std::function<void(const std::string&, const NDArray&, NDArray*)> StrUpdater;
299  virtual void set_updater(const Updater& updater) {
300  CHECK(updater) << "invalid updater";
301  updater_ = updater;
302  }
303 
313  virtual void set_updater(const StrUpdater& updater) {
314  CHECK(updater) << "invalid updater";
315  str_updater_ = updater;
316  }
317 
318  /******************************************************
319  * the following are used for multi-machines.
320  ******************************************************/
321 
326  static void InitPSEnv(const std::unordered_map<std::string, std::string>& envs) {
327 #if MXNET_USE_DIST_KVSTORE
328  ps::Environment::Init(envs);
329 #else
330  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
331 #endif // MXNET_USE_DIST_KVSTORE
332  }
333 
339  static bool IsWorkerNode() {
340 #if MXNET_USE_DIST_KVSTORE
341  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
342  return (role_str == nullptr) || (!strcmp(role_str, "worker"));
343 #else
344  return true;
345 #endif // MXNET_USE_DIST_KVSTORE
346  }
347 
353  static bool IsServerNode() {
354 #if MXNET_USE_DIST_KVSTORE
355  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
356  return (role_str != nullptr) && (!strcmp(role_str, "server"));
357 #else
358  return false;
359 #endif // MXNET_USE_DIST_KVSTORE
360  }
361 
362  void set_barrier_before_exit(const bool barrier_before_exit) {
363 #if MXNET_USE_DIST_KVSTORE
364  if (!IsWorkerNode())
365  LOG(FATAL) << "barrier_before_exit takes effect only on worker nodes";
366  barrier_before_exit_ = barrier_before_exit;
367 #else
368  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to enable barrier";
369 #endif
370  }
371 
377  static bool IsSchedulerNode() {
378 #if MXNET_USE_DIST_KVSTORE
379  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
380  return (role_str != nullptr) && (!strcmp(role_str, "scheduler"));
381 #else
382  return false;
383 #endif // MXNET_USE_DIST_KVSTORE
384  }
385 
392  virtual int get_rank() const {
393  return 0;
394  }
395 
399  virtual int get_group_size() const {
400  return 1;
401  }
402 
411  virtual int get_num_dead_node(int node_id, int timeout = 60) const {
412  return 0;
413  }
414 
422  virtual void Barrier() {}
423 
435  virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) {}
436 
444  const std::string& params) {
445  LOG(INFO) << "Unable to pass server the profiler command. If you are using "
446  << "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
447  << "If you are training on single machine, then there is no server process"
448  << "to profile. Please profile the worker process instead.";
449  }
450 
454  typedef std::function<void(int, const std::string&)> Controller;
455 
469  virtual void RunServer(const Controller& controller) {}
470 
471  protected:
476 
481 
485  std::string type_;
486 
491  std::shared_ptr<kvstore::GradientCompression> gradient_compression_;
492 
496  std::atomic<bool> barrier_before_exit_{true};
497 };
498 
499 } // namespace mxnet
500 #endif // MXNET_KVSTORE_H_
mxnet::KVStore::set_updater
virtual void set_updater(const StrUpdater &updater)
set an updater with string keys
Definition: kvstore.h:313
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::KVStore::PullRowSparse
virtual void PullRowSparse(const std::vector< int > &str_keys, const std::vector< std::pair< NDArray *, NDArray >> &val_rowids, int priority=0)=0
pull a list of key-value pairs from the store. The NDArray pulled back will be in row_sparse storage ...
mxnet::KVStore::SetServerProfilerCommand
virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type, const std::string &params)
Sends server profiler commands to all server nodes Only the worker with rank=0 sends the command whic...
Definition: kvstore.h:443
mxnet::KVStore::set_barrier_before_exit
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:362
mxnet::KVStore::~KVStore
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:59
mxnet::KVStore::InitPSEnv
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:326
mxnet::KVStore::str_updater_
StrUpdater str_updater_
the user-defined updater with string keys
Definition: kvstore.h:480
mxnet::KVStore::IsWorkerNode
static bool IsWorkerNode()
Definition: kvstore.h:339
mxnet::KVStore::SetGradientCompression
virtual void SetGradientCompression(const std::vector< std::pair< std::string, std::string >> &kwargs)=0
Set parameters to use low-bit compressed gradients.
mxnet::KVStore::type
const std::string & type()
return the type
Definition: kvstore.h:76
mxnet::KVStore::Pull
virtual void Pull(const std::vector< int > &keys, const std::vector< NDArray * > &values, int priority=0, bool ignore_sparse=true)=0
pull a list of key-value pairs from the store
mxnet::KVStore::SendCommandToServers
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:435
mxnet::KVStore::updater_
Updater updater_
the user-defined updater
Definition: kvstore.h:475
mxnet::KVStore::StrUpdater
std::function< void(const std::string &, const NDArray &, NDArray *)> StrUpdater
the prototype of user-defined updater with string keys
Definition: kvstore.h:289
mxnet::KVStore::type_
std::string type_
the kvstore type
Definition: kvstore.h:485
mxnet::KVStore::IsServerNode
static bool IsServerNode()
Definition: kvstore.h:353
mxnet::KVStoreServerProfilerCommand::kSetConfig
@ kSetConfig
mxnet::KVStore::get_num_dead_node
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:411
mxnet::KVStoreServerProfilerCommand::kPause
@ kPause
mxnet::KVStore::gradient_compression_
std::shared_ptr< kvstore::GradientCompression > gradient_compression_
Gradient compression object starts with GC_NONE mode Used if SetGradientCompression sets the type....
Definition: kvstore.h:491
mxnet::KVStore::Create
static KVStore * Create(const char *type="local")
Factory function to create a new KVStore.
mxnet::KVStore::Controller
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:454
mxnet::KVStore::barrier_before_exit_
std::atomic< bool > barrier_before_exit_
whether to do barrier when finalize
Definition: kvstore.h:496
mxnet::NDArray
ndarray interface
Definition: ndarray.h:82
mxnet::KVStore
distributed key-value store
Definition: kvstore.h:56
mxnet::KVStore::get_rank
virtual int get_rank() const
Definition: kvstore.h:392
mxnet::KVStore::RunServer
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:469
mxnet::KVStoreServerProfilerCommand::kState
@ kState
io.h
defines serializable interface of dmlc
mxnet::KVStore::Push
virtual void Push(const std::vector< int > &keys, const std::vector< NDArray > &values, int priority=0)=0
push a list of key-value pairs into the store
mxnet::KVStore::Updater
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:285
mxnet::KVStore::set_updater
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:299
mxnet::KVStore::Init
virtual void Init(const std::vector< int > &keys, const std::vector< NDArray > &values)=0
Initialize a list of key-value pair to the store.
mxnet::KVStore::get_group_size
virtual int get_group_size() const
Definition: kvstore.h:399
mxnet::KVStore::Barrier
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:422
mxnet::KVStore::PushPull
virtual void PushPull(const std::vector< int > &vkeys, const std::vector< int > &okeys, const std::vector< NDArray > &values, const std::vector< NDArray * > &outs, int priority=0)=0
push and pull a list of key-value pairs from the store
mxnet::KVStoreServerProfilerCommand
KVStoreServerProfilerCommand
enum to denote types of commands kvstore sends to server regarding profiler kSetConfig sets profiler ...
Definition: kvstore.h:48
mxnet::KVStoreServerProfilerCommand::kDump
@ kDump
ndarray.h
NDArray interface that handles array arithematics.
mxnet::KVStore::Broadcast
virtual void Broadcast(const std::vector< int > &vkeys, const std::vector< int > &okeys, const std::vector< NDArray > &values, const std::vector< NDArray * > &outs, int priority=0)=0
broadcast a list of key-value pairs from the store
mxnet::KVStore::IsSchedulerNode
static bool IsSchedulerNode()
Definition: kvstore.h:377