mxnet
expr.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 // Acknowledgement: This file originates from incubator-tvm
25 #ifndef MXNET_IR_EXPR_H_
26 #define MXNET_IR_EXPR_H_
27 
28 #include <mxnet/runtime/object.h>
29 #include <mxnet/node/node.h>
30 #include <mxnet/node/container.h>
32 #include <string>
33 
34 namespace mxnet {
35 
40 class BaseExprNode : public Object {
41  public:
42  static constexpr const char* _type_key = "Expr";
44 };
45 
50 class BaseExpr : public ObjectRef {
51  public:
53  BaseExpr() {}
58  explicit BaseExpr(runtime::ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
61 };
62 
75 class PrimExprNode : public BaseExprNode {
76  public:
92 
93  static constexpr const char* _type_key = "PrimExpr";
95 };
96 
101 class PrimExpr : public BaseExpr {
102  public:
104  PrimExpr() {}
114  MXNET_DLL PrimExpr(int32_t value); // NOLINT(*)
119  MXNET_DLL PrimExpr(float value); // NOLINT(*)
124  MXNET_DLL PrimExpr(std::string str); // NOLINT(*)
125 
128  return static_cast<const PrimExprNode*>(get())->dtype;
129  }
132 };
133 
138 class IntImmNode : public PrimExprNode {
139  public:
141  int64_t value;
142 
143  static constexpr const char* _type_key = "IntImm";
145 };
146 
152 class IntImm : public PrimExpr {
153  public:
157  IntImm() {}
161  explicit IntImm(runtime::ObjectPtr<Object> node) : PrimExpr(node) {}
167  MXNET_DLL IntImm(MXNetDataType dtype, int64_t value);
172  const IntImmNode* operator->() const {
173  return static_cast<const IntImmNode*>(get());
174  }
177 };
178 
183 class FloatImmNode : public PrimExprNode {
184  public:
186  double value;
187 
188  static constexpr const char* _type_key = "FloatImm";
190 };
191 
197 class FloatImm : public PrimExpr {
198  public:
202  FloatImm() {}
206  explicit FloatImm(runtime::ObjectPtr<Object> node) : PrimExpr(node) {}
212  MXNET_DLL FloatImm(MXNetDataType dtype, double value);
217  const FloatImmNode* operator->() const {
218  return static_cast<const FloatImmNode*>(get());
219  }
222 };
223 
224 } // namespace mxnet
225 #endif // MXNET_IR_EXPR_H_
Constant floating point literals in the program.
Definition: expr.h:183
Constant integer literals in the program.
Definition: expr.h:138
FloatImm()
Constructor.
Definition: expr.h:202
const FloatImmNode * operator->() const
Get pointer to the container.
Definition: expr.h:217
Managed reference class to FloatImmNode.
Definition: expr.h:197
namespace of mxnet
Definition: api_registry.h:33
IntImm(runtime::ObjectPtr< Object > node)
constructor from node.
Definition: expr.h:161
A custom smart pointer for Object.
Definition: object.h:345
double value
The constant value content.
Definition: expr.h:186
Managed reference class to IntImmNode.
Definition: expr.h:152
Reference to PrimExprNode.
Definition: expr.h:101
const IntImmNode * operator->() const
Get pointer to the internal value.
Definition: expr.h:172
IntImm()
Constructor.
Definition: expr.h:157
PrimExpr(runtime::ObjectPtr< Object > ptr)
Cosntructor from object ptr.
Definition: expr.h:109
Runtime primitive data type.
Definition: data_type.h:41
Base node of all primitive expressions.
Definition: expr.h:75
A managed object in MXNet runtime.
Base type of all the expressions.
Definition: expr.h:40
BaseExpr(runtime::ObjectPtr< Object > ptr)
Cosntructor from object ptr.
Definition: expr.h:58
MXNET_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object)
static constexpr const char * _type_key
Definition: expr.h:42
#define MXNET_DLL
MXNET_DLL prefix for windows.
Definition: c_api.h:54
PrimExpr()
Cosntructor.
Definition: expr.h:104
int64_t value
the Internal value.
Definition: expr.h:141
BaseExpr()
Cosntructor.
Definition: expr.h:53
FloatImm(runtime::ObjectPtr< Object > node)
constructor from node.
Definition: expr.h:206
MXNetDataType dtype
The runtime data type of the primitive expression.
Definition: expr.h:91
Managed reference to BaseExprNode.
Definition: expr.h:50
MXNetDataType dtype() const
Definition: expr.h:127
#define MXNET_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType)
helper macro to declare type information in a final class.
Definition: object.h:669