1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file type_relations.cc
22 * \brief A set of utilities and common functionality
23 * for type relations.
24 */
25#include "./type_relations.h"
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/relay/attrs/transform.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/op.h>
31#include <tvm/tir/op.h>
32
33#include <numeric>
34
35namespace tvm {
36namespace relay {
37
38bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
39 const TypeReporter& reporter) {
40 for (size_t i = 1; i < types.size(); ++i) {
41 reporter->Assign(types[i], types[0]);
42 }
43 return true;
44}
45
46bool EqualCheck(const IndexExpr& lhs, const IndexExpr& rhs) {
47 IndexExpr diff = lhs - rhs;
48 if (const int64_t* pdiff = tir::as_const_int(diff)) {
49 return pdiff[0] == 0;
50 }
51 // symbolic
52 tvm::arith::Analyzer ana;
53 diff = ana.Simplify(diff);
54 if (const int64_t* pdiff = tir::as_const_int(diff)) {
55 return pdiff[0] == 0;
56 }
57 return false;
58}
59
60bool EqualConstInt(const IndexExpr& lhs, int64_t value) {
61 if (const int64_t* pvalue = tir::as_const_int(lhs)) {
62 return pvalue[0] == value;
63 }
64 return false;
65}
66
67TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
68 std::vector<IndexExpr> oshape;
69 size_t ndim1 = t1->shape.size();
70 size_t ndim2 = t2->shape.size();
71 size_t i = 1;
72 for (; i <= std::min(ndim1, ndim2); ++i) {
73 IndexExpr s1 = t1->shape[ndim1 - i];
74 IndexExpr s2 = t2->shape[ndim2 - i];
75 if (EqualConstInt(s1, 1)) {
76 oshape.push_back(s2);
77 } else if (EqualConstInt(s2, 1)) {
78 oshape.push_back(s1);
79 } else if (s1.as<AnyNode>()) {
80 // s1 == 1 || s1 == s2
81 oshape.push_back(s2);
82 } else if (s2.as<AnyNode>()) {
83 // s2 == 1 || s2 == s1
84 oshape.push_back(s1);
85 } else if (EqualCheck(s1, s2)) {
86 oshape.push_back(s1);
87 } else {
88 throw CompileError(ErrorBuilder() << "Incompatible broadcast type " << t1 << " and " << t2);
89 }
90 }
91
92 size_t max_ndim = std::max(ndim1, ndim2);
93 auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape;
94 for (; i <= max_ndim; ++i) {
95 oshape.push_back(rshape[max_ndim - i]);
96 }
97 return TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), output_dtype);
98}
99
100bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
101 const TypeReporter& reporter) {
102 ICHECK_EQ(types.size(), 3);
103 // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
104 // << ",Out:" << types[2] << std::endl;
105 if (auto* t0 = types[0].as<TensorTypeNode>()) {
106 if (auto* t1 = types[1].as<TensorTypeNode>()) {
107 if (t0->dtype != t1->dtype) {
108 reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span)
109 << "data types " << t0->dtype << " and " << t1->dtype
110 << " do not match in BroadcastRel");
111 }
112 reporter->Assign(
113 types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
114 return true;
115 }
116 }
117 return false;
118}
119
120bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
121 const TypeReporter& reporter) {
122 ICHECK_EQ(types.size(), 3);
123 // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1]
124 // << ",Out:" << types[2] << std::endl;
125 if (auto* t0 = types[0].as<TensorTypeNode>()) {
126 if (auto* t1 = types[1].as<TensorTypeNode>()) {
127 if (t0->dtype != t1->dtype) {
128 reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span)
129 << "data types " << t0->dtype << " and " << t1->dtype
130 << " do not match in BroadcastCompRel");
131 }
132 reporter->Assign(types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1),
133 DataType::Bool()));
134 return true;
135 }
136 }
137 return false;
138}
139
140bool IdentityCompRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
141 const TypeReporter& reporter) {
142 if (auto* t0 = types[0].as<TensorTypeNode>()) {
143 Type out_type = TensorType(GetRef<TensorType>(t0)->shape, DataType::Bool());
144 reporter->Assign(types[1], out_type);
145 return true;
146 }
147 return false;
148}
149
150Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
151 if (shape.size() == 0) {
152 return {};
153 } else {
154 return {tvm::Integer(shape.size())};
155 }
156}
157
158bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
159 const TypeReporter& reporter) {
160 ICHECK_EQ(num_inputs, 1);
161 auto tt = types[0].as<TensorTypeNode>();
162 if (tt == nullptr) {
163 return false;
164 }
165 const auto* param = attrs.as<ShapeOfAttrs>();
166 ICHECK(param != nullptr);
167 auto rank_shape = RankShape(tt->shape);
168 reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
169 return true;
170}
171
172} // namespace relay
173} // namespace tvm
174