1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/ir/expr.cc |
22 | * \brief The expression AST nodes for the common IR infra. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/ir/expr.h> |
26 | #include <tvm/ir/function.h> |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/te/tensor.h> |
29 | #include <tvm/tir/expr.h> |
30 | |
31 | #include "../support/scalars.h" |
32 | |
33 | namespace tvm { |
34 | |
35 | PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} |
36 | |
37 | PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} |
38 | |
39 | PrimExpr PrimExpr::FromObject_(ObjectRef ref) { |
40 | using runtime::ObjectTypeChecker; |
41 | if (auto* ptr = ref.as<tir::IterVarNode>()) { |
42 | return GetRef<tir::IterVar>(ptr)->var; |
43 | } |
44 | if (auto* ptr = ref.as<te::TensorNode>()) { |
45 | return GetRef<te::Tensor>(ptr)(); |
46 | } |
47 | if (auto* ptr = ref.as<runtime::StringObj>()) { |
48 | return tir::StringImm(GetRef<runtime::String>(ptr)); |
49 | } |
50 | if (const auto* buffer_region = ref.as<tir::BufferRegionNode>()) { |
51 | Array<PrimExpr> indices; |
52 | indices.reserve(buffer_region->region.size()); |
53 | for (const Range& r : buffer_region->region) { |
54 | if (tvm::tir::is_one(r->extent)) { |
55 | indices.push_back(r->min); |
56 | } else if (const auto* extent = r->extent.as<IntImmNode>()) { |
57 | indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), extent->value)); |
58 | } else { |
59 | LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << ref; |
60 | } |
61 | } |
62 | return tir::BufferLoad(buffer_region->buffer, indices); |
63 | } |
64 | Optional<String> actual_type = ObjectTypeChecker<PrimExpr>::CheckAndGetMismatch(ref.get()); |
65 | ICHECK(!actual_type.defined()) << "Expected type " << ObjectTypeChecker<PrimExpr>::TypeName() |
66 | << " but got " << actual_type.value(); |
67 | return Downcast<PrimExpr>(ref); |
68 | } |
69 | |
70 | IntImm::IntImm(DataType dtype, int64_t value, Span span) { |
71 | ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype |
72 | << " was supplied." ; |
73 | ICHECK(dtype.is_int() || dtype.is_uint()) |
74 | << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied." ; |
75 | if (dtype.is_uint()) { |
76 | ICHECK_GE(value, 0U) << "ValueError: Literal value " << value |
77 | << " is negative for unsigned integer type " << dtype; |
78 | if (dtype.bits() < 64) { |
79 | ICHECK_LT(value, 1LL << dtype.bits()) |
80 | << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; |
81 | } |
82 | } else if (dtype.bits() == 1) { |
83 | // int(1) |
84 | ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype; |
85 | } else if (dtype.bits() < 64) { |
86 | ICHECK_GE(value, -(1LL << (dtype.bits() - 1))) |
87 | << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; |
88 | ICHECK_LT(value, 1LL << (dtype.bits() - 1)) |
89 | << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; |
90 | } |
91 | ObjectPtr<IntImmNode> node = make_object<IntImmNode>(); |
92 | node->dtype = dtype; |
93 | node->value = value; |
94 | node->span = span; |
95 | data_ = std::move(node); |
96 | } |
97 | |
98 | TVM_REGISTER_GLOBAL("ir.IntImm" ).set_body_typed([](DataType dtype, int64_t value, Span span) { |
99 | return IntImm(dtype, value, span); |
100 | }); |
101 | |
102 | TVM_REGISTER_NODE_TYPE(IntImmNode); |
103 | |
104 | FloatImm::FloatImm(DataType dtype, double value, Span span) { |
105 | ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar." ; |
106 | |
107 | ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.code() >= DataType::kCustomBegin) |
108 | << "ValueError: FloatImm supports only float, but " << dtype << " was supplied." ; |
109 | |
110 | // check range for float32 and float16 since they have specified range. |
111 | if (!std::isinf(value) && !std::isnan(value)) { |
112 | if (dtype.bits() == 32) { |
113 | ICHECK_GE(value, std::numeric_limits<float>::lowest()) |
114 | << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; |
115 | ICHECK_LE(value, std::numeric_limits<float>::max()) |
116 | << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; |
117 | } else if (dtype.is_float16()) { |
118 | ICHECK_GE(value, -support::kMaxFloat16) |
119 | << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; |
120 | ICHECK_LE(value, support::kMaxFloat16) |
121 | << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; |
122 | } |
123 | } |
124 | ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>(); |
125 | node->dtype = dtype; |
126 | node->value = value; |
127 | node->span = span; |
128 | data_ = std::move(node); |
129 | } |
130 | |
131 | TVM_REGISTER_GLOBAL("ir.FloatImm" ).set_body_typed([](DataType dtype, double value, Span span) { |
132 | return FloatImm(dtype, value, span); |
133 | }); |
134 | |
135 | TVM_REGISTER_NODE_TYPE(FloatImmNode); |
136 | |
137 | Range::Range(PrimExpr begin, PrimExpr end, Span span) |
138 | : Range(make_object<RangeNode>(begin, tir::is_zero(begin) ? end : (end - begin), span)) {} |
139 | |
140 | Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { |
141 | return Range(make_object<RangeNode>(min, extent, span)); |
142 | } |
143 | |
144 | TVM_REGISTER_GLOBAL("ir.Range_from_min_extent" ).set_body_typed(Range::FromMinExtent); |
145 | |
146 | TVM_REGISTER_GLOBAL("ir.Range" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
147 | *ret = Range(args[0], args[1], args[2]); |
148 | }); |
149 | |
150 | TVM_REGISTER_NODE_TYPE(RangeNode); |
151 | |
152 | GlobalVar::GlobalVar(String name_hint, Type type, Span span) { |
153 | ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>(); |
154 | n->name_hint = std::move(name_hint); |
155 | n->checked_type_ = std::move(type); |
156 | n->span = std::move(span); |
157 | data_ = std::move(n); |
158 | } |
159 | |
160 | TVM_REGISTER_NODE_TYPE(GlobalVarNode); |
161 | |
162 | TVM_REGISTER_GLOBAL("ir.GlobalVar" ).set_body_typed([](String name, Type type) { |
163 | return GlobalVar(name, type); |
164 | }); |
165 | |
166 | TVM_REGISTER_GLOBAL("ir.DebugPrint" ).set_body_typed([](ObjectRef ref) { |
167 | std::stringstream ss; |
168 | ss << ref; |
169 | return ss.str(); |
170 | }); |
171 | |
172 | } // namespace tvm |
173 | |