1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/meta_schedule/postproc.h>
20
21#include <algorithm>
22
23#include "../utils.h"
24
25namespace tvm {
26namespace meta_schedule {
27
28using tir::BlockRV;
29using tir::LoopRV;
30
31void CollectTensorizationJobs(
32 const tir::Schedule& sch, const String& func_name, const tir::PrimFuncNode* func,
33 bool vectorize_init_loop,
34 std::vector<std::tuple<String, String, std::function<void(tir::BlockRV)>>>* jobs) {
35 tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
36 if (const auto* block = obj.as<tir::BlockNode>()) {
37 tir::StmtSRef block_sref = sch->GetSRef(block);
38 std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
39 if (Optional<String> intrin_name =
40 tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
41 if (intrin_name.value() != "") {
42 jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
43 try {
44 sch->Tensorize(block, intrin_name.value());
45 } catch (const std::exception& e) {
46 LOG(WARNING) << "Tensorize failed with error " << e.what();
47 }
48 });
49 } else if (block_name.find("init") && vectorize_init_loop) {
50 jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
51 Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
52 ICHECK(child_blocks.size() == 1);
53 Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
54 ICHECK(init_loops.size() == 1);
55 sch->Vectorize(init_loops[0]);
56 });
57 }
58 }
59 }
60 });
61}
62
63class RewriteTensorizeNode : public PostprocNode {
64 public:
65 void InitializeWithTuneContext(const TuneContext& context) final {}
66
67 bool Apply(const tir::Schedule& sch) final;
68
69 void VisitAttrs(tvm::AttrVisitor* v) {}
70
71 Postproc Clone() const {
72 ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>(*this);
73 return Postproc(n);
74 }
75
76 bool vectorize_init_loop = false;
77
78 static constexpr const char* _type_key = "meta_schedule.RewriteTensorize";
79 TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode);
80};
81
82bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
83 // The rewriting jobs, 3-tuple (block_name, func_name, job_func)
84 std::vector<std::tuple<String, String, std::function<void(tir::BlockRV)>>> jobs;
85 for (const auto& kv : sch->mod()->functions) {
86 GlobalVar g_var = kv.first;
87 BaseFunc base_func = kv.second;
88 if (const tir::PrimFuncNode* prim_func = base_func.as<tir::PrimFuncNode>()) {
89 CollectTensorizationJobs(sch, g_var->name_hint, prim_func, vectorize_init_loop, &jobs);
90 }
91 }
92 for (const auto& job : jobs) {
93 const String& block_name = std::get<0>(job);
94 const String& func_name = std::get<1>(job);
95 const auto& job_func = std::get<2>(job);
96 BlockRV block = sch->GetBlock(block_name, func_name);
97 sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
98 job_func(block);
99 }
100 return true;
101}
102
103Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) {
104 ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>();
105 n->vectorize_init_loop = vectorize_init_loop;
106 return Postproc(n);
107}
108
109TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode);
110TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize")
111 .set_body_typed(Postproc::RewriteTensorize);
112
113} // namespace meta_schedule
114} // namespace tvm
115