1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file src/relay/op/annotation/annotation.cc
23 * \brief Helpers for working with various 'annotations' attributes.
24 */
25
26#include "./annotation.h"
27
28#include <tvm/relay/attrs/annotation.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/op.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/tir/expr.h>
33#include <tvm/topi/elemwise.h>
34
35#include "../../transforms/infer_layout_utils.h"
36#include "../type_relations.h"
37
38namespace tvm {
39namespace relay {
40
41Expr StopFusion(Expr data) {
42 static const Op& op = Op::Get("annotation.stop_fusion");
43 return Call(op, {data}, Attrs{}, {});
44}
45
46TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion").set_body_typed([](Expr data) {
47 return StopFusion(data);
48});
49
50RELAY_REGISTER_OP("annotation.stop_fusion")
51 .describe(
52 R"code(Annotate an expression to prevent it being fused with following expressions.)code" TVM_ADD_FILELINE)
53 .set_num_inputs(1)
54 .add_argument("data", "Tensor", "The input data.")
55 .add_type_rel("Identity", IdentityRel)
56 .set_support_level(10)
57 .set_attr<TOpPattern>("TOpPattern", kOpaque)
58 .set_attr<TOpIsStateful>("TOpIsStateful", false)
59 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
60 .set_attr<FTVMCompute>("FTVMCompute",
61 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
62 const Type& out_dtype) -> Array<te::Tensor> {
63 return {topi::identity(inputs[0])};
64 });
65
66// relay.annotation.cast_hint
67TVM_REGISTER_NODE_TYPE(CastHintAttrs);
68
69Expr CastHint(Expr data, DataType dtype) {
70 auto attrs = make_object<CastHintAttrs>();
71 attrs->dtype = dtype;
72 static const Op& op = Op::Get("annotation.cast_hint");
73 return Call(op, {data}, Attrs{attrs}, {});
74}
75
76RELAY_REGISTER_OP("annotation.cast_hint")
77 .describe(
78 R"code(Annotate an expression to be cast into specific data type.)code" TVM_ADD_FILELINE)
79 .set_num_inputs(1)
80 .add_argument("data", "Tensor", "The input data.")
81 .add_type_rel("Identity", IdentityRel)
82 .set_support_level(10)
83 .set_attr<TOpPattern>("TOpPattern", kOpaque)
84 .set_attr<TOpIsStateful>("TOpIsStateful", false)
85 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
86 .set_attr<FTVMCompute>("FTVMCompute",
87 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
88 const Type& out_dtype) -> Array<te::Tensor> {
89 return {topi::identity(inputs[0])};
90 });
91
92RELAY_REGISTER_OP("annotation.bitpack_start")
93 .describe(R"code(
94Mark the start of bitpacking.
95)code" TVM_ADD_FILELINE)
96 .set_num_inputs(1)
97 .add_argument("data", "Tensor", "The input data.")
98 .set_support_level(10)
99 .add_type_rel("Identity", IdentityRel)
100 .set_attr<TOpPattern>("TOpPattern", kOpaque)
101 .set_attr<TOpIsStateful>("TOpIsStateful", false)
102 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
103 .set_attr<FTVMCompute>("FTVMCompute",
104 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
105 const Type& out_dtype) -> Array<te::Tensor> {
106 return {topi::identity(inputs[0])};
107 });
108
109RELAY_REGISTER_OP("annotation.bitpack_end")
110 .describe(R"code(
111Mark the end of bitpacking.
112)code" TVM_ADD_FILELINE)
113 .set_num_inputs(1)
114 .add_argument("data", "Tensor", "The input data.")
115 .set_support_level(10)
116 .add_type_rel("Identity", IdentityRel)
117 .set_attr<TOpPattern>("TOpPattern", kOpaque)
118 .set_attr<TOpIsStateful>("TOpIsStateful", false)
119 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
120 .set_attr<FTVMCompute>("FTVMCompute",
121 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
122 const Type& out_dtype) -> Array<te::Tensor> {
123 return {topi::identity(inputs[0])};
124 });
125
126TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint").set_body_typed([](Expr data) {
127 static const Op& op = Op::Get("annotation.checkpoint");
128 return Call(op, {data}, Attrs{}, {});
129});
130
131RELAY_REGISTER_OP("annotation.checkpoint")
132 .describe(R"code(
133Mark a checkpoint for checkpointing memory optimization.
134)code" TVM_ADD_FILELINE)
135 .set_num_inputs(1)
136 .set_support_level(10)
137 .add_argument("data", "Tensor", "The input data.")
138 .add_type_rel("Identity", IdentityRel)
139 .set_attr<TOpPattern>("TOpPattern", kOpaque)
140 .set_attr<TOpIsStateful>("TOpIsStateful", false)
141 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
142 .set_attr<FTVMCompute>("FTVMCompute",
143 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
144 const Type& out_dtype) -> Array<te::Tensor> {
145 Array<te::Tensor> outputs;
146 for (size_t i = 0; i < inputs.size(); ++i) {
147 outputs.push_back(topi::identity(inputs[i]));
148 }
149 return outputs;
150 });
151
152TVM_REGISTER_NODE_TYPE(CompilerAttrs);
153
154RELAY_REGISTER_OP("annotation.compiler_begin")
155 .describe(R"code(
156Beginning of a region that is handled by a given compiler.
157)code" TVM_ADD_FILELINE)
158 .set_num_inputs(1)
159 .add_argument("data", "Tensor", "The input data.")
160 .set_support_level(10)
161 .add_type_rel("Identity", IdentityRel)
162 .set_attr<TOpPattern>("TOpPattern", kOpaque)
163 .set_attr<TOpIsStateful>("TOpIsStateful", false)
164 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
165 .set_attr<FTVMCompute>("FTVMCompute",
166 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
167 const Type& out_dtype) -> Array<te::Tensor> {
168 return {topi::identity(inputs[0])};
169 });
170
171TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
172 .set_body_typed([](Expr expr, String compiler) {
173 auto attrs = make_object<CompilerAttrs>();
174 attrs->compiler = compiler;
175 static const Op& op = Op::Get("annotation.compiler_begin");
176 return Call(op, {expr}, Attrs(attrs), {});
177 });
178
179RELAY_REGISTER_OP("annotation.compiler_end")
180 .describe(R"code(
181End of a region that is handled by a given compiler.
182)code" TVM_ADD_FILELINE)
183 .set_num_inputs(1)
184 .add_argument("data", "Tensor", "The input data.")
185 .set_support_level(10)
186 .add_type_rel("Identity", IdentityRel)
187 .set_attr<TOpPattern>("TOpPattern", kOpaque)
188 .set_attr<TOpIsStateful>("TOpIsStateful", false)
189 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
190 .set_attr<FTVMCompute>("FTVMCompute",
191 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
192 const Type& out_dtype) -> Array<te::Tensor> {
193 return {topi::identity(inputs[0])};
194 });
195
196TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
197 .set_body_typed([](Expr expr, String compiler) {
198 auto attrs = make_object<CompilerAttrs>();
199 attrs->compiler = compiler;
200 static const Op& op = Op::Get("annotation.compiler_end");
201 return Call(op, {expr}, Attrs(attrs), {});
202 });
203
204} // namespace relay
205} // namespace tvm
206