46 #include <unordered_map> 47 #include <type_traits> 50 #include "./logging.h" 96 inline void swap(
LuaRef& other);
104 inline T Get()
const;
117 inline T* GetUDataPtr()
const;
119 inline bool is_nil()
const;
126 template<
typename... Args>
127 inline LuaRef operator()(Args&& ...args)
const;
134 inline LuaRef operator[](
const std::string& key)
const;
142 inline LuaRef operator[](
size_t index)
const;
151 inline LuaRef& SetField(
const std::string& key,
const T& value);
159 inline void SetByPopStack_(
LuaState* s);
189 inline LuaRef Eval(
const char* lua_code);
197 return this->Eval(lua_code.c_str());
207 inline LuaRef Convert(
const T& value);
213 inline LuaRef operator[](
const std::string& key);
219 inline void SetGlobalField(
const std::string& key,
const LuaRef& value);
228 static inline LuaState* ThreadLocalState();
259 inline void PRun_(
F f);
288 #define LUA_CALL(x) \ 290 LOG(FATAL) << "Lua Call Error:" << lua_tostring(L, -1); \ 304 namespace lua_stack {
305 inline int lua_abs_index(lua_State* L,
int index) {
306 if (index > 0 || index <= LUA_REGISTRYINDEX)
return index;
307 return lua_gettop(L) + index + 1;
314 struct NumberHandler {
315 static inline T Get(lua_State* L,
int index,
LuaState* s) {
316 CHECK_EQ(lua_type(L, index), LUA_TNUMBER)
317 <<
"Attempt to get number but type is \'" 318 << lua_typename(L, lua_type(L, index)) <<
'\'';
319 if (std::is_integral<T>::value) {
320 return static_cast<T
>(lua_tointeger(L, index));
322 return static_cast<T
>(lua_tonumber(L, index));
325 static inline void Push(lua_State* L,
const T& v) {
326 if (std::is_integral<T>::value) {
327 lua_pushinteger(L, static_cast<lua_Integer>(v));
329 lua_pushnumber(L, static_cast<lua_Number>(v));
334 template<
typename ContainerType>
336 using K =
typename ContainerType::key_type;
337 using V =
typename ContainerType::mapped_type;
338 static inline ContainerType Get(lua_State* L,
int index,
LuaState* s) {
340 CHECK(lua_istable(L, index))
341 <<
"Expected a table but get " 342 << lua_typename(L, lua_type(L, index)) <<
'\'';
343 int tid = lua_abs_index(L, index);
345 while (lua_next(L, -2)) {
346 ret[Handler<K>::Get(L, -2, s)] = Handler<V>::Pop(L, -1, s);
352 static inline void Push(lua_State* L,
const ContainerType& v) {
353 lua_createtable(L, v.size(), 0);
354 for (
const auto& kv : v) {
355 Handler<K>::Push(L, kv.first);
356 Handler<V>::Push(L, kv.second);
362 struct UndefinedHandler {
367 :
public std::conditional<std::is_arithmetic<T>::value,
369 UndefinedHandler>::type {
373 struct Handler<std::string> {
374 static inline std::string Get(lua_State* L,
int index,
LuaState* s) {
375 CHECK_EQ(lua_type(L, index), LUA_TSTRING);
376 return std::string(lua_tostring(L, index));
378 static inline void Push(lua_State* L,
const std::string& v) {
379 lua_pushstring(L, v.c_str());
384 struct Handler<std::vector<T> > {
385 static inline std::vector<T> Get(lua_State* L,
int index,
LuaState* s) {
387 CHECK(lua_istable(L, index))
388 <<
"Expected a table but get " 389 << lua_typename(L, lua_type(L, index)) <<
'\'';
390 int tid = lua_abs_index(L, index);
392 while (lua_next(L, tid)) {
393 CHECK_EQ(Handler<size_t>::Get(L, -2, s), ret.size() + 1)
394 <<
"Target table is not an array";
395 ret.push_back(Handler<T>::Get(L, -1, s));
401 static inline void Push(lua_State* L,
const std::vector<T>& v) {
402 lua_createtable(L, v.size(), 0);
403 for (
size_t i = 0; i < v.size(); ++i) {
404 Handler<T>::Push(L, v[i]);
405 lua_rawseti(L, -2, i + 1);
410 template<
typename K,
typename V>
411 struct Handler<std::unordered_map<K, V> >
412 :
public MapHandler<std::unordered_map<K, V> > {
416 struct Handler<LuaRef> {
417 static inline LuaRef Get(lua_State* L,
int index,
LuaState* s) {
419 lua_pushvalue(L, index);
424 static inline void Push(lua_State* L,
const LuaRef& v) {
429 <<
"Cannot pass LuaRef on a different LuaState's function";
430 lua_rawgeti(L, LUA_REGISTRYINDEX, v.ref_);
436 struct Handler<std::nullptr_t> {
437 static inline LuaRef Get(lua_State* L,
int index,
LuaState* s) {
438 LOG(FATAL) <<
"not supported";
441 static inline void Push(lua_State* L,
const std::nullptr_t& v) {
450 inline void operator()(
const T& v)
const {
451 Handler<T>::Push(L, v);
458 L_ = luaL_newstate();
460 <<
"Failed to create new lua state";
465 if (option_ != kThreadLocal && L_ !=
nullptr) {
474 CHECK_NE(opt, kThreadLocal)
475 <<
"use LuaState::ThreadLocalState() to get the thread local state";
480 CHECK(state_ ==
nullptr);
481 lua_State* L = s->
L_;
482 if (!lua_isnil(L, -1)) {
483 ref_ = lua_ref(L, LUA_REGISTRYINDEX);
491 struct LuaState::StackReset {
501 if (option_ != kLocking) {
502 StackReset reset{L_, lua_gettop(L_)};
503 if (option_ == kThreadLocal) {
504 CHECK_EQ(ThreadLocalState(),
this)
505 <<
"Invoke lua from a different thread in ThreadLocal mode.";
508 CHECK_EQ(reset.top, lua_gettop(L_));
510 std::lock_guard<std::mutex> lock(mutex_);
511 StackReset reset{L_, lua_gettop(L_)};
513 CHECK_EQ(reset.top, lua_gettop(L_));
523 this->PRun_([
this, lua_code, &ret](lua_State* L) {
524 luaL_loadstring(L, lua_code);
525 CHECK_EQ(lua_pcall(L, 0, 1, 0), 0)
526 <<
"Lua call error: " << lua_tostring(L, -1) <<
'\n' 538 this->PRun_([
this, &value, &ret](lua_State* L) {
547 this->PRun_([
this, &key, &ret](lua_State* L) {
548 lua_getglobal(L, key.c_str());
555 const std::string& key,
const LuaRef& value) {
556 this->PRun_([
this, &key, &value](lua_State* L) {
557 lua_rawgeti(L, LUA_REGISTRYINDEX, value.ref_);
558 lua_setglobal(L, key.c_str());
563 if (other.state_ !=
nullptr) {
564 state_ = other.state_;
565 state_->
PRun_([
this, &other](lua_State* L) {
566 lua_rawgeti(L, LUA_REGISTRYINDEX, other.ref_);
567 ref_ = luaL_ref(L, LUA_REGISTRYINDEX);
574 state_ = other.state_;
575 other.state_ =
nullptr;
589 std::swap(state_, other.state_);
590 std::swap(ref_, other.ref_);
594 if (state_ !=
nullptr) {
595 state_->PRun_([
this](lua_State* L) {
596 luaL_unref(L, LUA_REGISTRYINDEX, ref_);
602 return state_ ==
nullptr;
607 r.state_->
PRun_([&os, &r](lua_State* L) {
608 lua_rawgeti(L, LUA_REGISTRYINDEX, r.ref_);
609 int type = lua_type(L, -1);
612 os <<
"lua_string:'" << lua_tostring(L, -1) <<
"'";
break;
614 os <<
"lua_bool:" << (lua_toboolean(L, -1) ?
"true" :
"false");
break;
616 os <<
"lua_number:" << lua_tonumber(L, -1);
break;
618 os <<
"lua[ref=" << r.ref_ <<
']' << lua_typename(L, type);
break;
630 CHECK(state_ !=
nullptr) <<
"Get:: LuaRef is nil";
632 state_->PRun_([&ret,
this](lua_State* L) {
633 lua_rawgeti(L, LUA_REGISTRYINDEX, ref_);
642 CHECK(state_ !=
nullptr) <<
"Get:: LuaRef is nil";
644 state_->PRun_([&ret,
this](lua_State* L) {
645 lua_rawgeti(L, LUA_REGISTRYINDEX, ref_);
646 ret =
reinterpret_cast<T*
>(lua_touserdata(L, -1));
653 template<
bool stop, std::size_t I,
typename F,
typename ...Args>
654 struct for_each_dispatcher_ {
655 static inline void run(
const std::tuple<Args...>& args, F f) {
656 f(std::get<I>(args));
657 for_each_dispatcher_<(I + 1) ==
sizeof...(Args), (I+1), F, Args...>::run(args, f);
661 template<std::size_t I,
typename F,
typename ...Args>
662 struct for_each_dispatcher_<true, I, F, Args...> {
663 static inline void run(
const std::tuple<Args...>& args, F f) {
668 template<
typename F,
typename ...Args>
669 inline void for_each(
const std::tuple<Args...>& args, F f) {
670 for_each_dispatcher_<
sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
673 template<
typename... Args>
675 CHECK(state_ !=
nullptr) <<
"LuaRef is nil";
676 auto targ = std::make_tuple(std::forward<Args>(args)...);
677 size_t nargs =
sizeof...(Args);
679 state_->PRun_([
this, nargs, &targ, &ret](lua_State* L) {
680 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
681 CHECK(lua_isfunction(L, -1))
682 <<
"Expect to invoke a function but type='" 683 << lua_typename(L, lua_type(L, -1)) <<
'\'';
684 for_each(targ, lua_stack::PushArg{L});
685 LUA_CALL(lua_pcall(L, nargs, 1, 0));
693 CHECK(state_ !=
nullptr) <<
"LuaRef is nil";
694 state_->PRun_([
this, &key, &value](lua_State* L) {
695 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
696 CHECK(lua_istable(L, -1))
697 <<
"Expect a table but type='" 698 << lua_typename(L, lua_type(L, -1)) <<
'\'';
700 lua_setfield(L, -2, key.c_str());
707 CHECK(state_ !=
nullptr) <<
"LuaRef is nil";
709 state_->PRun_([
this, &key, &ret](lua_State* L) {
710 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
711 CHECK(lua_istable(L, -1))
712 <<
"Expect a table but type='" 713 << lua_typename(L, lua_type(L, -1)) <<
'\'';
714 lua_getfield(L, -1, key.c_str());
722 CHECK(state_ !=
nullptr) <<
"LuaRef is nil";
724 state_->PRun_([
this, index, &ret](lua_State* L) {
725 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
726 CHECK(lua_istable(L, -1))
727 <<
"Expect a table but type='" 728 << lua_typename(L, lua_type(L, -1)) <<
'\'';
729 lua_rawgeti(L, -1, index);
739 #endif // DMLC_LUA_H_ LuaRef Convert(const T &value)
convert a C++ type to lua type
LuaRef Eval(const std::string &lua_code)
evaluate a piece of lua code, return the first result.
Definition: lua.h:196
static T * Get()
Definition: thread_local.h:38
bool SameLuaState(lua_State *L) const
Definition: lua.h:264
T * GetUDataPtr() const
Get user data pointer from LuaRef.
A threadlocal store to store threadlocal variables. Will return a thread local singleton of type T...
Definition: thread_local.h:35
void SetByPopStack_(LuaState *s)
Set LuaRef to the value on top of the stack. This state must be nil. This is API used by developer...
void PRun_(F f)
protected run f, this is used by API developers. always call this to access lua state f must not dest...
Option option_
internal option, default to thread local
Definition: lua.h:278
LuaRef & SetField(const std::string &key, const T &value)
Set field of lua table. The reference must be a table.
static LuaState * Create_(Option option)
namespace for dmlc
Definition: array_view.h:12
lua_State * L_
internal lua state
Definition: lua.h:280
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
Option
options to be provided in lua state
Definition: lua.h:176
Portable thread local storage.
C++11 header only interface to easily interact with Lua and Torch. This code is evolved from torch pl...
static LuaState * ThreadLocalState()
void SetGlobalField(const std::string &key, const LuaRef &value)
Set the value to the global table.
T Get() const
Get content out as type T.
LuaRef()=default
construct an nil ref
void for_each(const F &f, Args &&...args)
Definition: packed_func.h:974
std::ostream & operator<<(std::ostream &os, const optional< T > &t)
serialize an optional object to string.
Definition: optional.h:141
LuaRef & operator=(LuaRef &&other)
assign operator from other
A Lua state.
Definition: lua.h:173
std::mutex mutex_
internal lock about the state
Definition: lua.h:282
LuaRef Eval(const char *lua_code)
evaluate a piece of lua code, return the first result.
void swap(LuaRef &other)
swap content with another ref
an reference to lua object
Definition: lua.h:64
LuaRef operator[](const std::string &key) const
Get field from the lua table. The reference must be a table.
LuaRef operator()(Args &&...args) const
invoke the LuaRef as function
LuaRef operator[](const std::string &key)
get global field from the state