Source code for mxnet.io

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Data iterators for common data formats."""
from __future__ import absolute_import
from collections import OrderedDict, namedtuple

import sys
import ctypes
import logging
import threading
try:
    import h5py
except ImportError:
    h5py = None
import numpy as np
from .base import _LIB
from .base import c_str_array, mx_uint, py_str
from .base import DataIterHandle, NDArrayHandle
from .base import mx_real_t
from .base import check_call, build_param_doc as _build_param_doc
from .ndarray import NDArray
from .ndarray.sparse import CSRNDArray
from .ndarray.sparse import array as sparse_array
from .ndarray import _ndarray_cls
from .ndarray import array
from .ndarray import concatenate

[docs]class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout information of the data or the label. The `layout` describes how the axes in `shape` should be interpreted, for example for image data setting `layout=NCHW` indicates that the first axis is number of examples in the batch(N), C is number of channels, H is the height and W is the width of the image. For sequential data, by default `layout` is set to ``NTC``, where N is number of examples in the batch, T the temporal axis representing time and C is the number of channels. Parameters ---------- cls : DataDesc The class. name : str Data name. shape : tuple of int Data shape. dtype : np.dtype, optional Data type. layout : str, optional Data layout. """ def __new__(cls, name, shape, dtype=mx_real_t, layout='NCHW'): # pylint: disable=super-on-old-class ret = super(cls, DataDesc).__new__(cls, name, shape) ret.dtype = dtype ret.layout = layout return ret def __repr__(self): return "DataDesc[%s,%s,%s,%s]" % (self.name, self.shape, self.dtype, self.layout) @staticmethod
[docs] def get_batch_axis(layout): """Get the dimension that corresponds to the batch size. When data parallelism is used, the data will be automatically split and concatenated along the batch-size dimension. Axis can be -1, which means the whole array will be copied for each data-parallelism device. Parameters ---------- layout : str layout string. For example, "NCHW". Returns ------- int An axis indicating the batch_size dimension. """ if layout is None: return 0 return layout.find('N')
@staticmethod
[docs] def get_list(shapes, types): """Get DataDesc list from attribute lists. Parameters ---------- shapes : a tuple of (name, shape) types : a tuple of (name, type) """ if types is not None: type_dict = dict(types) return [DataDesc(x[0], x[1], type_dict[x[0]]) for x in shapes] else: return [DataDesc(x[0], x[1]) for x in shapes]
[docs]class DataBatch(object): """A data batch. MXNet's data iterator returns a batch of data for each `next` call. This data contains `batch_size` number of examples. If the input data consists of images, then shape of these images depend on the `layout` attribute of `DataDesc` object in `provide_data` parameter. If `layout` is set to 'NCHW' then, images should be stored in a 4-D matrix of shape ``(batch_size, num_channel, height, width)``. If `layout` is set to 'NHWC' then, images should be stored in a 4-D matrix of shape ``(batch_size, height, width, num_channel)``. The channels are often in RGB order. Parameters ---------- data : list of `NDArray`, each array containing `batch_size` examples. A list of input data. label : list of `NDArray`, each array often containing a 1-dimensional array. optional A list of input labels. pad : int, optional The number of examples padded at the end of a batch. It is used when the total number of examples read is not divisible by the `batch_size`. These extra padded examples are ignored in prediction. index : numpy.array, optional The example indices in this batch. bucket_key : int, optional The bucket key, used for bucketing module. provide_data : list of `DataDesc`, optional A list of `DataDesc` objects. `DataDesc` is used to store name, shape, type and layout information of the data. The *i*-th element describes the name and shape of ``data[i]``. provide_label : list of `DataDesc`, optional A list of `DataDesc` objects. `DataDesc` is used to store name, shape, type and layout information of the label. The *i*-th element describes the name and shape of ``label[i]``. """ def __init__(self, data, label=None, pad=None, index=None, bucket_key=None, provide_data=None, provide_label=None): if data is not None: assert isinstance(data, (list, tuple)), "Data must be list of NDArrays" if label is not None: assert isinstance(label, (list, tuple)), "Label must be list of NDArrays" self.data = data self.label = label self.pad = pad self.index = index self.bucket_key = bucket_key self.provide_data = provide_data self.provide_label = provide_label def __str__(self): data_shapes = [d.shape for d in self.data] label_shapes = [l.shape for l in self.label] return "{}: data shapes: {} label shapes: {}".format( self.__class__.__name__, data_shapes, label_shapes)
[docs]class DataIter(object): """The base class for an MXNet data iterator. All I/O in MXNet is handled by specializations of this class. Data iterators in MXNet are similar to standard-iterators in Python. On each call to `next` they return a `DataBatch` which represents the next batch of data. When there is no more data to return, it raises a `StopIteration` exception. Parameters ---------- batch_size : int, optional The batch size, namely the number of items in the batch. See Also -------- NDArrayIter : Data-iterator for MXNet NDArray or numpy-ndarray objects. CSVIter : Data-iterator for csv data. LibSVMIter : Data-iterator for libsvm data. ImageIter : Data-iterator for images. """ def __init__(self, batch_size=0): self.batch_size = batch_size def __iter__(self): return self
[docs] def reset(self): """Reset the iterator to the begin of the data.""" pass
[docs] def next(self): """Get next data batch from iterator. Returns ------- DataBatch The data of next batch. Raises ------ StopIteration If the end of the data is reached. """ if self.iter_next(): return DataBatch(data=self.getdata(), label=self.getlabel(), \ pad=self.getpad(), index=self.getindex()) else: raise StopIteration
def __next__(self): return self.next()
[docs] def iter_next(self): """Move to the next batch. Returns ------- boolean Whether the move is successful. """ pass
[docs] def getdata(self): """Get data of current batch. Returns ------- list of NDArray The data of the current batch. """ pass
[docs] def getlabel(self): """Get label of the current batch. Returns ------- list of NDArray The label of the current batch. """ pass
[docs] def getindex(self): """Get index of the current batch. Returns ------- index : numpy.array The indices of examples in the current batch. """ return None
[docs] def getpad(self): """Get the number of padding examples in the current batch. Returns ------- int Number of padding examples in the current batch. """ pass
[docs]class ResizeIter(DataIter): """Resize a data iterator to a given number of batches. Parameters ---------- data_iter : DataIter The data iterator to be resized. size : int The number of batches per epoch to resize to. reset_internal : bool Whether to reset internal iterator on ResizeIter.reset. Examples -------- >>> nd_iter = mx.io.NDArrayIter(mx.nd.ones((100,10)), batch_size=25) >>> resize_iter = mx.io.ResizeIter(nd_iter, 2) >>> for batch in resize_iter: ... print(batch.data) [] [] """ def __init__(self, data_iter, size, reset_internal=True): super(ResizeIter, self).__init__() self.data_iter = data_iter self.size = size self.reset_internal = reset_internal self.cur = 0 self.current_batch = None self.provide_data = data_iter.provide_data self.provide_label = data_iter.provide_label self.batch_size = data_iter.batch_size if hasattr(data_iter, 'default_bucket_key'): self.default_bucket_key = data_iter.default_bucket_key def reset(self): self.cur = 0 if self.reset_internal: self.data_iter.reset() def iter_next(self): if self.cur == self.size: return False try: self.current_batch = self.data_iter.next() except StopIteration: self.data_iter.reset() self.current_batch = self.data_iter.next() self.cur += 1 return True def getdata(self): return self.current_batch.data def getlabel(self): return self.current_batch.label def getindex(self): return self.current_batch.index def getpad(self): return self.current_batch.pad
[docs]class PrefetchingIter(DataIter): """Performs pre-fetch for other data iterators. This iterator will create another thread to perform ``iter_next`` and then store the data in memory. It potentially accelerates the data read, at the cost of more memory usage. Parameters ---------- iters : DataIter or list of DataIter The data iterators to be pre-fetched. rename_data : None or list of dict The *i*-th element is a renaming map for the *i*-th iter, in the form of {'original_name' : 'new_name'}. Should have one entry for each entry in iter[i].provide_data. rename_label : None or list of dict Similar to ``rename_data``. Examples -------- >>> iter1 = mx.io.NDArrayIter({'data':mx.nd.ones((100,10))}, batch_size=25) >>> iter2 = mx.io.NDArrayIter({'data':mx.nd.ones((100,10))}, batch_size=25) >>> piter = mx.io.PrefetchingIter([iter1, iter2], ... rename_data=[{'data': 'data_1'}, {'data': 'data_2'}]) >>> print(piter.provide_data) [DataDesc[data_1,(25, 10L),,NCHW], DataDesc[data_2,(25, 10L),,NCHW]] """ def __init__(self, iters, rename_data=None, rename_label=None): super(PrefetchingIter, self).__init__() if not isinstance(iters, list): iters = [iters] self.n_iter = len(iters) assert self.n_iter > 0 self.iters = iters self.rename_data = rename_data self.rename_label = rename_label self.batch_size = self.provide_data[0][1][0] self.data_ready = [threading.Event() for i in range(self.n_iter)] self.data_taken = [threading.Event() for i in range(self.n_iter)] for i in self.data_taken: i.set() self.started = True self.current_batch = [None for i in range(self.n_iter)] self.next_batch = [None for i in range(self.n_iter)] def prefetch_func(self, i): """Thread entry""" while True: self.data_taken[i].wait() if not self.started: break try: self.next_batch[i] = self.iters[i].next() except StopIteration: self.next_batch[i] = None self.data_taken[i].clear() self.data_ready[i].set() self.prefetch_threads = [threading.Thread(target=prefetch_func, args=[self, i]) \ for i in range(self.n_iter)] for thread in self.prefetch_threads: thread.setDaemon(True) thread.start() def __del__(self): self.started = False for i in self.data_taken: i.set() for thread in self.prefetch_threads: thread.join() @property def provide_data(self): if self.rename_data is None: return sum([i.provide_data for i in self.iters], []) else: return sum([[ DataDesc(r[x.name], x.shape, x.dtype) if isinstance(x, DataDesc) else DataDesc(*x) for x in i.provide_data ] for r, i in zip(self.rename_data, self.iters)], []) @property def provide_label(self): if self.rename_label is None: return sum([i.provide_label for i in self.iters], []) else: return sum([[ DataDesc(r[x.name], x.shape, x.dtype) if isinstance(x, DataDesc) else DataDesc(*x) for x in i.provide_label ] for r, i in zip(self.rename_label, self.iters)], []) def reset(self): for i in self.data_ready: i.wait() for i in self.iters: i.reset() for i in self.data_ready: i.clear() for i in self.data_taken: i.set() def iter_next(self): for i in self.data_ready: i.wait() if self.next_batch[0] is None: for i in self.next_batch: assert i is None, "Number of entry mismatches between iterators" return False else: for batch in self.next_batch: assert batch.pad == self.next_batch[0].pad, \ "Number of entry mismatches between iterators" self.current_batch = DataBatch(sum([batch.data for batch in self.next_batch], []), sum([batch.label for batch in self.next_batch], []), self.next_batch[0].pad, self.next_batch[0].index, provide_data=self.provide_data, provide_label=self.provide_label) for i in self.data_ready: i.clear() for i in self.data_taken: i.set() return True def next(self): if self.iter_next(): return self.current_batch else: raise StopIteration def getdata(self): return self.current_batch.data def getlabel(self): return self.current_batch.label def getindex(self): return self.current_batch.index def getpad(self): return self.current_batch.pad
def _init_data(data, allow_empty, default_name): """Convert data into canonical form.""" assert (data is not None) or allow_empty if data is None: data = [] if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) if h5py else (np.ndarray, NDArray)): data = [data] if isinstance(data, list): if not allow_empty: assert(len(data) > 0) if len(data) == 1: data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type else: data = OrderedDict( # pylint: disable=redefined-variable-type [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) if not isinstance(data, dict): raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \ "a list of them or dict with them as values") for k, v in data.items(): if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): try: data[k] = array(v) except: raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \ "should be NDArray, numpy.ndarray or h5py.Dataset") return list(data.items()) def _has_instance(data, dtype): """Return True if ``data`` has instance of ``dtype``. This function is called after _init_data. ``data`` is a list of (str, NDArray)""" for item in data: _, arr = item if isinstance(arr, dtype): return True return False def _shuffle(data, idx): """Shuffle the data.""" shuffle_data = [] for k, v in data: if (isinstance(v, h5py.Dataset) if h5py else False): shuffle_data.append((k, v)) elif isinstance(v, CSRNDArray): shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) else: shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) return shuffle_data
[docs]class NDArrayIter(DataIter): """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset`` ``mx.nd.sparse.CSRNDArray`` or ``scipy.sparse.csr_matrix``. Example usage: ---------- >>> data = np.arange(40).reshape((10,2,2)) >>> labels = np.ones([10, 1]) >>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard') >>> for batch in dataiter: ... print batch.data[0].asnumpy() ... batch.data[0].shape ... [[[ 36. 37.] [ 38. 39.]] [[ 16. 17.] [ 18. 19.]] [[ 12. 13.] [ 14. 15.]]] (3L, 2L, 2L) [[[ 32. 33.] [ 34. 35.]] [[ 4. 5.] [ 6. 7.]] [[ 24. 25.] [ 26. 27.]]] (3L, 2L, 2L) [[[ 8. 9.] [ 10. 11.]] [[ 20. 21.] [ 22. 23.]] [[ 28. 29.] [ 30. 31.]]] (3L, 2L, 2L) >>> dataiter.provide_data # Returns a list of `DataDesc` [DataDesc[data,(3, 2L, 2L),,NCHW]] >>> dataiter.provide_label # Returns a list of `DataDesc` [DataDesc[softmax_label,(3, 1L),,NCHW]] In the above example, data is shuffled as `shuffle` parameter is set to `True` and remaining examples are discarded as `last_batch_handle` parameter is set to `discard`. Usage of `last_batch_handle` parameter: >>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='pad') >>> batchidx = 0 >>> for batch in dataiter: ... batchidx += 1 ... >>> batchidx # Padding added after the examples read are over. So, 10/3+1 batches are created. 4 >>> dataiter = mx.io.NDArrayIter(data, labels, 3, True, last_batch_handle='discard') >>> batchidx = 0 >>> for batch in dataiter: ... batchidx += 1 ... >>> batchidx # Remaining examples are discarded. So, 10/3 batches are created. 3 `NDArrayIter` also supports multiple input and labels. >>> data = {'data1':np.zeros(shape=(10,2,2)), 'data2':np.zeros(shape=(20,2,2))} >>> label = {'label1':np.zeros(shape=(10,1)), 'label2':np.zeros(shape=(20,1))} >>> dataiter = mx.io.NDArrayIter(data, label, 3, True, last_batch_handle='discard') `NDArrayIter` also supports ``mx.nd.sparse.CSRNDArray`` with `last_batch_handle` set to `discard`. >>> csr_data = mx.nd.array(np.arange(40).reshape((10,4))).tostype('csr') >>> labels = np.ones([10, 1]) >>> dataiter = mx.io.NDArrayIter(csr_data, labels, 3, last_batch_handle='discard') >>> [batch.data[0] for batch in dataiter] [ , , ] Parameters ---------- data: array or list of array or dict of string to array The input data. label: array or list of array or dict of string to array, optional The input label. batch_size: int Batch size of data. shuffle: bool, optional Whether to shuffle the data. Only supported if no h5py.Dataset inputs are used. last_batch_handle : str, optional How to handle the last batch. This parameter can be 'pad', 'discard' or 'roll_over'. 'roll_over' is intended for training and can cause problems if used for prediction. data_name : str, optional The data name. label_name : str, optional The label name. """ def __init__(self, data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad', data_name='data', label_name='softmax_label'): super(NDArrayIter, self).__init__(batch_size) self.data = _init_data(data, allow_empty=False, default_name=data_name) self.label = _init_data(label, allow_empty=True, default_name=label_name) if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and (last_batch_handle != 'discard')): raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \ " with `last_batch_handle` set to `discard`.") self.idx = np.arange(self.data[0][1].shape[0]) # shuffle data if shuffle: np.random.shuffle(self.idx) self.data = _shuffle(self.data, self.idx) self.label = _shuffle(self.label, self.idx) # batching if last_batch_handle == 'discard': new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size self.idx = self.idx[:new_n] self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label] self.num_source = len(self.data_list) self.num_data = self.idx.shape[0] assert self.num_data >= batch_size, \ "batch_size needs to be smaller than data size." self.cursor = -batch_size self.batch_size = batch_size self.last_batch_handle = last_batch_handle @property def provide_data(self): """The name and shape of data provided by this iterator.""" return [ DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) for k, v in self.data ] @property def provide_label(self): """The name and shape of label provided by this iterator.""" return [ DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) for k, v in self.label ]
[docs] def hard_reset(self): """Ignore roll over data and set to start.""" self.cursor = -self.batch_size
def reset(self): if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data: self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size else: self.cursor = -self.batch_size def iter_next(self): self.cursor += self.batch_size return self.cursor < self.num_data def next(self): if self.iter_next(): return DataBatch(data=self.getdata(), label=self.getlabel(), \ pad=self.getpad(), index=None) else: raise StopIteration def _getdata(self, data_source): """Load data from underlying arrays, internal use only.""" assert(self.cursor < self.num_data), "DataIter needs reset." if self.cursor + self.batch_size <= self.num_data: return [ # np.ndarray or NDArray case x[1][self.cursor:self.cursor + self.batch_size] if isinstance(x[1], (np.ndarray, NDArray)) else # h5py (only supports indices in increasing order) array(x[1][sorted(self.idx[ self.cursor:self.cursor + self.batch_size])][[ list(self.idx[self.cursor: self.cursor + self.batch_size]).index(i) for i in sorted(self.idx[ self.cursor:self.cursor + self.batch_size]) ]]) for x in data_source ] else: pad = self.batch_size - self.num_data + self.cursor return [ # np.ndarray or NDArray case concatenate([x[1][self.cursor:], x[1][:pad]]) if isinstance(x[1], (np.ndarray, NDArray)) else # h5py (only supports indices in increasing order) concatenate([ array(x[1][sorted(self.idx[self.cursor:])][[ list(self.idx[self.cursor:]).index(i) for i in sorted(self.idx[self.cursor:]) ]]), array(x[1][sorted(self.idx[:pad])][[ list(self.idx[:pad]).index(i) for i in sorted(self.idx[:pad]) ]]) ]) for x in data_source ] def getdata(self): return self._getdata(self.data) def getlabel(self): return self._getdata(self.label) def getpad(self): if self.last_batch_handle == 'pad' and \ self.cursor + self.batch_size > self.num_data: return self.cursor + self.batch_size - self.num_data else: return 0
[docs]class MXDataIter(DataIter): """A python wrapper a C++ data iterator. This iterator is the Python wrapper to all native C++ data iterators, such as `CSVIter`, `ImageRecordIter`, `MNISTIter`, etc. When initializing `CSVIter` for example, you will get an `MXDataIter` instance to use in your Python code. Calls to `next`, `reset`, etc will be delegated to the underlying C++ data iterators. Usually you don't need to interact with `MXDataIter` directly unless you are implementing your own data iterators in C++. To do that, please refer to examples under the `src/io` folder. Parameters ---------- handle : DataIterHandle, required The handle to the underlying C++ Data Iterator. data_name : str, optional Data name. Default to "data". label_name : str, optional Label name. Default to "softmax_label". See Also -------- src/io : The underlying C++ data iterator implementation, e.g., `CSVIter`. """ def __init__(self, handle, data_name='data', label_name='softmax_label', **_): super(MXDataIter, self).__init__() self.handle = handle # debug option, used to test the speed with io effect eliminated self._debug_skip_load = False # load the first batch to get shape information self.first_batch = None self.first_batch = self.next() data = self.first_batch.data[0] label = self.first_batch.label[0] # properties self.provide_data = [DataDesc(data_name, data.shape, data.dtype)] self.provide_label = [DataDesc(label_name, label.shape, label.dtype)] self.batch_size = data.shape[0] def __del__(self): check_call(_LIB.MXDataIterFree(self.handle)) def debug_skip_load(self): # Set the iterator to simply return always first batch. This can be used # to test the speed of network without taking the loading delay into # account. self._debug_skip_load = True logging.info('Set debug_skip_load to be true, will simply return first batch') def reset(self): self._debug_at_begin = True self.first_batch = None check_call(_LIB.MXDataIterBeforeFirst(self.handle)) def next(self): if self._debug_skip_load and not self._debug_at_begin: return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) if self.first_batch is not None: batch = self.first_batch self.first_batch = None return batch self._debug_at_begin = False next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) if next_res.value: return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) else: raise StopIteration def iter_next(self): if self.first_batch is not None: return True next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) return next_res.value def getdata(self): hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl))) return _ndarray_cls(hdl, False) def getlabel(self): hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) return _ndarray_cls(hdl, False) def getindex(self): index_size = ctypes.c_uint64(0) index_data = ctypes.POINTER(ctypes.c_uint64)() check_call(_LIB.MXDataIterGetIndex(self.handle, ctypes.byref(index_data), ctypes.byref(index_size))) if index_size.value: address = ctypes.addressof(index_data.contents) dbuffer = (ctypes.c_uint64* index_size.value).from_address(address) np_index = np.frombuffer(dbuffer, dtype=np.uint64) return np_index.copy() else: return None def getpad(self): pad = ctypes.c_int(0) check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad))) return pad.value
def _make_io_iterator(handle): """Create an io iterator by handle.""" name = ctypes.c_char_p() desc = ctypes.c_char_p() num_args = mx_uint() arg_names = ctypes.POINTER(ctypes.c_char_p)() arg_types = ctypes.POINTER(ctypes.c_char_p)() arg_descs = ctypes.POINTER(ctypes.c_char_p)() check_call(_LIB.MXDataIterGetIterInfo( \ handle, ctypes.byref(name), ctypes.byref(desc), \ ctypes.byref(num_args), \ ctypes.byref(arg_names), \ ctypes.byref(arg_types), \ ctypes.byref(arg_descs))) iter_name = py_str(name.value) narg = int(num_args.value) param_str = _build_param_doc( [py_str(arg_names[i]) for i in range(narg)], [py_str(arg_types[i]) for i in range(narg)], [py_str(arg_descs[i]) for i in range(narg)]) doc_str = ('%s\n\n' + '%s\n' + 'Returns\n' + '-------\n' + 'MXDataIter\n'+ ' The result iterator.') doc_str = doc_str % (desc.value, param_str) def creator(*args, **kwargs): """Create an iterator. The parameters listed below can be passed in as keyword arguments. Parameters ---------- name : string, required. Name of the resulting data iterator. Returns ------- dataiter: Dataiter The resulting data iterator. """ param_keys = [] param_vals = [] for k, val in kwargs.items(): param_keys.append(k) param_vals.append(str(val)) # create atomic symbol param_keys = c_str_array(param_keys) param_vals = c_str_array(param_vals) iter_handle = DataIterHandle() check_call(_LIB.MXDataIterCreateIter( handle, mx_uint(len(param_keys)), param_keys, param_vals, ctypes.byref(iter_handle))) if len(args): raise TypeError('%s can only accept keyword arguments' % iter_name) return MXDataIter(iter_handle, **kwargs) creator.__name__ = iter_name creator.__doc__ = doc_str return creator def _init_io_module(): """List and add all the data iterators to current module.""" plist = ctypes.POINTER(ctypes.c_void_p)() size = ctypes.c_uint() check_call(_LIB.MXListDataIters(ctypes.byref(size), ctypes.byref(plist))) module_obj = sys.modules[__name__] for i in range(size.value): hdl = ctypes.c_void_p(plist[i]) dataiter = _make_io_iterator(hdl) setattr(module_obj, dataiter.__name__, dataiter) _init_io_module()