# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Autograd for NDArray."""
from __future__ import absolute_import
from __future__ import division

from array import array
from threading import Lock
import traceback
import ctypes
from ctypes import c_int, c_void_p, CFUNCTYPE, POINTER, cast
from .base import _LIB, check_call, string_types, mx_uint
from .base import NDArrayHandle, c_array, c_handle_array, c_array_buf, MXCallbackList, SymbolHandle
from .ndarray import NDArray, _ndarray_cls
from .ndarray import _GRAD_REQ_MAP
from .symbol import Symbol

[docs]def set_recording(is_recording): #pylint: disable=redefined-outer-name """Set status to recording/not recording. When recording, graph will be constructed for gradient computation. Parameters ---------- is_recording: bool Returns ------- previous state before this set. """ prev = ctypes.c_int() check_call(_LIB.MXAutogradSetIsRecording( ctypes.c_int(is_recording), ctypes.byref(prev))) return bool(prev.value)
[docs]def set_training(train_mode): #pylint: disable=redefined-outer-name """Set status to training/predicting. This affects ctx.is_train in operator running context. For example, Dropout will drop inputs randomly when train_mode=True while simply passing through if train_mode=False. Parameters ---------- train_mode: bool Returns ------- previous state before this set. """ prev = ctypes.c_int() check_call(_LIB.MXAutogradSetIsTraining( ctypes.c_int(train_mode), ctypes.byref(prev))) return bool(prev.value)
[docs]def is_recording(): """Get status on recording/not recording. Returns ------- Current state of recording. """ curr = ctypes.c_bool() check_call(_LIB.MXAutogradIsRecording(ctypes.byref(curr))) return curr.value
[docs]def is_training(): """Get status on training/predicting. Returns ------- Current state of training/predicting. """ curr = ctypes.c_bool() check_call(_LIB.MXAutogradIsTraining(ctypes.byref(curr))) return curr.value
class _RecordingStateScope(object): """Scope for managing training state. Example:: with _RecordingStateScope(True, True): y = model(x) backward([y]) """ def __init__(self, is_record, train_mode): #pylint: disable=redefined-outer-name self._enter_is_record = is_record self._enter_train_mode = train_mode self._prev_is_record = None self._prev_train_mode = None def __enter__(self): if self._enter_is_record is not None: self._prev_is_record = set_recording(self._enter_is_record) if self._enter_train_mode is not None: self._prev_train_mode = set_training(self._enter_train_mode) def __exit__(self, ptype, value, trace): if self._enter_is_record is not None and self._prev_is_record != self._enter_is_record: set_recording(self._prev_is_record) if self._enter_train_mode is not None and self._prev_train_mode != self._enter_train_mode: set_training(self._prev_train_mode)
[docs]def record(train_mode=True): #pylint: disable=redefined-outer-name """Returns an autograd recording scope context to be used in 'with' statement and captures code that needs gradients to be calculated. .. note:: When forwarding with train_mode=False, the corresponding backward should also use train_mode=False, otherwise gradient is undefined. Example:: with autograd.record(): y = model(x) backward([y]) metric.update(...) optim.step(...) Parameters ---------- train_mode: bool, default True Whether the forward pass is in training or predicting mode. This controls the behavior of some layers such as Dropout, BatchNorm. """ return _RecordingStateScope(True, train_mode)
[docs]def pause(train_mode=False): #pylint: disable=redefined-outer-name """Returns a scope context to be used in 'with' statement for codes that do not need gradients to be calculated. Example:: with autograd.record(): y = model(x) backward([y]) with autograd.pause(): # testing, IO, gradient updates... Parameters ---------- train_mode: bool, default False Whether to do forward for training or predicting. """ return _RecordingStateScope(False, train_mode)
[docs]def train_mode(): """Returns a scope context to be used in 'with' statement in which forward pass behavior is set to training mode, without changing the recording states. Example:: y = model(x) with autograd.train_mode(): y = dropout(y) """ return _RecordingStateScope(None, True)
[docs]def predict_mode(): """Returns a scope context to be used in 'with' statement in which forward pass behavior is set to inference mode, without changing the recording states. Example:: with autograd.record(): y = model(x) with autograd.predict_mode(): y = sampling(y) backward([y]) """ return _RecordingStateScope(None, False)
[docs]def mark_variables(variables, gradients, grad_reqs='write'): """Mark NDArrays as variables to compute gradient for autograd. Parameters ---------- variables: NDArray or list of NDArray gradients: NDArray or list of NDArray grad_reqs: str or list of str """ if isinstance(variables, NDArray): assert isinstance(gradients, NDArray) variables = [variables] gradients = [gradients] if isinstance(grad_reqs, string_types): grad_reqs = [_GRAD_REQ_MAP[grad_reqs]]*len(variables) else: grad_reqs = [_GRAD_REQ_MAP[i] for i in grad_reqs] check_call(_LIB.MXAutogradMarkVariables( len(variables), c_handle_array(variables), c_array_buf(mx_uint, array('I', grad_reqs)), c_handle_array(gradients)))
def _parse_head(heads, head_grads): """parse head gradient for backward and grad.""" if isinstance(heads, NDArray): heads = [heads] if isinstance(head_grads, NDArray): head_grads = [head_grads] head_handles = c_handle_array(heads) if head_grads is None: hgrad_handles = ctypes.c_void_p(0) else: assert len(heads) == len(head_grads), \ "heads and head_grads must be lists of the same length" hgrad_handles = c_array(NDArrayHandle, [i.handle if i is not None else NDArrayHandle(0) for i in head_grads]) return head_handles, hgrad_handles
[docs]def backward(heads, head_grads=None, retain_graph=False, train_mode=True): #pylint: disable=redefined-outer-name """Compute the gradients of heads w.r.t previously marked variables. Parameters ---------- heads: NDArray or list of NDArray Output NDArray(s) head_grads: NDArray or list of NDArray or None Gradients with respect to heads. train_mode: bool, optional Whether to do backward for training or predicting. """ head_handles, hgrad_handles = _parse_head(heads, head_grads) check_call(_LIB.MXAutogradBackwardEx( len(head_handles), head_handles, hgrad_handles, 0, ctypes.c_void_p(0), ctypes.c_int(retain_graph), ctypes.c_int(0), ctypes.c_int(train_mode), ctypes.c_void_p(0), ctypes.c_void_p(0)))
[docs]def grad(heads, variables, head_grads=None, retain_graph=None, create_graph=False, train_mode=True): #pylint: disable=redefined-outer-name """Compute the gradients of heads w.r.t variables. Gradients will be returned as new NDArrays instead of stored into `variable.grad`. Supports recording gradient graph for computing higher order gradients. .. Note: Currently only a very limited set of operators support higher order gradients. Parameters ---------- heads: NDArray or list of NDArray Output NDArray(s) variables: NDArray or list of NDArray Input variables to compute gradients for. head_grads: NDArray or list of NDArray or None Gradients with respect to heads. retain_graph: bool Whether to keep computation graph to differentiate again, instead of clearing history and release memory. Defaults to the same value as create_graph. create_graph: bool Whether to record gradient graph for computing higher order train_mode: bool, optional Whether to do backward for training or prediction. Returns ------- NDArray or list of NDArray: Gradients with respect to variables. Examples -------- >>> x = mx.nd.ones((1,)) >>> x.attach_grad() >>> with mx.autograd.record(): ... z = mx.nd.elemwise_add(mx.nd.exp(x), x) >>> dx = mx.autograd.grad(z, [x], create_graph=True) >>> dx.backward() >>> print(dx.grad) [ [ 3.71828175] ] """ head_handles, hgrad_handles = _parse_head(heads, head_grads) if isinstance(variables, NDArray): variables = [variables] else: assert len(variables), "variables cannot be an empty list." var_handles = c_handle_array(variables) retain_graph = retain_graph if retain_graph is not None else create_graph grad_vars = ctypes.POINTER(NDArrayHandle)() grad_stypes = ctypes.POINTER(ctypes.c_int)() check_call(_LIB.MXAutogradBackwardEx( len(head_handles), head_handles, hgrad_handles, len(var_handles), var_handles, ctypes.c_int(retain_graph), ctypes.c_int(create_graph), ctypes.c_int(train_mode), ctypes.byref(grad_vars), ctypes.byref(grad_stypes))) ret = [_ndarray_cls(ctypes.cast(grad_vars[i], NDArrayHandle), stype=grad_stypes[i]) for i in range(len(var_handles))] if isinstance(variables, NDArray): return ret[0] return ret
[docs]def get_symbol(x): """Retrieve recorded computation history as `Symbol`. Parameters ---------- x : NDArray Array representing the head of computation graph. Returns ------- Symbol The retrieved Symbol. """ hdl = SymbolHandle() check_call(_LIB.MXAutogradGetSymbol(x.handle, ctypes.byref(hdl))) return Symbol(hdl)
[docs]class Function(object): """User-defined differentiable function. Function allows defining both forward and backward computation for custom operators. During gradient computation, the used-defined backward function will be used instead of the default chain-rule. You can also cast to numpy array and back for some operations in forward and backward. For example, a stable sigmoid function can be defined as:: class sigmoid(Function): def forward(self, x): y = 1 / (1 + mx.nd.exp(-x)) self.save_for_backward(y) return y def backward(self, dy): # backward takes as many inputs as forward's return value, # and returns as many NDArrays as forward's arguments. y, = self.saved_tensors return y * (1-y) """ _bwd_functype = CFUNCTYPE(c_int, c_int, c_int, POINTER(c_void_p), POINTER(c_int), c_int, c_void_p) _del_functype = CFUNCTYPE(c_int, c_void_p) class _Registry(object): """CustomOp registry.""" def __init__(self): self.ref_holder = {} self.counter = 0 self.lock = Lock() def inc(self): """Get index for new entry.""" self.lock.acquire() cur = self.counter self.counter += 1 self.lock.release() return cur _registry = _Registry() def __init__(self): self._used = False self.saved_tensors = () def save_for_backward(self, *args): self.saved_tensors = args def __call__(self, *inputs): assert not self._used, \ "Each Function instance can only be called once. "\ "Please create another instance." self._used = True prev_recording = set_recording(False) outputs = self.forward(*inputs) set_recording(prev_recording) if not prev_recording: return outputs ret_outputs = outputs if isinstance(outputs, NDArray): outputs = (outputs,) key = def backward_entry(num_ograds, num_igrads, ptrs, reqs, is_train, _): """entry point for backward.""" # pylint: disable=W0613 try: output_grads = [NDArray(ctypes.cast(i, NDArrayHandle), writable=False) \ for i in ptrs[:num_ograds]] input_grads = [NDArray(ctypes.cast(i, NDArrayHandle), writable=True) \ for i in ptrs[num_ograds:num_ograds+num_igrads]] reqs = [reqs[i] for i in range(num_igrads)] rets = self.backward(*output_grads) if isinstance(rets, NDArray): rets = (rets,) assert len(rets) == len(input_grads), \ "%s.backward must return exactly the same number " \ "of NDArrays as the number of NDArrays arguments to forward." \ "Expecting %d got %d"%(, len(input_grads), len(rets)) for igrad, ret, req in zip(input_grads, rets, reqs): assert isinstance(ret, NDArray), \ "autograd.Function.backward must return NDArrays, not %s"%type(ret) if req == 0: # null return elif req == 1 or req == 2: # write or inplace igrad[:] = ret elif req == 'add': igrad[:] += ret except Exception: # pylint: disable=broad-except print('Error in Function.backward: %s' % traceback.format_exc()) return False return True def delete_entry(_): """C Callback for CustomFunction::delete""" try: del Function._registry.ref_holder[key] except Exception: # pylint: disable=broad-except print('Error in autograd.Function.delete: %s' % traceback.format_exc()) return False return True callbacks = [Function._bwd_functype(backward_entry), Function._del_functype(delete_entry)] callbacks = [cast(i, CFUNCTYPE(c_int)) for i in callbacks] context = MXCallbackList(c_int(len(callbacks)), cast(c_array(CFUNCTYPE(c_int), callbacks), POINTER(CFUNCTYPE(c_int))), cast(c_array(c_void_p, [None]*len(callbacks)), POINTER(c_void_p))) check_call(_LIB.MXCustomFunctionRecord( c_int(len(inputs)), c_handle_array(inputs), c_int(len(outputs)), c_handle_array(outputs), ctypes.byref(context))) Function._registry.ref_holder[key] = context return ret_outputs
[docs] def forward(self, *inputs): """Forward computation.""" raise NotImplementedError
[docs] def backward(self, *output_grads): """Backward computation. Takes as many inputs as forward's outputs, and returns as many NDArrays as forward's inputs. """ raise NotImplementedError