1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, |
25 | const ObjectRef& ann_val) { |
26 | // Extract annotation |
27 | const Map<String, ObjectRef>* annotations = nullptr; |
28 | if (const auto* loop = sref->StmtAs<ForNode>()) { |
29 | annotations = &loop->annotations; |
30 | } else if (const auto* block = sref->StmtAs<BlockNode>()) { |
31 | annotations = &block->annotations; |
32 | } else { |
33 | LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); |
34 | } |
35 | // Check if the annotation already exists |
36 | if (annotations->find(ann_key) != annotations->end()) { |
37 | return; |
38 | } |
39 | // Add the new annotation |
40 | Map<String, ObjectRef> new_ann(*annotations); |
41 | new_ann.Set(ann_key, ann_val); |
42 | // Create the new stmt |
43 | if (const auto* loop = sref->StmtAs<ForNode>()) { |
44 | ObjectPtr<ForNode> n = make_object<ForNode>(*loop); |
45 | n->annotations = std::move(new_ann); |
46 | self->Replace(sref, For(n), {}); |
47 | } else if (const auto* block = sref->StmtAs<BlockNode>()) { |
48 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*block); |
49 | n->annotations = std::move(new_ann); |
50 | Block p(n); |
51 | self->Replace(sref, p, {{GetRef<Block>(block), p}}); |
52 | } else { |
53 | LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); |
54 | throw; |
55 | } |
56 | } |
57 | |
58 | void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) { |
59 | // Extract annotation |
60 | const Map<String, ObjectRef>* annotations = nullptr; |
61 | if (const auto* loop = sref->StmtAs<ForNode>()) { |
62 | annotations = &loop->annotations; |
63 | } else if (const auto* block = sref->StmtAs<BlockNode>()) { |
64 | annotations = &block->annotations; |
65 | } else { |
66 | LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); |
67 | } |
68 | // Remove the annotation |
69 | ICHECK(annotations->find(ann_key) != annotations->end()) |
70 | << "IndexError: Cannot find annotation key: " << ann_key; |
71 | Map<String, ObjectRef> new_ann(*annotations); |
72 | new_ann.erase(ann_key); |
73 | // Create the new stmt |
74 | if (const auto* loop = sref->StmtAs<ForNode>()) { |
75 | ObjectPtr<ForNode> n = make_object<ForNode>(*loop); |
76 | n->annotations = std::move(new_ann); |
77 | self->Replace(sref, For(n), {}); |
78 | } else if (const auto* block = sref->StmtAs<BlockNode>()) { |
79 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*block); |
80 | n->annotations = std::move(new_ann); |
81 | Block p(n); |
82 | self->Replace(sref, p, {{GetRef<Block>(block), p}}); |
83 | } else { |
84 | LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); |
85 | throw; |
86 | } |
87 | } |
88 | |
89 | struct AnnotateTraits : public UnpackedInstTraits<AnnotateTraits> { |
90 | static constexpr const char* kName = "Annotate" ; |
91 | static constexpr bool kIsPure = false; |
92 | |
93 | private: |
94 | static constexpr size_t kNumInputs = 2; |
95 | static constexpr size_t kNumAttrs = 1; |
96 | static constexpr size_t kNumDecisions = 0; |
97 | |
98 | static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, |
99 | String ann_key) { |
100 | if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) { |
101 | return sch->Annotate(GetRef<BlockRV>(block), ann_key, ann_val); |
102 | } |
103 | if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) { |
104 | return sch->Annotate(GetRef<LoopRV>(loop), ann_key, ann_val); |
105 | } |
106 | LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); |
107 | throw; |
108 | } |
109 | |
110 | static String UnpackedAsPython(Array<String> outputs, ObjectRef block_or_loop_rv, |
111 | ObjectRef ann_val, String ann_key) { |
112 | PythonAPICall py("annotate" ); |
113 | py.Input("block_or_loop" , block_or_loop_rv); |
114 | py.Input("ann_key" , ann_key); |
115 | py.Input("ann_val" , ann_val); |
116 | return py.Str(); |
117 | } |
118 | |
119 | template <typename> |
120 | friend struct ::tvm::tir::UnpackedInstTraits; |
121 | }; |
122 | |
123 | struct UnannotateTraits : public UnpackedInstTraits<UnannotateTraits> { |
124 | static constexpr const char* kName = "Unannotate" ; |
125 | static constexpr bool kIsPure = false; |
126 | |
127 | private: |
128 | static constexpr size_t kNumInputs = 1; |
129 | static constexpr size_t kNumAttrs = 1; |
130 | static constexpr size_t kNumDecisions = 0; |
131 | |
132 | static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String ann_key) { |
133 | if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) { |
134 | return sch->Unannotate(GetRef<BlockRV>(block), ann_key); |
135 | } |
136 | if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) { |
137 | return sch->Unannotate(GetRef<LoopRV>(loop), ann_key); |
138 | } |
139 | LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); |
140 | throw; |
141 | } |
142 | |
143 | static String UnpackedAsPython(Array<String> outputs, ObjectRef block_or_loop_rv, |
144 | String ann_key) { |
145 | PythonAPICall py("unannotate" ); |
146 | py.Input("block_or_loop" , block_or_loop_rv); |
147 | py.Input("ann_key" , ann_key); |
148 | return py.Str(); |
149 | } |
150 | |
151 | template <typename> |
152 | friend struct ::tvm::tir::UnpackedInstTraits; |
153 | }; |
154 | |
155 | TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits); |
156 | TVM_REGISTER_INST_KIND_TRAITS(UnannotateTraits); |
157 | |
158 | } // namespace tir |
159 | } // namespace tvm |
160 | |