1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key,
25 const ObjectRef& ann_val) {
26 // Extract annotation
27 const Map<String, ObjectRef>* annotations = nullptr;
28 if (const auto* loop = sref->StmtAs<ForNode>()) {
29 annotations = &loop->annotations;
30 } else if (const auto* block = sref->StmtAs<BlockNode>()) {
31 annotations = &block->annotations;
32 } else {
33 LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey();
34 }
35 // Check if the annotation already exists
36 if (annotations->find(ann_key) != annotations->end()) {
37 return;
38 }
39 // Add the new annotation
40 Map<String, ObjectRef> new_ann(*annotations);
41 new_ann.Set(ann_key, ann_val);
42 // Create the new stmt
43 if (const auto* loop = sref->StmtAs<ForNode>()) {
44 ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
45 n->annotations = std::move(new_ann);
46 self->Replace(sref, For(n), {});
47 } else if (const auto* block = sref->StmtAs<BlockNode>()) {
48 ObjectPtr<BlockNode> n = make_object<BlockNode>(*block);
49 n->annotations = std::move(new_ann);
50 Block p(n);
51 self->Replace(sref, p, {{GetRef<Block>(block), p}});
52 } else {
53 LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey();
54 throw;
55 }
56}
57
58void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) {
59 // Extract annotation
60 const Map<String, ObjectRef>* annotations = nullptr;
61 if (const auto* loop = sref->StmtAs<ForNode>()) {
62 annotations = &loop->annotations;
63 } else if (const auto* block = sref->StmtAs<BlockNode>()) {
64 annotations = &block->annotations;
65 } else {
66 LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey();
67 }
68 // Remove the annotation
69 ICHECK(annotations->find(ann_key) != annotations->end())
70 << "IndexError: Cannot find annotation key: " << ann_key;
71 Map<String, ObjectRef> new_ann(*annotations);
72 new_ann.erase(ann_key);
73 // Create the new stmt
74 if (const auto* loop = sref->StmtAs<ForNode>()) {
75 ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
76 n->annotations = std::move(new_ann);
77 self->Replace(sref, For(n), {});
78 } else if (const auto* block = sref->StmtAs<BlockNode>()) {
79 ObjectPtr<BlockNode> n = make_object<BlockNode>(*block);
80 n->annotations = std::move(new_ann);
81 Block p(n);
82 self->Replace(sref, p, {{GetRef<Block>(block), p}});
83 } else {
84 LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey();
85 throw;
86 }
87}
88
89struct AnnotateTraits : public UnpackedInstTraits<AnnotateTraits> {
90 static constexpr const char* kName = "Annotate";
91 static constexpr bool kIsPure = false;
92
93 private:
94 static constexpr size_t kNumInputs = 2;
95 static constexpr size_t kNumAttrs = 1;
96 static constexpr size_t kNumDecisions = 0;
97
98 static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val,
99 String ann_key) {
100 if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
101 return sch->Annotate(GetRef<BlockRV>(block), ann_key, ann_val);
102 }
103 if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
104 return sch->Annotate(GetRef<LoopRV>(loop), ann_key, ann_val);
105 }
106 LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey();
107 throw;
108 }
109
110 static String UnpackedAsPython(Array<String> outputs, ObjectRef block_or_loop_rv,
111 ObjectRef ann_val, String ann_key) {
112 PythonAPICall py("annotate");
113 py.Input("block_or_loop", block_or_loop_rv);
114 py.Input("ann_key", ann_key);
115 py.Input("ann_val", ann_val);
116 return py.Str();
117 }
118
119 template <typename>
120 friend struct ::tvm::tir::UnpackedInstTraits;
121};
122
123struct UnannotateTraits : public UnpackedInstTraits<UnannotateTraits> {
124 static constexpr const char* kName = "Unannotate";
125 static constexpr bool kIsPure = false;
126
127 private:
128 static constexpr size_t kNumInputs = 1;
129 static constexpr size_t kNumAttrs = 1;
130 static constexpr size_t kNumDecisions = 0;
131
132 static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String ann_key) {
133 if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
134 return sch->Unannotate(GetRef<BlockRV>(block), ann_key);
135 }
136 if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
137 return sch->Unannotate(GetRef<LoopRV>(loop), ann_key);
138 }
139 LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey();
140 throw;
141 }
142
143 static String UnpackedAsPython(Array<String> outputs, ObjectRef block_or_loop_rv,
144 String ann_key) {
145 PythonAPICall py("unannotate");
146 py.Input("block_or_loop", block_or_loop_rv);
147 py.Input("ann_key", ann_key);
148 return py.Str();
149 }
150
151 template <typename>
152 friend struct ::tvm::tir::UnpackedInstTraits;
153};
154
155TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits);
156TVM_REGISTER_INST_KIND_TRAITS(UnannotateTraits);
157
158} // namespace tir
159} // namespace tvm
160