mxnet
monitor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_MONITOR_H_
27 #define MXNET_CPP_MONITOR_H_
28 
29 #include <regex>
30 #include <tuple>
31 #include <vector>
32 #include <map>
33 #include <set>
34 #include <string>
35 #include <functional>
36 #include "mxnet-cpp/base.h"
37 #include "mxnet-cpp/ndarray.h"
38 #include "mxnet-cpp/executor.h"
39 
40 namespace mxnet {
41 namespace cpp {
42 
49 NDArray _default_monitor_func(const NDArray &x);
50 
54 class Monitor {
55  public:
56  typedef std::function<NDArray(const NDArray&)> StatFunc;
57  typedef std::tuple<int, std::string, NDArray> Stat;
58 
66  Monitor(int interval, std::regex pattern = std::regex(".*"),
67  StatFunc stat_func = _default_monitor_func);
68 
73  void install(Executor *exe);
74 
78  void tic();
79 
84  std::vector<Stat> toc();
85 
89  void toc_print();
90 
91  protected:
92  int interval;
93  std::regex pattern;
94  StatFunc stat_func;
95  std::vector<Executor*> exes;
96 
97  int step;
98  bool activated;
99  std::vector<Stat> stats;
100 
101  static void executor_callback(const char *name, NDArrayHandle ndarray, void *monitor_ptr);
102 };
103 
104 } // namespace cpp
105 } // namespace mxnet
106 #endif // MXNET_CPP_MONITOR_H_
std::regex pattern
Definition: monitor.h:93
void toc_print()
End collecting and print results.
int interval
Definition: monitor.h:92
std::function< NDArray(const NDArray &)> StatFunc
Definition: monitor.h:56
namespace of mxnet
Definition: base.h:126
Executor interface.
Definition: executor.h:44
Monitor(int interval, std::regex pattern=std::regex(".*"), StatFunc stat_func=_default_monitor_func)
Monitor constructor.
NDArray _default_monitor_func(const NDArray &x)
Default function for monitor that computes statistics of the input tensor, which is the mean absolute...
void install(Executor *exe)
install callback to executor. Supports installing to multiple executors.
bool activated
Definition: monitor.h:98
void tic()
Start collecting stats for current batch. Call before calling forward.
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:64
int step
Definition: monitor.h:97
std::vector< Stat > stats
Definition: monitor.h:99
Monitor interface.
Definition: monitor.h:54
StatFunc stat_func
Definition: monitor.h:94
std::vector< Executor * > exes
Definition: monitor.h:95
std::tuple< int, std::string, NDArray > Stat
Definition: monitor.h:57
std::vector< Stat > toc()
End collecting for current batch and return results. Call after computation of current batch...
static void executor_callback(const char *name, NDArrayHandle ndarray, void *monitor_ptr)