1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/ir/script/script_complete.cc |
22 | * \brief Used by TVM Script parser to expand incomplete TIR input |
23 | */ |
24 | |
25 | #include "./script_complete.h" |
26 | |
27 | #include <tvm/arith/int_set.h> |
28 | #include <tvm/tir/analysis.h> |
29 | |
30 | #include <utility> |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | /*! \brief Generate surrounding loops automatically */ |
36 | class ScriptCompleter : public StmtMutator { |
37 | public: |
38 | explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map) : buffer_var_map_(buffer_var_map) {} |
39 | /*! \brief Whether the stmt contains at least one block. */ |
40 | bool contains_block = false; |
41 | |
42 | private: |
43 | Map<Var, Buffer>* buffer_var_map_; |
44 | Stmt VisitStmt_(const BlockRealizeNode* op) override { |
45 | contains_block = true; |
46 | for (const PrimExpr& value : op->iter_values) { |
47 | CHECK(value.dtype().is_int()) |
48 | << "BlockRealize iter_value expected a IntImm, but got " << value.dtype(); |
49 | } |
50 | return StmtMutator::VisitStmt_(op); |
51 | } |
52 | |
53 | Stmt VisitStmt_(const BlockNode* op) override { |
54 | // Buffers allocated in the block can be accessed by its body. |
55 | for (const auto& alloc_buffer : op->alloc_buffers) { |
56 | buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); |
57 | } |
58 | for (const auto& match_buffer : op->match_buffers) { |
59 | const Buffer& target_buffer = match_buffer->buffer; |
60 | buffer_var_map_->Set(target_buffer->data, target_buffer); |
61 | } |
62 | Block block = Downcast<Block>(StmtMutator::VisitStmt_(op)); |
63 | // Remove buffers allocated inside block to detect its access region |
64 | for (const auto& alloc_buffer : op->alloc_buffers) { |
65 | buffer_var_map_->erase(alloc_buffer->data); |
66 | } |
67 | for (const auto& match_buffer : op->match_buffers) { |
68 | const Buffer& target_buffer = match_buffer->buffer; |
69 | buffer_var_map_->erase(target_buffer->data); |
70 | } |
71 | // Get access detection mask |
72 | // 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write |
73 | int mask = 0; |
74 | auto it = op->annotations.find(attr::script_parsing_detect_access); |
75 | if (it != op->annotations.end()) { |
76 | mask = Downcast<IntImm>((*it).second)->value; |
77 | } |
78 | // ignore root block or blocks which already has reads/writes regions |
79 | if (mask != 0) { |
80 | auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); |
81 | const Array<BufferRegion>& reads = access_region[0]; |
82 | const Array<BufferRegion>& writes = access_region[1]; |
83 | const Array<BufferRegion>& opaque = access_region[2]; |
84 | CHECK(opaque.empty()) |
85 | << "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or " |
86 | "direct access by buffer data. Please annotation the access region manually" ; |
87 | auto n = CopyOnWrite(block.operator->()); |
88 | if (mask & 1) n->reads = reads; |
89 | if (mask & 2) n->writes = writes; |
90 | n->annotations = op->annotations; |
91 | n->annotations.erase(attr::script_parsing_detect_access); |
92 | return Block(n); |
93 | } else { |
94 | return std::move(block); |
95 | } |
96 | } |
97 | }; |
98 | |
99 | PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) { |
100 | Map<Var, Buffer> buffer_var_map; |
101 | for (const auto& pair : func->buffer_map) { |
102 | const Buffer& buffer = pair.second; |
103 | buffer_var_map.Set(buffer->data, buffer); |
104 | } |
105 | for (const auto& alloc : root_allocates) { |
106 | buffer_var_map.Set(alloc->data, alloc); |
107 | } |
108 | |
109 | Stmt res = func->body; |
110 | |
111 | // Generate root block automatically. This is done before |
112 | // ScriptCompleter, in order to fill the root block's T.reads() and |
113 | // T.writes() annotations, as if it had been explicitly written. |
114 | bool should_insert_root = [&]() -> bool { |
115 | if (root_allocates.size()) { |
116 | return true; |
117 | } |
118 | auto* block_realize = func->body.as<BlockRealizeNode>(); |
119 | if (block_realize && block_realize->block->iter_vars.size()) { |
120 | return true; |
121 | } |
122 | if (!block_realize && ContainsNode<BlockRealizeNode>(func->body)) { |
123 | return true; |
124 | } |
125 | return false; |
126 | }(); |
127 | |
128 | if (should_insert_root) { |
129 | Block root_block({}, {}, {}, "root" , std::move(res), NullOpt, root_allocates); |
130 | res = BlockRealize({}, Bool(true), std::move(root_block)); |
131 | } |
132 | |
133 | // generate surrounding loops automatically |
134 | ScriptCompleter script_completer(&buffer_var_map); |
135 | res = script_completer(std::move(res)); |
136 | |
137 | if (func->body.same_as(res)) { |
138 | return func; |
139 | } else { |
140 | auto fptr = func.CopyOnWrite(); |
141 | fptr->body = res; |
142 | return func; |
143 | } |
144 | } |
145 | |
146 | TVM_REGISTER_GLOBAL("script.Complete" ).set_body_typed(ScriptComplete); |
147 | |
148 | } // namespace tir |
149 | } // namespace tvm |
150 | |