1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file inject_prefetch.cc |
22 | */ |
23 | // Inject prefetch op in HalideIR |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/arith/bound.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/transform.h> |
31 | |
32 | #include <unordered_set> |
33 | |
34 | #include "ir_utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | using arith::DomainTouched; |
40 | using arith::IntSet; |
41 | |
42 | class PrefetchInjector : public StmtMutator { |
43 | public: |
44 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
45 | Stmt ret = StmtMutator::VisitStmt_(op); |
46 | op = ret.as<AttrStmtNode>(); |
47 | if (op && op->attr_key == attr::prefetch_scope) { |
48 | Buffer buffer = Downcast<Buffer>(op->node); |
49 | ICHECK_NE(loop_nest_.size(), 0U); |
50 | Region domain = DomainTouched(op->body, buffer, true, false); |
51 | Region region; |
52 | |
53 | auto iter_var = loop_nest_.back().get(); |
54 | vectorized_[iter_var] = IntSet::SinglePoint(loop_nest_.back() + op->value); |
55 | |
56 | for (Range r : domain) { |
57 | if (!r.defined()) { |
58 | LOG(WARNING) << "Cannot decide prefetch region for " << buffer; |
59 | return op->body; |
60 | } |
61 | Range res(EvalSet(r, vectorized_).CoverRange(none)); |
62 | region.push_back(Range::FromMinExtent(res->min, res->extent)); |
63 | } |
64 | |
65 | vectorized_.erase(iter_var); |
66 | |
67 | Stmt prefetch = Prefetch(buffer, region); |
68 | return SeqStmt({prefetch, op->body}); |
69 | } |
70 | return ret; |
71 | } |
72 | |
73 | Stmt VisitStmt_(const ForNode* op) final { |
74 | auto& var = op->loop_var; |
75 | loop_nest_.push_back(var); |
76 | if (op->kind == ForKind::kVectorized) { |
77 | vectorized_[var.get()] = IntSet::Interval(op->min, (op->min + op->extent) - 1); |
78 | } |
79 | Stmt ret = StmtMutator::VisitStmt_(op); |
80 | if (op->kind == ForKind::kVectorized) { |
81 | vectorized_.erase(var.get()); |
82 | } |
83 | loop_nest_.pop_back(); |
84 | return ret; |
85 | } |
86 | |
87 | private: |
88 | std::vector<Var> loop_nest_; |
89 | std::unordered_map<const VarNode*, IntSet> vectorized_; |
90 | static const Range none; |
91 | }; |
92 | |
93 | const Range PrefetchInjector::none; |
94 | |
95 | Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); } |
96 | |
97 | namespace transform { |
98 | |
99 | Pass InjectPrefetch() { |
100 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
101 | // Only apply this pass to TIR from TE schedules |
102 | if (IsFromLegacyTESchedule(f)) { |
103 | auto* n = f.CopyOnWrite(); |
104 | n->body = PrefetchInjector()(std::move(n->body)); |
105 | return f; |
106 | } else { |
107 | return f; |
108 | } |
109 | }; |
110 | return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch" , {}); |
111 | } |
112 | |
113 | TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch" ).set_body_typed(InjectPrefetch); |
114 | |
115 | } // namespace transform |
116 | |
117 | } // namespace tir |
118 | } // namespace tvm |
119 | |