1 | /* |
---|---|
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! \brief Check if an IRModule has any dynamic loop. */ |
25 | struct DynamicExtentFinder : private StmtVisitor { |
26 | public: |
27 | static bool Find(const IRModule& mod) { |
28 | DynamicExtentFinder finder; |
29 | for (const auto& kv : mod->functions) { |
30 | const BaseFunc& func = kv.second; |
31 | if (const auto* prim_func = func.as<PrimFuncNode>()) { |
32 | finder(prim_func->body); |
33 | if (finder.found_) { |
34 | return true; |
35 | } |
36 | } |
37 | } |
38 | return false; |
39 | } |
40 | |
41 | private: |
42 | void VisitStmt_(const ForNode* loop) final { |
43 | if (!loop->extent->IsInstance<IntImmNode>()) { |
44 | found_ = true; |
45 | } else { |
46 | StmtVisitor::VisitStmt_(loop); |
47 | } |
48 | } |
49 | |
50 | void VisitStmt(const Stmt& stmt) final { |
51 | if (!found_) { |
52 | StmtVisitor::VisitStmt(stmt); |
53 | } |
54 | } |
55 | |
56 | bool found_ = false; |
57 | }; |
58 | |
59 | } // namespace tir |
60 | |
61 | namespace meta_schedule { |
62 | |
63 | /*! \brief Check if the IRModule has any loop with non-constant extent. */ |
64 | class DisallowDynamicLoopNode : public PostprocNode { |
65 | public: |
66 | // Inherited from PostprocNode |
67 | void InitializeWithTuneContext(const TuneContext& context) final {} |
68 | // Inherited from PostprocNode |
69 | bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } |
70 | // Inherited from PostprocNode |
71 | Postproc Clone() const { |
72 | ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>(*this); |
73 | return Postproc(n); |
74 | } |
75 | |
76 | static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop"; |
77 | TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode); |
78 | }; |
79 | |
80 | Postproc Postproc::DisallowDynamicLoop() { |
81 | ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>(); |
82 | return Postproc(n); |
83 | } |
84 | |
85 | TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode); |
86 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") |
87 | .set_body_typed(Postproc::DisallowDynamicLoop); |
88 | |
89 | } // namespace meta_schedule |
90 | } // namespace tvm |
91 |