1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/tir/transform.h>
20
21#include "../utils.h"
22
23namespace tvm {
24namespace meta_schedule {
25
26class VerifyVTCMLimitNode : public PostprocNode {
27 public:
28 Integer vtcm_capacity;
29
30 void InitializeWithTuneContext(const TuneContext& context) final {
31 ICHECK(context->target.defined());
32 Target target = context->target.value();
33 ICHECK(target->kind->name == "hexagon");
34 // The value of 0 will disable VTCM verification.
35 vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value_or(0);
36 }
37
38 bool Verify(const IRModule& mod) const {
39 for (const auto& kv : mod->functions) {
40 if (const auto* prim_func = kv.second.as<tir::PrimFuncNode>()) {
41 if (!tir::VerifyVTCMLimit(GetRef<tir::PrimFunc>(prim_func), vtcm_capacity)) {
42 return false;
43 }
44 }
45 }
46 return true;
47 }
48
49 bool Apply(const tir::Schedule& sch) final {
50 IRModule mod = sch->mod();
51 for (const auto& kv : mod->functions) {
52 const GlobalVar& g_var = kv.first;
53 const BaseFunc& base_func = kv.second;
54 if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
55 IRModule lowered{nullptr};
56 try {
57 auto pass_list = Array<tvm::transform::Pass>();
58 pass_list.push_back(tir::transform::LowerInitBlock());
59 pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
60 pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
61 pass_list.push_back(tir::transform::CompactBufferAllocation());
62 pass_list.push_back(tir::transform::LowerMatchBuffer());
63 pass_list.push_back(tir::transform::InjectSoftwarePipeline());
64 pass_list.push_back(tir::transform::LowerOpaqueBlock());
65 pass_list.push_back(tir::transform::FlattenBuffer());
66 pass_list.push_back(tir::transform::Simplify());
67 pass_list.push_back(tir::transform::VectorizeLoop(true));
68 pass_list.push_back(tir::transform::StorageRewrite());
69 transform::PassContext pass_ctx = transform::PassContext::Current();
70 tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
71 runtime::String(g_var->name_hint));
72 IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
73 lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
74 } catch (const dmlc::Error& e) {
75 return false;
76 }
77 if (!Verify(lowered)) {
78 return false;
79 }
80 }
81 }
82 return true;
83 }
84
85 Postproc Clone() const {
86 ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>(*this);
87 return Postproc(n);
88 }
89
90 static constexpr const char* _type_key = "meta_schedule.VerifyVTCMLimit";
91 TVM_DECLARE_FINAL_OBJECT_INFO(VerifyVTCMLimitNode, PostprocNode);
92};
93
94Postproc Postproc::VerifyVTCMLimit() {
95 ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>();
96 return Postproc(n);
97}
98
99TVM_REGISTER_NODE_TYPE(VerifyVTCMLimitNode);
100TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit")
101 .set_body_typed(Postproc::VerifyVTCMLimit);
102
103} // namespace meta_schedule
104} // namespace tvm
105