1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/tir/transform.h> |
20 | |
21 | #include "../utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace meta_schedule { |
25 | |
26 | class VerifyVTCMLimitNode : public PostprocNode { |
27 | public: |
28 | Integer vtcm_capacity; |
29 | |
30 | void InitializeWithTuneContext(const TuneContext& context) final { |
31 | ICHECK(context->target.defined()); |
32 | Target target = context->target.value(); |
33 | ICHECK(target->kind->name == "hexagon" ); |
34 | // The value of 0 will disable VTCM verification. |
35 | vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity" ).value_or(0); |
36 | } |
37 | |
38 | bool Verify(const IRModule& mod) const { |
39 | for (const auto& kv : mod->functions) { |
40 | if (const auto* prim_func = kv.second.as<tir::PrimFuncNode>()) { |
41 | if (!tir::VerifyVTCMLimit(GetRef<tir::PrimFunc>(prim_func), vtcm_capacity)) { |
42 | return false; |
43 | } |
44 | } |
45 | } |
46 | return true; |
47 | } |
48 | |
49 | bool Apply(const tir::Schedule& sch) final { |
50 | IRModule mod = sch->mod(); |
51 | for (const auto& kv : mod->functions) { |
52 | const GlobalVar& g_var = kv.first; |
53 | const BaseFunc& base_func = kv.second; |
54 | if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) { |
55 | IRModule lowered{nullptr}; |
56 | try { |
57 | auto pass_list = Array<tvm::transform::Pass>(); |
58 | pass_list.push_back(tir::transform::LowerInitBlock()); |
59 | pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); |
60 | pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); |
61 | pass_list.push_back(tir::transform::CompactBufferAllocation()); |
62 | pass_list.push_back(tir::transform::LowerMatchBuffer()); |
63 | pass_list.push_back(tir::transform::InjectSoftwarePipeline()); |
64 | pass_list.push_back(tir::transform::LowerOpaqueBlock()); |
65 | pass_list.push_back(tir::transform::FlattenBuffer()); |
66 | pass_list.push_back(tir::transform::Simplify()); |
67 | pass_list.push_back(tir::transform::VectorizeLoop(true)); |
68 | pass_list.push_back(tir::transform::StorageRewrite()); |
69 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
70 | tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol" , |
71 | runtime::String(g_var->name_hint)); |
72 | IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}})); |
73 | lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); |
74 | } catch (const dmlc::Error& e) { |
75 | return false; |
76 | } |
77 | if (!Verify(lowered)) { |
78 | return false; |
79 | } |
80 | } |
81 | } |
82 | return true; |
83 | } |
84 | |
85 | Postproc Clone() const { |
86 | ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>(*this); |
87 | return Postproc(n); |
88 | } |
89 | |
90 | static constexpr const char* _type_key = "meta_schedule.VerifyVTCMLimit" ; |
91 | TVM_DECLARE_FINAL_OBJECT_INFO(VerifyVTCMLimitNode, PostprocNode); |
92 | }; |
93 | |
94 | Postproc Postproc::VerifyVTCMLimit() { |
95 | ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>(); |
96 | return Postproc(n); |
97 | } |
98 | |
99 | TVM_REGISTER_NODE_TYPE(VerifyVTCMLimitNode); |
100 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit" ) |
101 | .set_body_typed(Postproc::VerifyVTCMLimit); |
102 | |
103 | } // namespace meta_schedule |
104 | } // namespace tvm |
105 | |