1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file type_relations.cc |
22 | * \brief A set of utilities and common functionality |
23 | * for type relations. |
24 | */ |
25 | #include "./type_relations.h" |
26 | |
27 | #include <tvm/arith/analyzer.h> |
28 | #include <tvm/relay/attrs/transform.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/op.h> |
31 | #include <tvm/tir/op.h> |
32 | |
33 | #include <numeric> |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | |
38 | bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
39 | const TypeReporter& reporter) { |
40 | for (size_t i = 1; i < types.size(); ++i) { |
41 | reporter->Assign(types[i], types[0]); |
42 | } |
43 | return true; |
44 | } |
45 | |
46 | bool EqualCheck(const IndexExpr& lhs, const IndexExpr& rhs) { |
47 | IndexExpr diff = lhs - rhs; |
48 | if (const int64_t* pdiff = tir::as_const_int(diff)) { |
49 | return pdiff[0] == 0; |
50 | } |
51 | // symbolic |
52 | tvm::arith::Analyzer ana; |
53 | diff = ana.Simplify(diff); |
54 | if (const int64_t* pdiff = tir::as_const_int(diff)) { |
55 | return pdiff[0] == 0; |
56 | } |
57 | return false; |
58 | } |
59 | |
60 | bool EqualConstInt(const IndexExpr& lhs, int64_t value) { |
61 | if (const int64_t* pvalue = tir::as_const_int(lhs)) { |
62 | return pvalue[0] == value; |
63 | } |
64 | return false; |
65 | } |
66 | |
67 | TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { |
68 | std::vector<IndexExpr> oshape; |
69 | size_t ndim1 = t1->shape.size(); |
70 | size_t ndim2 = t2->shape.size(); |
71 | size_t i = 1; |
72 | for (; i <= std::min(ndim1, ndim2); ++i) { |
73 | IndexExpr s1 = t1->shape[ndim1 - i]; |
74 | IndexExpr s2 = t2->shape[ndim2 - i]; |
75 | if (EqualConstInt(s1, 1)) { |
76 | oshape.push_back(s2); |
77 | } else if (EqualConstInt(s2, 1)) { |
78 | oshape.push_back(s1); |
79 | } else if (s1.as<AnyNode>()) { |
80 | // s1 == 1 || s1 == s2 |
81 | oshape.push_back(s2); |
82 | } else if (s2.as<AnyNode>()) { |
83 | // s2 == 1 || s2 == s1 |
84 | oshape.push_back(s1); |
85 | } else if (EqualCheck(s1, s2)) { |
86 | oshape.push_back(s1); |
87 | } else { |
88 | throw CompileError(ErrorBuilder() << "Incompatible broadcast type " << t1 << " and " << t2); |
89 | } |
90 | } |
91 | |
92 | size_t max_ndim = std::max(ndim1, ndim2); |
93 | auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape; |
94 | for (; i <= max_ndim; ++i) { |
95 | oshape.push_back(rshape[max_ndim - i]); |
96 | } |
97 | return TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), output_dtype); |
98 | } |
99 | |
100 | bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
101 | const TypeReporter& reporter) { |
102 | ICHECK_EQ(types.size(), 3); |
103 | // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] |
104 | // << ",Out:" << types[2] << std::endl; |
105 | if (auto* t0 = types[0].as<TensorTypeNode>()) { |
106 | if (auto* t1 = types[1].as<TensorTypeNode>()) { |
107 | if (t0->dtype != t1->dtype) { |
108 | reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span) |
109 | << "data types " << t0->dtype << " and " << t1->dtype |
110 | << " do not match in BroadcastRel" ); |
111 | } |
112 | reporter->Assign( |
113 | types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype)); |
114 | return true; |
115 | } |
116 | } |
117 | return false; |
118 | } |
119 | |
120 | bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
121 | const TypeReporter& reporter) { |
122 | ICHECK_EQ(types.size(), 3); |
123 | // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] |
124 | // << ",Out:" << types[2] << std::endl; |
125 | if (auto* t0 = types[0].as<TensorTypeNode>()) { |
126 | if (auto* t1 = types[1].as<TensorTypeNode>()) { |
127 | if (t0->dtype != t1->dtype) { |
128 | reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span) |
129 | << "data types " << t0->dtype << " and " << t1->dtype |
130 | << " do not match in BroadcastCompRel" ); |
131 | } |
132 | reporter->Assign(types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), |
133 | DataType::Bool())); |
134 | return true; |
135 | } |
136 | } |
137 | return false; |
138 | } |
139 | |
140 | bool IdentityCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
141 | const TypeReporter& reporter) { |
142 | if (auto* t0 = types[0].as<TensorTypeNode>()) { |
143 | Type out_type = TensorType(GetRef<TensorType>(t0)->shape, DataType::Bool()); |
144 | reporter->Assign(types[1], out_type); |
145 | return true; |
146 | } |
147 | return false; |
148 | } |
149 | |
150 | Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) { |
151 | if (shape.size() == 0) { |
152 | return {}; |
153 | } else { |
154 | return {tvm::Integer(shape.size())}; |
155 | } |
156 | } |
157 | |
158 | bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
159 | const TypeReporter& reporter) { |
160 | ICHECK_EQ(num_inputs, 1); |
161 | auto tt = types[0].as<TensorTypeNode>(); |
162 | if (tt == nullptr) { |
163 | return false; |
164 | } |
165 | const auto* param = attrs.as<ShapeOfAttrs>(); |
166 | ICHECK(param != nullptr); |
167 | auto rank_shape = RankShape(tt->shape); |
168 | reporter->Assign(types[1], TensorType(rank_shape, param->dtype)); |
169 | return true; |
170 | } |
171 | |
172 | } // namespace relay |
173 | } // namespace tvm |
174 | |