mxnet
expr_operator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 // Acknowledgement: This file originates from incubator-tvm
28 // Acknowledgement: Most operator APIs originate from Halide.
29 #ifndef MXNET_EXPR_OPERATOR_H_
30 #define MXNET_EXPR_OPERATOR_H_
31 
32 #include <mxnet/ir/expr.h>
33 
34 namespace mxnet {
35 
36 template <typename ValueType>
37 inline PrimExpr MakeConstScalar(MXNetDataType t, ValueType value) {
38  if (t.is_int())
39  return IntImm(t, static_cast<int64_t>(value));
40  if (t.is_float())
41  return FloatImm(t, static_cast<double>(value));
42  // customized type and uint is not supported for MXNet for now
43  LOG(FATAL) << "cannot make const for type " << t;
44  return PrimExpr();
45 }
46 
47 template <typename ValueType>
48 inline PrimExpr make_const(MXNetDataType t, ValueType value) {
49  if (t.lanes() == 1) {
50  return MakeConstScalar(t, value);
51  } else {
52  LOG(FATAL) << "MXNetDataType::lanes() != 1 is not supported ";
53  }
54  return PrimExpr();
55 }
56 
57 } // namespace mxnet
58 
59 #endif // MXNET_EXPR_OPERATOR_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
expr.h
Base expr nodes in MXNet.
mxnet::runtime::MXNetDataType::lanes
int lanes() const
Definition: data_type.h:80
mxnet::make_const
PrimExpr make_const(MXNetDataType t, ValueType value)
Definition: expr_operator.h:48
mxnet::FloatImm
Managed reference class to FloatImmNode.
Definition: expr.h:197
mxnet::runtime::MXNetDataType::is_float
bool is_float() const
Definition: data_type.h:92
mxnet::PrimExpr
Reference to PrimExprNode.
Definition: expr.h:101
mxnet::runtime::MXNetDataType::is_int
bool is_int() const
Definition: data_type.h:96
mxnet::runtime::MXNetDataType
Runtime primitive data type.
Definition: data_type.h:40
mxnet::MakeConstScalar
PrimExpr MakeConstScalar(MXNetDataType t, ValueType value)
Definition: expr_operator.h:37
mxnet::IntImm
Managed reference class to IntImmNode.
Definition: expr.h:152