mxnet
model.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_MODEL_H_
27 #define MXNET_CPP_MODEL_H_
28 
29 #include <string>
30 #include <vector>
31 #include "mxnet-cpp/base.h"
32 #include "mxnet-cpp/symbol.h"
33 #include "mxnet-cpp/ndarray.h"
34 
35 namespace mxnet {
36 namespace cpp {
37 
40  std::vector<Context> ctx = {Context::cpu()};
41  int num_epoch = 0;
42  int epoch_size = 0;
43  std::string optimizer = "sgd";
44  // TODO(zhangchen-qinyinghua) More implement
45  // initializer=Uniform(0.01),
46  // numpy_batch_size=128,
47  // arg_params=None, aux_params=None,
48  // allow_extra_params=False,
49  // begin_epoch=0,
50  // **kwargs):
53 };
54 class FeedForward {
55  public:
56  explicit FeedForward(const FeedForwardConfig &conf) : conf_(conf) {}
57  void Predict();
58  void Score();
59  void Fit();
60  void Save();
61  void Load();
62  static FeedForward Create();
63 
64  private:
65  void InitParams();
66  void InitPredictor();
67  void InitIter();
68  void InitEvalIter();
69  FeedForwardConfig conf_;
70 };
71 
72 } // namespace cpp
73 } // namespace mxnet
74 
75 #endif // MXNET_CPP_MODEL_H_
76 
definition of symbol
namespace of mxnet
Definition: base.h:126
static Context cpu(int device_id=0)
Return a CPU context.
Definition: ndarray.h:80
FeedForwardConfig(const FeedForwardConfig &other)
Definition: model.h:51
Definition: model.h:54
FeedForward(const FeedForwardConfig &conf)
Definition: model.h:56
int num_epoch
Definition: model.h:41
std::vector< Context > ctx
Definition: model.h:40
Symbol symbol
Definition: model.h:39
Definition: model.h:38
int epoch_size
Definition: model.h:42
FeedForwardConfig()
Definition: model.h:52
std::string optimizer
Definition: model.h:43
Symbol interface.
Definition: symbol.h:71