1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file src/relay/op/annotation/annotation.cc |
23 | * \brief Helpers for working with various 'annotations' attributes. |
24 | */ |
25 | |
26 | #include "./annotation.h" |
27 | |
28 | #include <tvm/relay/attrs/annotation.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/op.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/tir/expr.h> |
33 | #include <tvm/topi/elemwise.h> |
34 | |
35 | #include "../../transforms/infer_layout_utils.h" |
36 | #include "../type_relations.h" |
37 | |
38 | namespace tvm { |
39 | namespace relay { |
40 | |
41 | Expr StopFusion(Expr data) { |
42 | static const Op& op = Op::Get("annotation.stop_fusion" ); |
43 | return Call(op, {data}, Attrs{}, {}); |
44 | } |
45 | |
46 | TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion" ).set_body_typed([](Expr data) { |
47 | return StopFusion(data); |
48 | }); |
49 | |
50 | RELAY_REGISTER_OP("annotation.stop_fusion" ) |
51 | .describe( |
52 | R"code(Annotate an expression to prevent it being fused with following expressions.)code" TVM_ADD_FILELINE) |
53 | .set_num_inputs(1) |
54 | .add_argument("data" , "Tensor" , "The input data." ) |
55 | .add_type_rel("Identity" , IdentityRel) |
56 | .set_support_level(10) |
57 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
58 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
59 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
60 | .set_attr<FTVMCompute>("FTVMCompute" , |
61 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
62 | const Type& out_dtype) -> Array<te::Tensor> { |
63 | return {topi::identity(inputs[0])}; |
64 | }); |
65 | |
66 | // relay.annotation.cast_hint |
67 | TVM_REGISTER_NODE_TYPE(CastHintAttrs); |
68 | |
69 | Expr CastHint(Expr data, DataType dtype) { |
70 | auto attrs = make_object<CastHintAttrs>(); |
71 | attrs->dtype = dtype; |
72 | static const Op& op = Op::Get("annotation.cast_hint" ); |
73 | return Call(op, {data}, Attrs{attrs}, {}); |
74 | } |
75 | |
76 | RELAY_REGISTER_OP("annotation.cast_hint" ) |
77 | .describe( |
78 | R"code(Annotate an expression to be cast into specific data type.)code" TVM_ADD_FILELINE) |
79 | .set_num_inputs(1) |
80 | .add_argument("data" , "Tensor" , "The input data." ) |
81 | .add_type_rel("Identity" , IdentityRel) |
82 | .set_support_level(10) |
83 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
84 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
85 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
86 | .set_attr<FTVMCompute>("FTVMCompute" , |
87 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
88 | const Type& out_dtype) -> Array<te::Tensor> { |
89 | return {topi::identity(inputs[0])}; |
90 | }); |
91 | |
92 | RELAY_REGISTER_OP("annotation.bitpack_start" ) |
93 | .describe(R"code( |
94 | Mark the start of bitpacking. |
95 | )code" TVM_ADD_FILELINE) |
96 | .set_num_inputs(1) |
97 | .add_argument("data" , "Tensor" , "The input data." ) |
98 | .set_support_level(10) |
99 | .add_type_rel("Identity" , IdentityRel) |
100 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
101 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
102 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
103 | .set_attr<FTVMCompute>("FTVMCompute" , |
104 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
105 | const Type& out_dtype) -> Array<te::Tensor> { |
106 | return {topi::identity(inputs[0])}; |
107 | }); |
108 | |
109 | RELAY_REGISTER_OP("annotation.bitpack_end" ) |
110 | .describe(R"code( |
111 | Mark the end of bitpacking. |
112 | )code" TVM_ADD_FILELINE) |
113 | .set_num_inputs(1) |
114 | .add_argument("data" , "Tensor" , "The input data." ) |
115 | .set_support_level(10) |
116 | .add_type_rel("Identity" , IdentityRel) |
117 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
118 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
119 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
120 | .set_attr<FTVMCompute>("FTVMCompute" , |
121 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
122 | const Type& out_dtype) -> Array<te::Tensor> { |
123 | return {topi::identity(inputs[0])}; |
124 | }); |
125 | |
126 | TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint" ).set_body_typed([](Expr data) { |
127 | static const Op& op = Op::Get("annotation.checkpoint" ); |
128 | return Call(op, {data}, Attrs{}, {}); |
129 | }); |
130 | |
131 | RELAY_REGISTER_OP("annotation.checkpoint" ) |
132 | .describe(R"code( |
133 | Mark a checkpoint for checkpointing memory optimization. |
134 | )code" TVM_ADD_FILELINE) |
135 | .set_num_inputs(1) |
136 | .set_support_level(10) |
137 | .add_argument("data" , "Tensor" , "The input data." ) |
138 | .add_type_rel("Identity" , IdentityRel) |
139 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
140 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
141 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
142 | .set_attr<FTVMCompute>("FTVMCompute" , |
143 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
144 | const Type& out_dtype) -> Array<te::Tensor> { |
145 | Array<te::Tensor> outputs; |
146 | for (size_t i = 0; i < inputs.size(); ++i) { |
147 | outputs.push_back(topi::identity(inputs[i])); |
148 | } |
149 | return outputs; |
150 | }); |
151 | |
152 | TVM_REGISTER_NODE_TYPE(CompilerAttrs); |
153 | |
154 | RELAY_REGISTER_OP("annotation.compiler_begin" ) |
155 | .describe(R"code( |
156 | Beginning of a region that is handled by a given compiler. |
157 | )code" TVM_ADD_FILELINE) |
158 | .set_num_inputs(1) |
159 | .add_argument("data" , "Tensor" , "The input data." ) |
160 | .set_support_level(10) |
161 | .add_type_rel("Identity" , IdentityRel) |
162 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
163 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
164 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
165 | .set_attr<FTVMCompute>("FTVMCompute" , |
166 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
167 | const Type& out_dtype) -> Array<te::Tensor> { |
168 | return {topi::identity(inputs[0])}; |
169 | }); |
170 | |
171 | TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin" ) |
172 | .set_body_typed([](Expr expr, String compiler) { |
173 | auto attrs = make_object<CompilerAttrs>(); |
174 | attrs->compiler = compiler; |
175 | static const Op& op = Op::Get("annotation.compiler_begin" ); |
176 | return Call(op, {expr}, Attrs(attrs), {}); |
177 | }); |
178 | |
179 | RELAY_REGISTER_OP("annotation.compiler_end" ) |
180 | .describe(R"code( |
181 | End of a region that is handled by a given compiler. |
182 | )code" TVM_ADD_FILELINE) |
183 | .set_num_inputs(1) |
184 | .add_argument("data" , "Tensor" , "The input data." ) |
185 | .set_support_level(10) |
186 | .add_type_rel("Identity" , IdentityRel) |
187 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
188 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
189 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
190 | .set_attr<FTVMCompute>("FTVMCompute" , |
191 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
192 | const Type& out_dtype) -> Array<te::Tensor> { |
193 | return {topi::identity(inputs[0])}; |
194 | }); |
195 | |
196 | TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end" ) |
197 | .set_body_typed([](Expr expr, String compiler) { |
198 | auto attrs = make_object<CompilerAttrs>(); |
199 | attrs->compiler = compiler; |
200 | static const Op& op = Op::Get("annotation.compiler_end" ); |
201 | return Call(op, {expr}, Attrs(attrs), {}); |
202 | }); |
203 | |
204 | } // namespace relay |
205 | } // namespace tvm |
206 | |