1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*! \brief Check if an IRModule has any dynamic loop. */
25struct DynamicExtentFinder : private StmtVisitor {
26 public:
27 static bool Find(const IRModule& mod) {
28 DynamicExtentFinder finder;
29 for (const auto& kv : mod->functions) {
30 const BaseFunc& func = kv.second;
31 if (const auto* prim_func = func.as<PrimFuncNode>()) {
32 finder(prim_func->body);
33 if (finder.found_) {
34 return true;
35 }
36 }
37 }
38 return false;
39 }
40
41 private:
42 void VisitStmt_(const ForNode* loop) final {
43 if (!loop->extent->IsInstance<IntImmNode>()) {
44 found_ = true;
45 } else {
46 StmtVisitor::VisitStmt_(loop);
47 }
48 }
49
50 void VisitStmt(const Stmt& stmt) final {
51 if (!found_) {
52 StmtVisitor::VisitStmt(stmt);
53 }
54 }
55
56 bool found_ = false;
57};
58
59} // namespace tir
60
61namespace meta_schedule {
62
63/*! \brief Check if the IRModule has any loop with non-constant extent. */
64class DisallowDynamicLoopNode : public PostprocNode {
65 public:
66 // Inherited from PostprocNode
67 void InitializeWithTuneContext(const TuneContext& context) final {}
68 // Inherited from PostprocNode
69 bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); }
70 // Inherited from PostprocNode
71 Postproc Clone() const {
72 ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>(*this);
73 return Postproc(n);
74 }
75
76 static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop";
77 TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode);
78};
79
80Postproc Postproc::DisallowDynamicLoop() {
81 ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>();
82 return Postproc(n);
83}
84
85TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode);
86TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop")
87 .set_body_typed(Postproc::DisallowDynamicLoop);
88
89} // namespace meta_schedule
90} // namespace tvm
91