mxnet
metric.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_METRIC_H_
27 #define MXNET_CPP_METRIC_H_
28 
29 #include <cmath>
30 #include <string>
31 #include <vector>
32 #include <algorithm>
33 #include "mxnet-cpp/ndarray.h"
34 #include "dmlc/logging.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
39 class EvalMetric {
40  public:
41  explicit EvalMetric(const std::string& name, int num = 0) : name(name), num(num) {}
42  virtual void Update(NDArray labels, NDArray preds) = 0;
43  void Reset() {
44  num_inst = 0;
45  sum_metric = 0.0f;
46  }
47  float Get() {
48  return sum_metric / num_inst;
49  }
50  void GetNameValue();
51 
52  protected:
53  std::string name;
54  int num;
55  float sum_metric = 0.0f;
56  int num_inst = 0;
57 
58  static void CheckLabelShapes(NDArray labels, NDArray preds, bool strict = false) {
59  if (strict) {
60  CHECK_EQ(Shape(labels.GetShape()), Shape(preds.GetShape()));
61  } else {
62  CHECK_EQ(labels.Size(), preds.Size());
63  }
64  }
65 };
66 
67 class Accuracy : public EvalMetric {
68  public:
69  Accuracy() : EvalMetric("accuracy") {}
70 
71  void Update(NDArray labels, NDArray preds) override {
72  CHECK_EQ(labels.GetShape().size(), 1);
73  mx_uint len = labels.GetShape()[0];
74  std::vector<mx_float> pred_data(len);
75  std::vector<mx_float> label_data(len);
76  preds.ArgmaxChannel().SyncCopyToCPU(&pred_data, len);
77  labels.SyncCopyToCPU(&label_data, len);
78  for (mx_uint i = 0; i < len; ++i) {
79  sum_metric += (pred_data[i] == label_data[i]) ? 1 : 0;
80  num_inst += 1;
81  }
82  }
83 };
84 
85 class LogLoss : public EvalMetric {
86  public:
87  LogLoss() : EvalMetric("logloss") {}
88 
89  void Update(NDArray labels, NDArray preds) override {
90  static const float epsilon = 1e-15;
91  mx_uint len = labels.GetShape()[0];
92  mx_uint m = preds.GetShape()[1];
93  std::vector<mx_float> pred_data(len * m);
94  std::vector<mx_float> label_data(len);
95  preds.SyncCopyToCPU(&pred_data, pred_data.size());
96  labels.SyncCopyToCPU(&label_data, len);
97  for (mx_uint i = 0; i < len; ++i) {
98  sum_metric += -std::log(std::max(pred_data[i * m + label_data[i]], epsilon));
99  num_inst += 1;
100  }
101  }
102 };
103 
104 class MAE : public EvalMetric {
105  public:
106  MAE() : EvalMetric("mae") {}
107 
108  void Update(NDArray labels, NDArray preds) override {
109  CheckLabelShapes(labels, preds);
110 
111  std::vector<mx_float> pred_data;
112  preds.SyncCopyToCPU(&pred_data);
113  std::vector<mx_float> label_data;
114  labels.SyncCopyToCPU(&label_data);
115 
116  size_t len = preds.Size();
117  mx_float sum = 0;
118  for (size_t i = 0; i < len; ++i) {
119  sum += std::abs(pred_data[i] - label_data[i]);
120  }
121  sum_metric += sum / len;
122  ++num_inst;
123  }
124 };
125 
126 class MSE : public EvalMetric {
127  public:
128  MSE() : EvalMetric("mse") {}
129 
130  void Update(NDArray labels, NDArray preds) override {
131  CheckLabelShapes(labels, preds);
132 
133  std::vector<mx_float> pred_data;
134  preds.SyncCopyToCPU(&pred_data);
135  std::vector<mx_float> label_data;
136  labels.SyncCopyToCPU(&label_data);
137 
138  size_t len = preds.Size();
139  mx_float sum = 0;
140  for (size_t i = 0; i < len; ++i) {
141  mx_float diff = pred_data[i] - label_data[i];
142  sum += diff * diff;
143  }
144  sum_metric += sum / len;
145  ++num_inst;
146  }
147 };
148 
149 class RMSE : public EvalMetric {
150  public:
151  RMSE() : EvalMetric("rmse") {}
152 
153  void Update(NDArray labels, NDArray preds) override {
154  CheckLabelShapes(labels, preds);
155 
156  std::vector<mx_float> pred_data;
157  preds.SyncCopyToCPU(&pred_data);
158  std::vector<mx_float> label_data;
159  labels.SyncCopyToCPU(&label_data);
160 
161  size_t len = preds.Size();
162  mx_float sum = 0;
163  for (size_t i = 0; i < len; ++i) {
164  mx_float diff = pred_data[i] - label_data[i];
165  sum += diff * diff;
166  }
167  sum_metric += std::sqrt(sum / len);
168  ++num_inst;
169  }
170 };
171 
172 class PSNR : public EvalMetric {
173  public:
174  PSNR() : EvalMetric("psnr") {}
175 
176  void Update(NDArray labels, NDArray preds) override {
177  CheckLabelShapes(labels, preds);
178 
179  std::vector<mx_float> pred_data;
180  preds.SyncCopyToCPU(&pred_data);
181  std::vector<mx_float> label_data;
182  labels.SyncCopyToCPU(&label_data);
183 
184  size_t len = preds.Size();
185  mx_float sum = 0;
186  for (size_t i = 0; i < len; ++i) {
187  mx_float diff = pred_data[i] - label_data[i];
188  sum += diff * diff;
189  }
190  mx_float mse = sum / len;
191  if (mse > 0) {
192  sum_metric += 10 * std::log(255.0f / mse) / log10_;
193  } else {
194  sum_metric += 99.0f;
195  }
196  ++num_inst;
197  }
198 
199  private:
200  mx_float log10_ = std::log(10.0f);
201 };
202 
203 } // namespace cpp
204 } // namespace mxnet
205 
206 #endif // MXNET_CPP_METRIC_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::cpp::Accuracy
Definition: metric.h:67
mxnet::cpp::MAE::MAE
MAE()
Definition: metric.h:106
mxnet::cpp::PSNR::PSNR
PSNR()
Definition: metric.h:174
mxnet::cpp::Accuracy::Accuracy
Accuracy()
Definition: metric.h:69
mxnet::cpp::EvalMetric::num
int num
Definition: metric.h:54
mxnet::cpp::RMSE
Definition: metric.h:149
mxnet::cpp::MSE::MSE
MSE()
Definition: metric.h:128
mxnet::cpp::NDArray::SyncCopyToCPU
void SyncCopyToCPU(mx_float *data, size_t size=0)
Do a synchronize copy to a contiguous CPU memory region.
mxnet::cpp::RMSE::RMSE
RMSE()
Definition: metric.h:151
mxnet::cpp::EvalMetric::EvalMetric
EvalMetric(const std::string &name, int num=0)
Definition: metric.h:41
mxnet::cpp::EvalMetric::num_inst
int num_inst
Definition: metric.h:56
mxnet::cpp::PSNR::Update
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:176
mxnet::cpp::EvalMetric::Update
virtual void Update(NDArray labels, NDArray preds)=0
mxnet::cpp::Accuracy::Update
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:71
mxnet::cpp::NDArray
NDArray interface.
Definition: ndarray.h:122
mxnet::cpp::NDArray::Size
size_t Size() const
ndarray.h
definition of ndarray
mxnet::cpp::LogLoss
Definition: metric.h:85
mxnet::cpp::LogLoss::Update
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:89
mxnet::cpp::EvalMetric::sum_metric
float sum_metric
Definition: metric.h:55
mxnet::cpp::NDArray::ArgmaxChannel
NDArray ArgmaxChannel()
mx_float
float mx_float
manually define float
Definition: c_api.h:67
mxnet::cpp::EvalMetric::name
std::string name
Definition: metric.h:53
mxnet::cpp::NDArray::GetShape
std::vector< mx_uint > GetShape() const
mxnet::cpp::MSE::Update
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:130
mxnet::cpp::EvalMetric::GetNameValue
void GetNameValue()
mxnet::cpp::Shape
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
mxnet::cpp::LogLoss::LogLoss
LogLoss()
Definition: metric.h:87
mxnet::cpp::EvalMetric
Definition: metric.h:39
mxnet::cpp::MAE
Definition: metric.h:104
mxnet::cpp::RMSE::Update
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:153
mx_uint
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:65
mxnet::cpp::EvalMetric::Reset
void Reset()
Definition: metric.h:43
mxnet::cpp::MSE
Definition: metric.h:126
mxnet::cpp::PSNR
Definition: metric.h:172
mxnet::cpp::EvalMetric::Get
float Get()
Definition: metric.h:47
mxnet::cpp::MAE::Update
void Update(NDArray labels, NDArray preds) override
Definition: metric.h:108
mxnet::cpp::EvalMetric::CheckLabelShapes
static void CheckLabelShapes(NDArray labels, NDArray preds, bool strict=false)
Definition: metric.h:58