1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/meta_schedule/postproc.h> |
20 | |
21 | #include <algorithm> |
22 | |
23 | #include "../utils.h" |
24 | |
25 | namespace tvm { |
26 | namespace meta_schedule { |
27 | |
28 | using tir::BlockRV; |
29 | using tir::LoopRV; |
30 | |
31 | void CollectTensorizationJobs( |
32 | const tir::Schedule& sch, const String& func_name, const tir::PrimFuncNode* func, |
33 | bool vectorize_init_loop, |
34 | std::vector<std::tuple<String, String, std::function<void(tir::BlockRV)>>>* jobs) { |
35 | tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { |
36 | if (const auto* block = obj.as<tir::BlockNode>()) { |
37 | tir::StmtSRef block_sref = sch->GetSRef(block); |
38 | std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint; |
39 | if (Optional<String> intrin_name = |
40 | tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) { |
41 | if (intrin_name.value() != "" ) { |
42 | jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) { |
43 | try { |
44 | sch->Tensorize(block, intrin_name.value()); |
45 | } catch (const std::exception& e) { |
46 | LOG(WARNING) << "Tensorize failed with error " << e.what(); |
47 | } |
48 | }); |
49 | } else if (block_name.find("init" ) && vectorize_init_loop) { |
50 | jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) { |
51 | Array<BlockRV> child_blocks = sch->GetChildBlocks(block); |
52 | ICHECK(child_blocks.size() == 1); |
53 | Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]); |
54 | ICHECK(init_loops.size() == 1); |
55 | sch->Vectorize(init_loops[0]); |
56 | }); |
57 | } |
58 | } |
59 | } |
60 | }); |
61 | } |
62 | |
63 | class RewriteTensorizeNode : public PostprocNode { |
64 | public: |
65 | void InitializeWithTuneContext(const TuneContext& context) final {} |
66 | |
67 | bool Apply(const tir::Schedule& sch) final; |
68 | |
69 | void VisitAttrs(tvm::AttrVisitor* v) {} |
70 | |
71 | Postproc Clone() const { |
72 | ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>(*this); |
73 | return Postproc(n); |
74 | } |
75 | |
76 | bool vectorize_init_loop = false; |
77 | |
78 | static constexpr const char* _type_key = "meta_schedule.RewriteTensorize" ; |
79 | TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode); |
80 | }; |
81 | |
82 | bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { |
83 | // The rewriting jobs, 3-tuple (block_name, func_name, job_func) |
84 | std::vector<std::tuple<String, String, std::function<void(tir::BlockRV)>>> jobs; |
85 | for (const auto& kv : sch->mod()->functions) { |
86 | GlobalVar g_var = kv.first; |
87 | BaseFunc base_func = kv.second; |
88 | if (const tir::PrimFuncNode* prim_func = base_func.as<tir::PrimFuncNode>()) { |
89 | CollectTensorizationJobs(sch, g_var->name_hint, prim_func, vectorize_init_loop, &jobs); |
90 | } |
91 | } |
92 | for (const auto& job : jobs) { |
93 | const String& block_name = std::get<0>(job); |
94 | const String& func_name = std::get<1>(job); |
95 | const auto& job_func = std::get<2>(job); |
96 | BlockRV block = sch->GetBlock(block_name, func_name); |
97 | sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); |
98 | job_func(block); |
99 | } |
100 | return true; |
101 | } |
102 | |
103 | Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { |
104 | ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>(); |
105 | n->vectorize_init_loop = vectorize_init_loop; |
106 | return Postproc(n); |
107 | } |
108 | |
109 | TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); |
110 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize" ) |
111 | .set_body_typed(Postproc::RewriteTensorize); |
112 | |
113 | } // namespace meta_schedule |
114 | } // namespace tvm |
115 | |