1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file unsafe_select_rewrite.cc |
22 | * \brief Rewrite uinsafe select expression. |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/builtin.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/op_attr_types.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | #include <tvm/tir/transform.h> |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | // For now, rewrite unsafe select expression to if_then_else |
35 | // TODO(tqchen) pattern matching to support masked load |
36 | class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> { |
37 | public: |
38 | // select itself is always considered safe if condition is safe |
39 | // Because we will issue guard to make sure it is. |
40 | bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } |
41 | bool VisitExpr_(const CallNode* op) { |
42 | if (op->op.same_as(builtin::if_then_else())) { |
43 | return VisitExpr(op->args[0]); |
44 | } else if (op->op.same_as(builtin::address_of())) { |
45 | const BufferLoadNode* load = op->args[0].as<BufferLoadNode>(); |
46 | for (const auto& index : load->indices) { |
47 | if (VisitExpr(index)) { |
48 | return true; |
49 | } |
50 | } |
51 | return false; |
52 | } else if (auto* ptr_op = op->op.as<OpNode>()) { |
53 | auto effect_kind = op_call_effect_[GetRef<Op>(ptr_op)]; |
54 | if (effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation) { |
55 | for (PrimExpr e : op->args) { |
56 | if (VisitExpr(e)) return true; |
57 | } |
58 | return false; |
59 | } else { |
60 | return true; |
61 | } |
62 | } else { |
63 | return true; |
64 | } |
65 | } |
66 | bool VisitExpr_(const BufferLoadNode* op) { |
67 | // Load is considered unsafe. |
68 | return true; |
69 | } |
70 | bool VisitExpr_(const LoadNode* op) { |
71 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
72 | } |
73 | bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); } |
74 | bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); } |
75 | bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); } |
76 | bool VisitExpr_(const DivNode* op) final { return BinaryOp(op); } |
77 | bool VisitExpr_(const ModNode* op) final { return BinaryOp(op); } |
78 | bool VisitExpr_(const FloorDivNode* op) final { return BinaryOp(op); } |
79 | bool VisitExpr_(const FloorModNode* op) final { return BinaryOp(op); } |
80 | bool VisitExpr_(const MinNode* op) final { return BinaryOp(op); } |
81 | bool VisitExpr_(const MaxNode* op) final { return BinaryOp(op); } |
82 | bool VisitExpr_(const EQNode* op) final { return BinaryOp(op); } |
83 | bool VisitExpr_(const NENode* op) final { return BinaryOp(op); } |
84 | bool VisitExpr_(const LTNode* op) final { return BinaryOp(op); } |
85 | bool VisitExpr_(const LENode* op) final { return BinaryOp(op); } |
86 | bool VisitExpr_(const GTNode* op) final { return BinaryOp(op); } |
87 | bool VisitExpr_(const GENode* op) final { return BinaryOp(op); } |
88 | bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); } |
89 | bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); } |
90 | bool VisitExpr_(const NotNode* op) final { return VisitExpr(op->a); } |
91 | bool VisitExpr_(const LetNode* op) final { return VisitExpr(op->body) || VisitExpr(op->value); } |
92 | bool VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } |
93 | bool VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); } |
94 | bool VisitExpr_(const RampNode* op) final { return VisitExpr(op->base) && VisitExpr(op->stride); } |
95 | bool VisitExpr_(const ShuffleNode* op) final { |
96 | for (PrimExpr e : op->vectors) { |
97 | if (VisitExpr(e)) return true; |
98 | } |
99 | return false; |
100 | } |
101 | bool VisitExpr_(const VarNode* op) final { return false; } |
102 | bool VisitExpr_(const IntImmNode* op) final { return false; } |
103 | bool VisitExpr_(const FloatImmNode* op) final { return false; } |
104 | bool VisitExpr_(const StringImmNode* op) final { return false; } |
105 | |
106 | private: |
107 | template <typename T> |
108 | bool BinaryOp(const T* op) { |
109 | return VisitExpr(op->a) || VisitExpr(op->b); |
110 | } |
111 | |
112 | OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind" ); |
113 | }; |
114 | |
115 | class UnsafeSelectRewriter : public StmtExprMutator { |
116 | public: |
117 | PrimExpr VisitExpr_(const SelectNode* op) { |
118 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
119 | op = expr.as<SelectNode>(); |
120 | UnsafeExprDetector unsafe; |
121 | bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); |
122 | if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && |
123 | cond_is_scalar_bool) { |
124 | return Call(op->dtype, builtin::if_then_else(), |
125 | {op->condition, op->true_value, op->false_value}); |
126 | } else { |
127 | return expr; |
128 | } |
129 | } |
130 | }; |
131 | |
132 | Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); } |
133 | |
134 | namespace transform { |
135 | |
136 | Pass RewriteUnsafeSelect() { |
137 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
138 | auto* n = f.CopyOnWrite(); |
139 | n->body = UnsafeSelectRewriter()(std::move(n->body)); |
140 | return f; |
141 | }; |
142 | return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect" , {}); |
143 | } |
144 | |
145 | TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect" ).set_body_typed(RewriteUnsafeSelect); |
146 | |
147 | } // namespace transform |
148 | |
149 | } // namespace tir |
150 | } // namespace tvm |
151 | |