46 #include <unordered_map>
47 #include <type_traits>
50 #include "./logging.h"
104 inline T
Get()
const;
119 inline bool is_nil()
const;
126 template<
typename... Args>
197 return this->
Eval(lua_code.c_str());
288 #define LUA_CALL(x) \
290 LOG(FATAL) << "Lua Call Error:" << lua_tostring(L, -1); \
304 namespace lua_stack {
305 inline int lua_abs_index(lua_State* L,
int index) {
306 if (index > 0 || index <= LUA_REGISTRYINDEX)
return index;
307 return lua_gettop(L) + index + 1;
314 struct NumberHandler {
315 static inline T Get(lua_State* L,
int index, LuaState* s) {
316 CHECK_EQ(lua_type(L, index), LUA_TNUMBER)
317 <<
"Attempt to get number but type is \'"
318 << lua_typename(L, lua_type(L, index)) <<
'\'';
319 if (std::is_integral<T>::value) {
320 return static_cast<T
>(lua_tointeger(L, index));
322 return static_cast<T
>(lua_tonumber(L, index));
325 static inline void Push(lua_State* L,
const T& v) {
326 if (std::is_integral<T>::value) {
327 lua_pushinteger(L,
static_cast<lua_Integer
>(v));
329 lua_pushnumber(L,
static_cast<lua_Number
>(v));
334 template<
typename ContainerType>
336 using K =
typename ContainerType::key_type;
337 using V =
typename ContainerType::mapped_type;
338 static inline ContainerType Get(lua_State* L,
int index, LuaState* s) {
340 CHECK(lua_istable(L, index))
341 <<
"Expected a table but get "
342 << lua_typename(L, lua_type(L, index)) <<
'\'';
343 int tid = lua_abs_index(L, index);
345 while (lua_next(L, -2)) {
346 ret[Handler<K>::Get(L, -2, s)] = Handler<V>::Pop(L, -1, s);
352 static inline void Push(lua_State* L,
const ContainerType& v) {
353 lua_createtable(L, v.size(), 0);
354 for (
const auto& kv : v) {
355 Handler<K>::Push(L, kv.first);
356 Handler<V>::Push(L, kv.second);
362 struct UndefinedHandler {
367 :
public std::conditional<std::is_arithmetic<T>::value,
369 UndefinedHandler>::type {
373 struct Handler<
std::string> {
374 static inline std::string Get(lua_State* L,
int index, LuaState* s) {
375 CHECK_EQ(lua_type(L, index), LUA_TSTRING);
376 return std::string(lua_tostring(L, index));
378 static inline void Push(lua_State* L,
const std::string& v) {
379 lua_pushstring(L, v.c_str());
384 struct Handler<
std::vector<T> > {
385 static inline std::vector<T> Get(lua_State* L,
int index, LuaState* s) {
387 CHECK(lua_istable(L, index))
388 <<
"Expected a table but get "
389 << lua_typename(L, lua_type(L, index)) <<
'\'';
390 int tid = lua_abs_index(L, index);
392 while (lua_next(L, tid)) {
393 CHECK_EQ(Handler<size_t>::Get(L, -2, s), ret.size() + 1)
394 <<
"Target table is not an array";
395 ret.push_back(Handler<T>::Get(L, -1, s));
401 static inline void Push(lua_State* L,
const std::vector<T>& v) {
402 lua_createtable(L, v.size(), 0);
403 for (
size_t i = 0; i < v.size(); ++i) {
404 Handler<T>::Push(L, v[i]);
405 lua_rawseti(L, -2, i + 1);
410 template<
typename K,
typename V>
411 struct Handler<
std::unordered_map<K, V> >
412 :
public MapHandler<std::unordered_map<K, V> > {
416 struct Handler<LuaRef> {
417 static inline LuaRef Get(lua_State* L,
int index, LuaState* s) {
419 lua_pushvalue(L, index);
420 ret.SetByPopStack_(s);
424 static inline void Push(lua_State* L,
const LuaRef& v) {
428 CHECK(v.state_->SameLuaState(L))
429 <<
"Cannot pass LuaRef on a different LuaState's function";
430 lua_rawgeti(L, LUA_REGISTRYINDEX, v.ref_);
436 struct Handler<
std::nullptr_t> {
437 static inline LuaRef Get(lua_State* L,
int index, LuaState* s) {
438 LOG(FATAL) <<
"not supported";
441 static inline void Push(lua_State* L,
const std::nullptr_t& v) {
450 inline void operator()(
const T& v)
const {
451 Handler<T>::Push(L, v);
458 L_ = luaL_newstate();
460 <<
"Failed to create new lua state";
475 <<
"use LuaState::ThreadLocalState() to get the thread local state";
480 CHECK(state_ ==
nullptr);
481 lua_State* L = s->L_;
482 if (!lua_isnil(L, -1)) {
483 ref_ = lua_ref(L, LUA_REGISTRYINDEX);
491 struct LuaState::StackReset {
502 StackReset reset{
L_, lua_gettop(
L_)};
505 <<
"Invoke lua from a different thread in ThreadLocal mode.";
508 CHECK_EQ(reset.top, lua_gettop(
L_));
511 StackReset reset{
L_, lua_gettop(
L_)};
513 CHECK_EQ(reset.top, lua_gettop(
L_));
523 this->
PRun_([
this, lua_code, &ret](lua_State* L) {
524 luaL_loadstring(L, lua_code);
525 CHECK_EQ(lua_pcall(L, 0, 1, 0), 0)
526 <<
"Lua call error: " << lua_tostring(L, -1) <<
'\n'
530 ret.SetByPopStack_(
this);
538 this->
PRun_([
this, &value, &ret](lua_State* L) {
539 lua_stack::Handler<T>::Push(L, value);
540 ret.SetByPopStack_(
this);
547 this->
PRun_([
this, &key, &ret](lua_State* L) {
548 lua_getglobal(L, key.c_str());
549 ret.SetByPopStack_(
this);
555 const std::string& key,
const LuaRef& value) {
556 this->
PRun_([
this, &key, &value](lua_State* L) {
557 lua_rawgeti(L, LUA_REGISTRYINDEX, value.ref_);
558 lua_setglobal(L, key.c_str());
563 if (other.state_ !=
nullptr) {
564 state_ = other.state_;
565 state_->
PRun_([
this, &other](lua_State* L) {
566 lua_rawgeti(L, LUA_REGISTRYINDEX, other.ref_);
567 ref_ = luaL_ref(L, LUA_REGISTRYINDEX);
574 state_ = other.state_;
575 other.state_ =
nullptr;
579 LuaRef(std::move(other)).swap(*
this);
584 LuaRef(other).swap(*
this);
589 std::swap(state_, other.state_);
590 std::swap(ref_, other.ref_);
594 if (state_ !=
nullptr) {
595 state_->
PRun_([
this](lua_State* L) {
596 luaL_unref(L, LUA_REGISTRYINDEX, ref_);
602 return state_ ==
nullptr;
605 std::ostream &
operator<<(std::ostream &os,
const LuaRef &r) {
607 r.state_->PRun_([&os, &r](lua_State* L) {
608 lua_rawgeti(L, LUA_REGISTRYINDEX, r.ref_);
609 int type = lua_type(L, -1);
612 os <<
"lua_string:'" << lua_tostring(L, -1) <<
"'";
break;
614 os <<
"lua_bool:" << (lua_toboolean(L, -1) ?
"true" :
"false");
break;
616 os <<
"lua_number:" << lua_tonumber(L, -1);
break;
618 os <<
"lua[ref=" << r.ref_ <<
']' << lua_typename(L, type);
break;
630 CHECK(state_ !=
nullptr) <<
"Get:: LuaRef is nil";
632 state_->
PRun_([&ret,
this](lua_State* L) {
633 lua_rawgeti(L, LUA_REGISTRYINDEX, ref_);
634 ret = lua_stack::Handler<T>::Get(L, -1, state_);
642 CHECK(state_ !=
nullptr) <<
"Get:: LuaRef is nil";
644 state_->
PRun_([&ret,
this](lua_State* L) {
645 lua_rawgeti(L, LUA_REGISTRYINDEX, ref_);
646 ret =
reinterpret_cast<T*
>(lua_touserdata(L, -1));
653 template<
bool stop, std::size_t I,
typename F,
typename ...Args>
654 struct for_each_dispatcher_ {
655 static inline void run(
const std::tuple<Args...>& args,
F f) {
656 f(std::get<I>(args));
657 for_each_dispatcher_<(I + 1) ==
sizeof...(Args), (I+1),
F, Args...>::run(args, f);
661 template<std::size_t I,
typename F,
typename ...Args>
662 struct for_each_dispatcher_<true, I,
F, Args...> {
663 static inline void run(
const std::tuple<Args...>& args,
F f) {
668 template<
typename F,
typename ...Args>
669 inline void for_each(
const std::tuple<Args...>& args,
F f) {
670 for_each_dispatcher_<
sizeof...(Args) == 0, 0,
F, Args...>::run(args, f);
673 template<
typename... Args>
675 CHECK(state_ !=
nullptr) <<
"LuaRef is nil";
676 auto targ = std::make_tuple(std::forward<Args>(args)...);
677 size_t nargs =
sizeof...(Args);
679 state_->
PRun_([
this, nargs, &targ, &ret](lua_State* L) {
680 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
681 CHECK(lua_isfunction(L, -1))
682 <<
"Expect to invoke a function but type='"
683 << lua_typename(L, lua_type(L, -1)) <<
'\'';
684 for_each(targ, lua_stack::PushArg{L});
685 LUA_CALL(lua_pcall(L, nargs, 1, 0));
686 ret.SetByPopStack_(state_);
693 CHECK(state_ !=
nullptr) <<
"LuaRef is nil";
694 state_->
PRun_([
this, &key, &value](lua_State* L) {
695 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
696 CHECK(lua_istable(L, -1))
697 <<
"Expect a table but type='"
698 << lua_typename(L, lua_type(L, -1)) <<
'\'';
699 lua_stack::Handler<T>::Push(L, value);
700 lua_setfield(L, -2, key.c_str());
707 CHECK(state_ !=
nullptr) <<
"LuaRef is nil";
709 state_->
PRun_([
this, &key, &ret](lua_State* L) {
710 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
711 CHECK(lua_istable(L, -1))
712 <<
"Expect a table but type='"
713 << lua_typename(L, lua_type(L, -1)) <<
'\'';
714 lua_getfield(L, -1, key.c_str());
715 ret.SetByPopStack_(state_);
722 CHECK(state_ !=
nullptr) <<
"LuaRef is nil";
724 state_->
PRun_([
this, index, &ret](lua_State* L) {
725 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
726 CHECK(lua_istable(L, -1))
727 <<
"Expect a table but type='"
728 << lua_typename(L, lua_type(L, -1)) <<
'\'';
729 lua_rawgeti(L, -1, index);
730 ret.SetByPopStack_(state_);
739 #endif // DMLC_LUA_H_