1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file unsafe_select_rewrite.cc
22 * \brief Rewrite uinsafe select expression.
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/builtin.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/op_attr_types.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/transform.h>
30
31namespace tvm {
32namespace tir {
33
34// For now, rewrite unsafe select expression to if_then_else
35// TODO(tqchen) pattern matching to support masked load
36class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> {
37 public:
38 // select itself is always considered safe if condition is safe
39 // Because we will issue guard to make sure it is.
40 bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); }
41 bool VisitExpr_(const CallNode* op) {
42 if (op->op.same_as(builtin::if_then_else())) {
43 return VisitExpr(op->args[0]);
44 } else if (op->op.same_as(builtin::address_of())) {
45 const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
46 for (const auto& index : load->indices) {
47 if (VisitExpr(index)) {
48 return true;
49 }
50 }
51 return false;
52 } else if (auto* ptr_op = op->op.as<OpNode>()) {
53 auto effect_kind = op_call_effect_[GetRef<Op>(ptr_op)];
54 if (effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation) {
55 for (PrimExpr e : op->args) {
56 if (VisitExpr(e)) return true;
57 }
58 return false;
59 } else {
60 return true;
61 }
62 } else {
63 return true;
64 }
65 }
66 bool VisitExpr_(const BufferLoadNode* op) {
67 // Load is considered unsafe.
68 return true;
69 }
70 bool VisitExpr_(const LoadNode* op) {
71 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
72 }
73 bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); }
74 bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); }
75 bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); }
76 bool VisitExpr_(const DivNode* op) final { return BinaryOp(op); }
77 bool VisitExpr_(const ModNode* op) final { return BinaryOp(op); }
78 bool VisitExpr_(const FloorDivNode* op) final { return BinaryOp(op); }
79 bool VisitExpr_(const FloorModNode* op) final { return BinaryOp(op); }
80 bool VisitExpr_(const MinNode* op) final { return BinaryOp(op); }
81 bool VisitExpr_(const MaxNode* op) final { return BinaryOp(op); }
82 bool VisitExpr_(const EQNode* op) final { return BinaryOp(op); }
83 bool VisitExpr_(const NENode* op) final { return BinaryOp(op); }
84 bool VisitExpr_(const LTNode* op) final { return BinaryOp(op); }
85 bool VisitExpr_(const LENode* op) final { return BinaryOp(op); }
86 bool VisitExpr_(const GTNode* op) final { return BinaryOp(op); }
87 bool VisitExpr_(const GENode* op) final { return BinaryOp(op); }
88 bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); }
89 bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); }
90 bool VisitExpr_(const NotNode* op) final { return VisitExpr(op->a); }
91 bool VisitExpr_(const LetNode* op) final { return VisitExpr(op->body) || VisitExpr(op->value); }
92 bool VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
93 bool VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); }
94 bool VisitExpr_(const RampNode* op) final { return VisitExpr(op->base) && VisitExpr(op->stride); }
95 bool VisitExpr_(const ShuffleNode* op) final {
96 for (PrimExpr e : op->vectors) {
97 if (VisitExpr(e)) return true;
98 }
99 return false;
100 }
101 bool VisitExpr_(const VarNode* op) final { return false; }
102 bool VisitExpr_(const IntImmNode* op) final { return false; }
103 bool VisitExpr_(const FloatImmNode* op) final { return false; }
104 bool VisitExpr_(const StringImmNode* op) final { return false; }
105
106 private:
107 template <typename T>
108 bool BinaryOp(const T* op) {
109 return VisitExpr(op->a) || VisitExpr(op->b);
110 }
111
112 OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
113};
114
115class UnsafeSelectRewriter : public StmtExprMutator {
116 public:
117 PrimExpr VisitExpr_(const SelectNode* op) {
118 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
119 op = expr.as<SelectNode>();
120 UnsafeExprDetector unsafe;
121 bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
122 if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) &&
123 cond_is_scalar_bool) {
124 return Call(op->dtype, builtin::if_then_else(),
125 {op->condition, op->true_value, op->false_value});
126 } else {
127 return expr;
128 }
129 }
130};
131
132Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); }
133
134namespace transform {
135
136Pass RewriteUnsafeSelect() {
137 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
138 auto* n = f.CopyOnWrite();
139 n->body = UnsafeSelectRewriter()(std::move(n->body));
140 return f;
141 };
142 return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {});
143}
144
145TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect);
146
147} // namespace transform
148
149} // namespace tir
150} // namespace tvm
151