mxnet
expr_operator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 // Acknowledgement: This file originates from incubator-tvm
28 // Acknowledgement: Most operator APIs originate from Halide.
29 #ifndef MXNET_EXPR_OPERATOR_H_
30 #define MXNET_EXPR_OPERATOR_H_
31 
32 #include <mxnet/ir/expr.h>
33 
34 namespace mxnet {
35 
36 template<typename ValueType>
37 inline PrimExpr MakeConstScalar(MXNetDataType t, ValueType value) {
38  if (t.is_int()) return IntImm(t, static_cast<int64_t>(value));
39  if (t.is_float()) return FloatImm(t, static_cast<double>(value));
40  // customized type and uint is not supported for MXNet for now
41  LOG(FATAL) << "cannot make const for type " << t;
42  return PrimExpr();
43 }
44 
45 
46 template<typename ValueType>
47 inline PrimExpr make_const(MXNetDataType t, ValueType value) {
48  if (t.lanes() == 1) {
49  return MakeConstScalar(t, value);
50  } else {
51  LOG(FATAL) << "MXNetDataType::lanes() != 1 is not supported ";
52  }
53  return PrimExpr();
54 }
55 
56 } // namespace mxnet
57 
58 #endif // MXNET_EXPR_OPERATOR_H_
int lanes() const
Definition: data_type.h:82
Managed reference class to FloatImmNode.
Definition: expr.h:197
PrimExpr make_const(MXNetDataType t, ValueType value)
Definition: expr_operator.h:47
namespace of mxnet
Definition: api_registry.h:33
bool is_float() const
Definition: data_type.h:94
Managed reference class to IntImmNode.
Definition: expr.h:152
Reference to PrimExprNode.
Definition: expr.h:101
PrimExpr MakeConstScalar(MXNetDataType t, ValueType value)
Definition: expr_operator.h:37
Runtime primitive data type.
Definition: data_type.h:41
Base expr nodes in MXNet.
bool is_int() const
Definition: data_type.h:98