mxnet
ndarray.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_NDARRAY_H_
27 #define MXNET_CPP_NDARRAY_H_
28 
29 #include <map>
30 #include <memory>
31 #include <string>
32 #include <vector>
33 #include <iostream>
34 #include "mxnet-cpp/base.h"
35 #include "mxnet-cpp/shape.h"
36 
37 namespace mxnet {
38 namespace cpp {
39 
40 enum DeviceType {
41  kCPU = 1,
42  kGPU = 2,
44 };
45 
49 class Context {
50  public:
56  Context(const DeviceType &type, int id) : type_(type), id_(id) {}
60  DeviceType GetDeviceType() const { return type_; }
64  int GetDeviceId() const { return id_; }
65 
71  static Context gpu(int device_id = 0) {
72  return Context(DeviceType::kGPU, device_id);
73  }
74 
80  static Context cpu(int device_id = 0) {
81  return Context(DeviceType::kCPU, device_id);
82  }
83 
84  private:
85  DeviceType type_;
86  int id_;
87 };
88 
92 struct NDBlob {
93  public:
97  NDBlob() : handle_(nullptr) {}
102  explicit NDBlob(NDArrayHandle handle) : handle_(handle) {}
106  ~NDBlob() { MXNDArrayFree(handle_); }
111 
112  private:
113  NDBlob(const NDBlob &);
114  NDBlob &operator=(const NDBlob &);
115 };
116 
120 class NDArray {
121  public:
125  NDArray();
129  explicit NDArray(const NDArrayHandle &handle);
136  NDArray(const std::vector<mx_uint> &shape, const Context &context,
137  bool delay_alloc = true);
144  NDArray(const Shape &shape, const Context &context, bool delay_alloc = true);
145  NDArray(const mx_float *data, size_t size);
152  NDArray(const mx_float *data, const Shape &shape, const Context &context);
159  NDArray(const std::vector<mx_float> &data, const Shape &shape,
160  const Context &context);
161  explicit NDArray(const std::vector<mx_float> &data);
162  NDArray operator+(mx_float scalar);
163  NDArray operator-(mx_float scalar);
164  NDArray operator*(mx_float scalar);
165  NDArray operator/(mx_float scalar);
166  NDArray operator%(mx_float scalar);
167  NDArray operator+(const NDArray &);
168  NDArray operator-(const NDArray &);
169  NDArray operator*(const NDArray &);
170  NDArray operator/(const NDArray &);
171  NDArray operator%(const NDArray &);
177  NDArray &operator=(mx_float scalar);
184  NDArray &operator+=(mx_float scalar);
191  NDArray &operator-=(mx_float scalar);
198  NDArray &operator*=(mx_float scalar);
205  NDArray &operator/=(mx_float scalar);
212  NDArray &operator%=(mx_float scalar);
219  NDArray &operator+=(const NDArray &src);
226  NDArray &operator-=(const NDArray &src);
233  NDArray &operator*=(const NDArray &src);
240  NDArray &operator/=(const NDArray &src);
247  NDArray &operator%=(const NDArray &src);
248  NDArray ArgmaxChannel();
259  void SyncCopyFromCPU(const mx_float *data, size_t size);
269  void SyncCopyFromCPU(const std::vector<mx_float> &data);
280  void SyncCopyToCPU(mx_float *data, size_t size = 0);
291  void SyncCopyToCPU(std::vector<mx_float> *data, size_t size = 0);
297  NDArray CopyTo(NDArray * other) const;
303  NDArray Copy(const Context &) const;
310  size_t Offset(size_t h = 0, size_t w = 0) const;
318  size_t Offset(size_t c, size_t h, size_t w) const;
325  mx_float At(size_t h, size_t w) const;
333  mx_float At(size_t c, size_t h, size_t w) const;
340  NDArray Slice(mx_uint begin, mx_uint end) const;
346  NDArray Reshape(const Shape &new_shape) const;
351  void WaitToRead() const;
356  void WaitToWrite();
361  static void WaitAll();
368  static void SampleGaussian(mx_float mu, mx_float sigma, NDArray *out);
375  static void SampleUniform(mx_float begin, mx_float end, NDArray *out);
384  static void Load(const std::string &file_name,
385  std::vector<NDArray> *array_list = nullptr,
386  std::map<std::string, NDArray> *array_map = nullptr);
392  static std::map<std::string, NDArray> LoadToMap(const std::string &file_name);
398  static std::vector<NDArray> LoadToList(const std::string &file_name);
404  static void Save(const std::string &file_name,
405  const std::map<std::string, NDArray> &array_map);
411  static void Save(const std::string &file_name,
412  const std::vector<NDArray> &array_list);
416  size_t Size() const;
420  std::vector<mx_uint> GetShape() const;
424  int GetDType() const;
429  const mx_float *GetData() const;
430 
434  Context GetContext() const;
435 
439  NDArrayHandle GetHandle() const { return blob_ptr_->handle_; }
440 
441  private:
442  std::shared_ptr<NDBlob> blob_ptr_;
443 };
444 
445 std::ostream& operator<<(std::ostream& out, const NDArray &ndarray);
446 } // namespace cpp
447 } // namespace mxnet
448 
449 #endif // MXNET_CPP_NDARRAY_H_
NDBlob()
default constructor
Definition: ndarray.h:97
Symbol operator/(mx_float lhs, const Symbol &rhs)
namespace of mxnet
Definition: base.h:126
Definition: ndarray.h:43
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
MXNET_DLL int MXNDArrayFree(NDArrayHandle handle)
free the narray handle
~NDBlob()
destructor, free the NDArrayHandle
Definition: ndarray.h:106
static Context cpu(int device_id=0)
Return a CPU context.
Definition: ndarray.h:80
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=0, Shape target_shape=Shape(), bool keep_highest=0)
Definition: op.h:301
Symbol operator%(mx_float lhs, const Symbol &rhs)
Definition: ndarray.h:41
DeviceType
Definition: ndarray.h:40
NDArray interface.
Definition: ndarray.h:120
NDBlob(NDArrayHandle handle)
construct with a NDArrayHandle
Definition: ndarray.h:102
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:64
NDArrayHandle GetHandle() const
Definition: ndarray.h:439
NDArrayHandle handle_
the NDArrayHandle
Definition: ndarray.h:110
Symbol operator+(mx_float lhs, const Symbol &rhs)
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:57
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
DeviceType GetDeviceType() const
Definition: ndarray.h:60
std::ostream & operator<<(std::ostream &out, const NDArray &ndarray)
int GetDeviceId() const
Definition: ndarray.h:64
Definition: ndarray.h:42
float mx_float
manually define float
Definition: c_api.h:59
Symbol operator-(mx_float lhs, const Symbol &rhs)
struct to store NDArrayHandle
Definition: ndarray.h:92
definition of shape
static Context gpu(int device_id=0)
Return a GPU context.
Definition: ndarray.h:71
Context interface.
Definition: ndarray.h:49
Symbol operator*(mx_float lhs, const Symbol &rhs)
Context(const DeviceType &type, int id)
Context constructor.
Definition: ndarray.h:56