mxnet
kvstore.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_KVSTORE_H_
25 #define MXNET_KVSTORE_H_
26 #include <dmlc/io.h>
27 #include <vector>
28 #include <unordered_map>
29 #include <string>
30 #include <functional>
31 #include <atomic>
32 #include "./ndarray.h"
33 #if MXNET_USE_DIST_KVSTORE
34 #include "ps/ps.h"
35 #endif // MXNET_USE_DIST_KVSTORE
36 
37 namespace mxnet {
44 class KVStore {
45  public:
47  virtual ~KVStore() {}
48 
59  static KVStore *Create(const char *type = "local");
60 
64  inline const std::string& type() { return type_; }
65 
82  virtual void Init(const std::vector<int>& keys,
83  const std::vector<NDArray>& values) = 0;
89  virtual void Init(const std::vector<std::string>& str_keys,
90  const std::vector<NDArray>& values) = 0;
127  virtual void Push(const std::vector<int>& keys,
128  const std::vector<NDArray>& values,
129  int priority = 0) = 0;
130 
137  virtual void Push(const std::vector<std::string>& str_keys,
138  const std::vector<NDArray>& values,
139  int priority = 0) = 0;
163  virtual void Pull(const std::vector<int>& keys,
164  const std::vector<NDArray*>& values,
165  int priority = 0) = 0;
172  virtual void Pull(const std::vector<std::string>& str_keys,
173  const std::vector<NDArray*>& values,
174  int priority = 0) = 0;
175 
176 
180  typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
190  virtual void set_updater(const Updater& updater) {
191  CHECK(updater) << "invalid updater";
192  updater_ = updater;
193  }
194 
195  /******************************************************
196  * the following are used for multi-machines.
197  ******************************************************/
198 
203  static void InitPSEnv(const std::unordered_map<std::string, std::string>& envs) {
204 #if MXNET_USE_DIST_KVSTORE
205  ps::Environment::Init(envs);
206 #else
207  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
208 #endif // MXNET_USE_DIST_KVSTORE
209  }
210 
216  static bool IsWorkerNode() {
217 #if MXNET_USE_DIST_KVSTORE
218  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
219  return (role_str == nullptr) || (!strcmp(role_str, "worker"));
220 #else
221  return true;
222 #endif // MXNET_USE_DIST_KVSTORE
223  }
224 
230  static bool IsServerNode() {
231 #if MXNET_USE_DIST_KVSTORE
232  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
233  return (role_str != nullptr) && (!strcmp(role_str, "server"));
234 #else
235  return false;
236 #endif // MXNET_USE_DIST_KVSTORE
237  }
238 
239  void set_barrier_before_exit(const bool barrier_before_exit) {
240 #if MXNET_USE_DIST_KVSTORE
241  if (!IsWorkerNode()) LOG(FATAL) << "barrier_before_exit takes effect only on worker nodes";
242  barrier_before_exit_ = barrier_before_exit;
243 #else
244  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to enable barrier";
245 #endif
246  }
247 
253  static bool IsSchedulerNode() {
254 #if MXNET_USE_DIST_KVSTORE
255  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
256  return (role_str != nullptr) && (!strcmp(role_str, "scheduler"));
257 #else
258  return false;
259 #endif // MXNET_USE_DIST_KVSTORE
260  }
261 
268  virtual int get_rank() const {
269  return 0;
270  }
271 
275  virtual int get_group_size() const {
276  return 1;
277  }
278 
287  virtual int get_num_dead_node(int node_id, int timeout = 60) const {
288  return 0;
289  }
290 
298  virtual void Barrier() { }
299 
311  virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
312 
316  typedef std::function<void(int, const std::string&)> Controller;
317 
331  virtual void RunServer(const Controller& controller) { }
332 
333  protected:
337  Updater updater_;
338 
342  std::string type_;
343 
347  std::atomic<bool> barrier_before_exit_{true};
348 };
349 
350 } // namespace mxnet
351 #endif // MXNET_KVSTORE_H_
distributed key-value store
Definition: kvstore.h:44
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:180
namespace of mxnet
Definition: base.h:126
virtual int get_rank() const
Definition: kvstore.h:268
static KVStore * Create(const char *type="local")
Factory function to create a new KVStore.
Updater updater_
the user-defined updater
Definition: kvstore.h:337
const std::string & type()
return the type
Definition: kvstore.h:64
virtual void Pull(const std::vector< int > &keys, const std::vector< NDArray * > &values, int priority=0)=0
pull a list of key-value pairs from the store
static bool IsSchedulerNode()
Definition: kvstore.h:253
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:298
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:203
virtual void Init(const std::vector< int > &keys, const std::vector< NDArray > &values)=0
Initialize a list of key-value pair to the store.
static bool IsWorkerNode()
Definition: kvstore.h:216
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:47
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:239
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:287
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:331
virtual void Push(const std::vector< int > &keys, const std::vector< NDArray > &values, int priority=0)=0
push a list of key-value pairs into the store
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:311
std::string type_
the kvstore type
Definition: kvstore.h:342
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:316
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:190
virtual int get_group_size() const
Definition: kvstore.h:275
std::atomic< bool > barrier_before_exit_
whether to do barrier when finalize
Definition: kvstore.h:347
static bool IsServerNode()
Definition: kvstore.h:230