1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/ir/expr.cc
22 * \brief The expression AST nodes for the common IR infra.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/ir/expr.h>
26#include <tvm/ir/function.h>
27#include <tvm/runtime/registry.h>
28#include <tvm/te/tensor.h>
29#include <tvm/tir/expr.h>
30
31#include "../support/scalars.h"
32
33namespace tvm {
34
35PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {}
36
37PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {}
38
39PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
40 using runtime::ObjectTypeChecker;
41 if (auto* ptr = ref.as<tir::IterVarNode>()) {
42 return GetRef<tir::IterVar>(ptr)->var;
43 }
44 if (auto* ptr = ref.as<te::TensorNode>()) {
45 return GetRef<te::Tensor>(ptr)();
46 }
47 if (auto* ptr = ref.as<runtime::StringObj>()) {
48 return tir::StringImm(GetRef<runtime::String>(ptr));
49 }
50 if (const auto* buffer_region = ref.as<tir::BufferRegionNode>()) {
51 Array<PrimExpr> indices;
52 indices.reserve(buffer_region->region.size());
53 for (const Range& r : buffer_region->region) {
54 if (tvm::tir::is_one(r->extent)) {
55 indices.push_back(r->min);
56 } else if (const auto* extent = r->extent.as<IntImmNode>()) {
57 indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), extent->value));
58 } else {
59 LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << ref;
60 }
61 }
62 return tir::BufferLoad(buffer_region->buffer, indices);
63 }
64 Optional<String> actual_type = ObjectTypeChecker<PrimExpr>::CheckAndGetMismatch(ref.get());
65 ICHECK(!actual_type.defined()) << "Expected type " << ObjectTypeChecker<PrimExpr>::TypeName()
66 << " but got " << actual_type.value();
67 return Downcast<PrimExpr>(ref);
68}
69
70IntImm::IntImm(DataType dtype, int64_t value, Span span) {
71 ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype
72 << " was supplied.";
73 ICHECK(dtype.is_int() || dtype.is_uint())
74 << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied.";
75 if (dtype.is_uint()) {
76 ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
77 << " is negative for unsigned integer type " << dtype;
78 if (dtype.bits() < 64) {
79 ICHECK_LT(value, 1LL << dtype.bits())
80 << "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
81 }
82 } else if (dtype.bits() == 1) {
83 // int(1)
84 ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype;
85 } else if (dtype.bits() < 64) {
86 ICHECK_GE(value, -(1LL << (dtype.bits() - 1)))
87 << "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
88 ICHECK_LT(value, 1LL << (dtype.bits() - 1))
89 << "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
90 }
91 ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
92 node->dtype = dtype;
93 node->value = value;
94 node->span = span;
95 data_ = std::move(node);
96}
97
98TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value, Span span) {
99 return IntImm(dtype, value, span);
100});
101
102TVM_REGISTER_NODE_TYPE(IntImmNode);
103
104FloatImm::FloatImm(DataType dtype, double value, Span span) {
105 ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";
106
107 ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.code() >= DataType::kCustomBegin)
108 << "ValueError: FloatImm supports only float, but " << dtype << " was supplied.";
109
110 // check range for float32 and float16 since they have specified range.
111 if (!std::isinf(value) && !std::isnan(value)) {
112 if (dtype.bits() == 32) {
113 ICHECK_GE(value, std::numeric_limits<float>::lowest())
114 << "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
115 ICHECK_LE(value, std::numeric_limits<float>::max())
116 << "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
117 } else if (dtype.is_float16()) {
118 ICHECK_GE(value, -support::kMaxFloat16)
119 << "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
120 ICHECK_LE(value, support::kMaxFloat16)
121 << "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
122 }
123 }
124 ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
125 node->dtype = dtype;
126 node->value = value;
127 node->span = span;
128 data_ = std::move(node);
129}
130
131TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value, Span span) {
132 return FloatImm(dtype, value, span);
133});
134
135TVM_REGISTER_NODE_TYPE(FloatImmNode);
136
137Range::Range(PrimExpr begin, PrimExpr end, Span span)
138 : Range(make_object<RangeNode>(begin, tir::is_zero(begin) ? end : (end - begin), span)) {}
139
140Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) {
141 return Range(make_object<RangeNode>(min, extent, span));
142}
143
144TVM_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent);
145
146TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) {
147 *ret = Range(args[0], args[1], args[2]);
148});
149
150TVM_REGISTER_NODE_TYPE(RangeNode);
151
152GlobalVar::GlobalVar(String name_hint, Type type, Span span) {
153 ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
154 n->name_hint = std::move(name_hint);
155 n->checked_type_ = std::move(type);
156 n->span = std::move(span);
157 data_ = std::move(n);
158}
159
160TVM_REGISTER_NODE_TYPE(GlobalVarNode);
161
162TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) {
163 return GlobalVar(name, type);
164});
165
166TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) {
167 std::stringstream ss;
168 ss << ref;
169 return ss.str();
170});
171
172} // namespace tvm
173