mxnet
lua.h
Go to the documentation of this file.
1 
29 #ifndef DMLC_LUA_H_
30 #define DMLC_LUA_H_
31 
32 extern "C" {
33 #include <lua.h>
34 #include <luaT.h>
35 #include <lualib.h>
36 }
37 
38 #include <string>
39 #include <stdexcept>
40 #include <tuple>
41 #include <mutex>
42 #include <memory>
43 #include <vector>
44 #include <utility>
45 #include <algorithm>
46 #include <unordered_map>
47 #include <type_traits>
48 
49 #include "./base.h"
50 #include "./logging.h"
51 #include "./thread_local.h"
52 
53 namespace dmlc {
54 
55 // forward declare torch state
56 class LuaState;
57 
58 namespace lua_stack {
59 template<typename T>
60 struct Handler;
61 };
62 
64 class LuaRef {
65  public:
67  LuaRef() = default;
72  inline LuaRef(LuaRef&& other); // NOLINT(*)
77  inline LuaRef(const LuaRef& other); // NOLINT(*)
83  inline LuaRef& operator=(LuaRef&& other);
89  inline LuaRef& operator=(const LuaRef& other);
91  inline ~LuaRef();
96  inline void swap(LuaRef& other); // NOLINT(*)
103  template<typename T>
104  inline T Get() const;
116  template<typename T>
117  inline T* GetUDataPtr() const;
119  inline bool is_nil() const;
126  template<typename... Args>
127  inline LuaRef operator()(Args&& ...args) const;
134  inline LuaRef operator[](const std::string& key) const;
142  inline LuaRef operator[](size_t index) const;
150  template<typename T>
151  inline LuaRef& SetField(const std::string& key, const T& value); // NOLINT(*)
159  inline void SetByPopStack_(LuaState* s);
160 
161  private:
162  // friend with luastate
163  friend struct lua_stack::Handler<LuaRef>;
164  friend class LuaState;
165  friend std::ostream &operator<<(std::ostream &os, const LuaRef &r);
167  LuaState* state_{nullptr};
169  int ref_;
170 };
171 
173 class LuaState {
174  public:
176  enum Option {
180  };
182  inline ~LuaState();
189  inline LuaRef Eval(const char* lua_code);
196  inline LuaRef Eval(const std::string& lua_code) {
197  return this->Eval(lua_code.c_str());
198  }
206  template<typename T>
207  inline LuaRef Convert(const T& value);
213  inline LuaRef operator[](const std::string& key);
219  inline void SetGlobalField(const std::string& key, const LuaRef& value);
228  static inline LuaState* ThreadLocalState();
248  static inline LuaState* Create_(Option option);
249 
258  template<typename F>
259  inline void PRun_(F f);
264  inline bool SameLuaState(lua_State *L) const {
265  return L_ == L;
266  }
267 
268  protected:
269  struct StackReset;
270  friend class LuaRef;
271  friend struct ThreadLocalStore<LuaState>;
275  inline LuaState();
276 
278  Option option_{kThreadLocal};
280  lua_State* L_;
282  std::mutex mutex_;
283 };
284 
285 // implementations after this line
287 
288 #define LUA_CALL(x) \
289  if ((x)) { \
290  LOG(FATAL) << "Lua Call Error:" << lua_tostring(L, -1); \
291  }
292 
304 namespace lua_stack {
305 inline int lua_abs_index(lua_State* L, int index) {
306  if (index > 0 || index <= LUA_REGISTRYINDEX) return index;
307  return lua_gettop(L) + index + 1;
308 }
309 
310 template<typename T>
311 struct Handler;
312 
313 template<typename T>
314 struct NumberHandler {
315  static inline T Get(lua_State* L, int index, LuaState* s) {
316  CHECK_EQ(lua_type(L, index), LUA_TNUMBER)
317  << "Attempt to get number but type is \'"
318  << lua_typename(L, lua_type(L, index)) << '\'';
319  if (std::is_integral<T>::value) {
320  return static_cast<T>(lua_tointeger(L, index));
321  } else {
322  return static_cast<T>(lua_tonumber(L, index));
323  }
324  }
325  static inline void Push(lua_State* L, const T& v) {
326  if (std::is_integral<T>::value) {
327  lua_pushinteger(L, static_cast<lua_Integer>(v));
328  } else {
329  lua_pushnumber(L, static_cast<lua_Number>(v));
330  }
331  }
332 };
333 
334 template<typename ContainerType>
335 struct MapHandler {
336  using K = typename ContainerType::key_type;
337  using V = typename ContainerType::mapped_type;
338  static inline ContainerType Get(lua_State* L, int index, LuaState* s) {
339  ContainerType ret;
340  CHECK(lua_istable(L, index))
341  << "Expected a table but get "
342  << lua_typename(L, lua_type(L, index)) << '\'';
343  int tid = lua_abs_index(L, index);
344  lua_pushnil(L);
345  while (lua_next(L, -2)) {
346  ret[Handler<K>::Get(L, -2, s)] = Handler<V>::Pop(L, -1, s);
347  lua_pop(L, 1);
348  }
349  lua_settop(L, tid);
350  return ret;
351  }
352  static inline void Push(lua_State* L, const ContainerType& v) {
353  lua_createtable(L, v.size(), 0);
354  for (const auto& kv : v) {
355  Handler<K>::Push(L, kv.first);
356  Handler<V>::Push(L, kv.second);
357  lua_settable(L, -3);
358  }
359  }
360 };
361 
362 struct UndefinedHandler {
363 };
364 
365 template<typename T>
366 struct Handler
367  : public std::conditional<std::is_arithmetic<T>::value,
368  NumberHandler<T>,
369  UndefinedHandler>::type {
370 };
371 
372 template<>
373 struct Handler<std::string> {
374  static inline std::string Get(lua_State* L, int index, LuaState* s) {
375  CHECK_EQ(lua_type(L, index), LUA_TSTRING);
376  return std::string(lua_tostring(L, index));
377  }
378  static inline void Push(lua_State* L, const std::string& v) {
379  lua_pushstring(L, v.c_str());
380  }
381 };
382 
383 template<typename T>
384 struct Handler<std::vector<T> > {
385  static inline std::vector<T> Get(lua_State* L, int index, LuaState* s) {
386  std::vector<T> ret;
387  CHECK(lua_istable(L, index))
388  << "Expected a table but get "
389  << lua_typename(L, lua_type(L, index)) << '\'';
390  int tid = lua_abs_index(L, index);
391  lua_pushnil(L);
392  while (lua_next(L, tid)) {
393  CHECK_EQ(Handler<size_t>::Get(L, -2, s), ret.size() + 1)
394  << "Target table is not an array";
395  ret.push_back(Handler<T>::Get(L, -1, s));
396  lua_pop(L, 1);
397  }
398  lua_settop(L, tid);
399  return ret;
400  }
401  static inline void Push(lua_State* L, const std::vector<T>& v) {
402  lua_createtable(L, v.size(), 0);
403  for (size_t i = 0; i < v.size(); ++i) {
404  Handler<T>::Push(L, v[i]);
405  lua_rawseti(L, -2, i + 1);
406  }
407  }
408 };
409 
410 template<typename K, typename V>
411 struct Handler<std::unordered_map<K, V> >
412  : public MapHandler<std::unordered_map<K, V> > {
413 };
414 
415 template<>
416 struct Handler<LuaRef> {
417  static inline LuaRef Get(lua_State* L, int index, LuaState* s) {
418  LuaRef ret;
419  lua_pushvalue(L, index);
420  ret.SetByPopStack_(s);
421  return ret;
422  }
423 
424  static inline void Push(lua_State* L, const LuaRef& v) {
425  if (v.is_nil()) {
426  lua_pushnil(L);
427  } else {
428  CHECK(v.state_->SameLuaState(L))
429  << "Cannot pass LuaRef on a different LuaState's function";
430  lua_rawgeti(L, LUA_REGISTRYINDEX, v.ref_);
431  }
432  }
433 };
434 
435 template<>
436 struct Handler<std::nullptr_t> {
437  static inline LuaRef Get(lua_State* L, int index, LuaState* s) {
438  LOG(FATAL) << "not supported";
439  return LuaRef();
440  }
441  static inline void Push(lua_State* L, const std::nullptr_t& v) {
442  lua_pushnil(L);
443  }
444 };
445 
446 // generic functor to call push the arguments.
447 struct PushArg {
448  lua_State* L;
449  template<typename T>
450  inline void operator()(const T& v) const {
451  Handler<T>::Push(L, v);
452  }
453 };
454 
455 } // namespace lua_stack
456 
457 inline LuaState::LuaState() {
458  L_ = luaL_newstate();
459  CHECK(L_ != nullptr)
460  << "Failed to create new lua state";
461  luaL_openlibs(L_);
462 }
463 
464 inline LuaState::~LuaState() {
465  if (option_ != kThreadLocal && L_ != nullptr) {
466  // never close threadlocal, for save destruction.
467  lua_close(L_);
468  }
469 }
470 
471 inline LuaState* LuaState::Create_(Option opt) {
472  LuaState* s = new LuaState();
473  s->option_ = opt;
474  CHECK_NE(opt, kThreadLocal)
475  << "use LuaState::ThreadLocalState() to get the thread local state";
476  return s;
477 }
478 
479 inline void LuaRef::SetByPopStack_(LuaState* s) {
480  CHECK(state_ == nullptr);
481  lua_State* L = s->L_;
482  if (!lua_isnil(L, -1)) {
483  ref_ = lua_ref(L, LUA_REGISTRYINDEX);
484  state_ = s;
485  } else {
486  lua_pop(L, 1);
487  }
488 }
489 
490 // RAII guard to reset stack
491 struct LuaState::StackReset {
492  lua_State* L;
493  int top;
494  ~StackReset() {
495  lua_settop(L, top);
496  }
497 };
498 
499 template<typename F>
500 inline void LuaState::PRun_(F f) {
501  if (option_ != kLocking) {
502  StackReset reset{L_, lua_gettop(L_)};
503  if (option_ == kThreadLocal) {
504  CHECK_EQ(ThreadLocalState(), this)
505  << "Invoke lua from a different thread in ThreadLocal mode.";
506  }
507  f(L_);
508  CHECK_EQ(reset.top, lua_gettop(L_));
509  } else {
510  std::lock_guard<std::mutex> lock(mutex_);
511  StackReset reset{L_, lua_gettop(L_)};
512  f(L_);
513  CHECK_EQ(reset.top, lua_gettop(L_));
514  }
515 }
516 
519 }
520 
521 inline LuaRef LuaState::Eval(const char* lua_code) {
522  LuaRef ret;
523  this->PRun_([this, lua_code, &ret](lua_State* L) {
524  luaL_loadstring(L, lua_code);
525  CHECK_EQ(lua_pcall(L, 0, 1, 0), 0)
526  << "Lua call error: " << lua_tostring(L, -1) << '\n'
527  << "---------\n"
528  << lua_code
529  << "\n----------";
530  ret.SetByPopStack_(this);
531  });
532  return ret;
533 }
534 
535 template<typename T>
536 inline LuaRef LuaState::Convert(const T& value) {
537  LuaRef ret;
538  this->PRun_([this, &value, &ret](lua_State* L) {
539  lua_stack::Handler<T>::Push(L, value);
540  ret.SetByPopStack_(this);
541  });
542  return ret;
543 }
544 
545 inline LuaRef LuaState::operator[](const std::string& key) {
546  LuaRef ret;
547  this->PRun_([this, &key, &ret](lua_State* L) {
548  lua_getglobal(L, key.c_str());
549  ret.SetByPopStack_(this);
550  });
551  return ret;
552 }
553 
554 inline void LuaState::SetGlobalField(
555  const std::string& key, const LuaRef& value) {
556  this->PRun_([this, &key, &value](lua_State* L) {
557  lua_rawgeti(L, LUA_REGISTRYINDEX, value.ref_);
558  lua_setglobal(L, key.c_str());
559  });
560 }
561 
562 inline LuaRef::LuaRef(const LuaRef& other) {
563  if (other.state_ != nullptr) {
564  state_ = other.state_;
565  state_->PRun_([this, &other](lua_State* L) {
566  lua_rawgeti(L, LUA_REGISTRYINDEX, other.ref_);
567  ref_ = luaL_ref(L, LUA_REGISTRYINDEX);
568  });
569  }
570 }
571 
572 inline LuaRef::LuaRef(LuaRef&& other) {
573  ref_ = other.ref_;
574  state_ = other.state_;
575  other.state_ = nullptr;
576 }
577 
578 inline LuaRef& LuaRef::operator=(LuaRef&& other) {
579  LuaRef(std::move(other)).swap(*this);
580  return *this;
581 }
582 
583 inline LuaRef& LuaRef::operator=(const LuaRef& other) {
584  LuaRef(other).swap(*this);
585  return *this;
586 }
587 
588 inline void LuaRef::swap(LuaRef& other) { // NOLINT(*)
589  std::swap(state_, other.state_);
590  std::swap(ref_, other.ref_);
591 }
592 
593 inline LuaRef::~LuaRef() {
594  if (state_ != nullptr) {
595  state_->PRun_([this](lua_State* L) {
596  luaL_unref(L, LUA_REGISTRYINDEX, ref_);
597  });
598  }
599 }
600 
601 inline bool LuaRef::is_nil() const {
602  return state_ == nullptr;
603 }
604 
605 std::ostream &operator<<(std::ostream &os, const LuaRef &r) {
606  if (!r.is_nil()) {
607  r.state_->PRun_([&os, &r](lua_State* L) {
608  lua_rawgeti(L, LUA_REGISTRYINDEX, r.ref_);
609  int type = lua_type(L, -1);
610  switch (type) {
611  case LUA_TSTRING:
612  os << "lua_string:'" << lua_tostring(L, -1) << "'"; break;
613  case LUA_TBOOLEAN:
614  os << "lua_bool:" << (lua_toboolean(L, -1) ? "true" : "false"); break;
615  case LUA_TNUMBER:
616  os << "lua_number:" << lua_tonumber(L, -1); break;
617  default:
618  os << "lua[ref=" << r.ref_ << ']' << lua_typename(L, type); break;
619  }
620  lua_pop(L, 1);
621  });
622  } else {
623  os << "lua_nil";
624  }
625  return os;
626 }
627 
628 template<typename T>
629 inline T LuaRef::Get() const {
630  CHECK(state_ != nullptr) << "Get:: LuaRef is nil";
631  T ret;
632  state_->PRun_([&ret, this](lua_State* L) {
633  lua_rawgeti(L, LUA_REGISTRYINDEX, ref_);
634  ret = lua_stack::Handler<T>::Get(L, -1, state_);
635  lua_pop(L, 1);
636  });
637  return ret;
638 }
639 
640 template<typename T>
641 inline T* LuaRef::GetUDataPtr() const {
642  CHECK(state_ != nullptr) << "Get:: LuaRef is nil";
643  T* ret;
644  state_->PRun_([&ret, this](lua_State* L) {
645  lua_rawgeti(L, LUA_REGISTRYINDEX, ref_);
646  ret = reinterpret_cast<T*>(lua_touserdata(L, -1));
647  lua_pop(L, 1);
648  });
649  return ret;
650 }
651 
652 // helper function to dispatch varg foreach
653 template<bool stop, std::size_t I, typename F, typename ...Args>
654 struct for_each_dispatcher_ {
655  static inline void run(const std::tuple<Args...>& args, F f) {
656  f(std::get<I>(args));
657  for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
658  }
659 };
660 // helper function to run foreach
661 template<std::size_t I, typename F, typename ...Args>
662 struct for_each_dispatcher_<true, I, F, Args...> {
663  static inline void run(const std::tuple<Args...>& args, F f) {
664  }
665 };
666 
667 // template function to iterate over tuples
668 template<typename F, typename ...Args>
669 inline void for_each(const std::tuple<Args...>& args, F f) {
670  for_each_dispatcher_<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
671 }
672 
673 template<typename... Args>
674 inline LuaRef LuaRef::operator()(Args&& ...args) const {
675  CHECK(state_ != nullptr) << "LuaRef is nil";
676  auto targ = std::make_tuple(std::forward<Args>(args)...);
677  size_t nargs = sizeof...(Args);
678  LuaRef ret;
679  state_->PRun_([this, nargs, &targ, &ret](lua_State* L) {
680  lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
681  CHECK(lua_isfunction(L, -1))
682  << "Expect to invoke a function but type='"
683  << lua_typename(L, lua_type(L, -1)) << '\'';
684  for_each(targ, lua_stack::PushArg{L});
685  LUA_CALL(lua_pcall(L, nargs, 1, 0));
686  ret.SetByPopStack_(state_);
687  });
688  return ret;
689 }
690 
691 template<typename T>
692 inline LuaRef& LuaRef::SetField(const std::string& key, const T& value) { // NOLINT(*)
693  CHECK(state_ != nullptr) << "LuaRef is nil";
694  state_->PRun_([this, &key, &value](lua_State* L) {
695  lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
696  CHECK(lua_istable(L, -1))
697  << "Expect a table but type='"
698  << lua_typename(L, lua_type(L, -1)) << '\'';
699  lua_stack::Handler<T>::Push(L, value);
700  lua_setfield(L, -2, key.c_str());
701  lua_pop(L, 1);
702  });
703  return *this;
704 }
705 
706 inline LuaRef LuaRef::operator[](const std::string& key) const {
707  CHECK(state_ != nullptr) << "LuaRef is nil";
708  LuaRef ret;
709  state_->PRun_([this, &key, &ret](lua_State* L) {
710  lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
711  CHECK(lua_istable(L, -1))
712  << "Expect a table but type='"
713  << lua_typename(L, lua_type(L, -1)) << '\'';
714  lua_getfield(L, -1, key.c_str());
715  ret.SetByPopStack_(state_);
716  lua_pop(L, 1);
717  });
718  return ret;
719 }
720 
721 inline LuaRef LuaRef::operator[](size_t index) const {
722  CHECK(state_ != nullptr) << "LuaRef is nil";
723  LuaRef ret;
724  state_->PRun_([this, index, &ret](lua_State* L) {
725  lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
726  CHECK(lua_istable(L, -1))
727  << "Expect a table but type='"
728  << lua_typename(L, lua_type(L, -1)) << '\'';
729  lua_rawgeti(L, -1, index);
730  ret.SetByPopStack_(state_);
731  lua_pop(L, 1);
732  });
733  return ret;
734 }
735 
737 } // namespace dmlc
738 
739 #endif // DMLC_LUA_H_
LuaRef Convert(const T &value)
convert a C++ type to lua type
~LuaRef()
destructor
LuaRef Eval(const std::string &lua_code)
evaluate a piece of lua code, return the first result.
Definition: lua.h:196
static T * Get()
Definition: thread_local.h:38
bool SameLuaState(lua_State *L) const
Definition: lua.h:264
Definition: lua.h:178
Definition: lua.h:179
T * GetUDataPtr() const
Get user data pointer from LuaRef.
A threadlocal store to store threadlocal variables. Will return a thread local singleton of type T...
Definition: thread_local.h:35
LuaState()
constructor
void SetByPopStack_(LuaState *s)
Set LuaRef to the value on top of the stack. This state must be nil. This is API used by developer...
void PRun_(F f)
protected run f, this is used by API developers. always call this to access lua state f must not dest...
Definition: lua.h:60
Option option_
internal option, default to thread local
Definition: lua.h:278
LuaRef & SetField(const std::string &key, const T &value)
Set field of lua table. The reference must be a table.
static LuaState * Create_(Option option)
namespace for dmlc
Definition: array_view.h:12
lua_State * L_
internal lua state
Definition: lua.h:280
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
Option
options to be provided in lua state
Definition: lua.h:176
Portable thread local storage.
C++11 header only interface to easily interact with Lua and Torch. This code is evolved from torch pl...
static LuaState * ThreadLocalState()
void SetGlobalField(const std::string &key, const LuaRef &value)
Set the value to the global table.
T Get() const
Get content out as type T.
LuaRef()=default
construct an nil ref
void for_each(const F &f, Args &&...args)
Definition: packed_func.h:974
std::ostream & operator<<(std::ostream &os, const optional< T > &t)
serialize an optional object to string.
Definition: optional.h:141
LuaRef & operator=(LuaRef &&other)
assign operator from other
~LuaState()
destructor
A Lua state.
Definition: lua.h:173
std::mutex mutex_
internal lock about the state
Definition: lua.h:282
LuaRef Eval(const char *lua_code)
evaluate a piece of lua code, return the first result.
void swap(LuaRef &other)
swap content with another ref
an reference to lua object
Definition: lua.h:64
LuaRef operator[](const std::string &key) const
Get field from the lua table. The reference must be a table.
LuaRef operator()(Args &&...args) const
invoke the LuaRef as function
LuaRef operator[](const std::string &key)
get global field from the state
bool is_nil() const