1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/op_strategy.h
22 * \brief The Relay operator Strategy and related data structure.
23 */
24
25#ifndef TVM_RELAY_OP_STRATEGY_H_
26#define TVM_RELAY_OP_STRATEGY_H_
27
28#include <tvm/relay/expr.h>
29#include <tvm/relay/op_attr_types.h>
30#include <tvm/target/target.h>
31#include <tvm/te/schedule.h>
32#include <tvm/te/tensor.h>
33
34#include <string>
35
36namespace tvm {
37namespace relay {
38
39/*!
40 * \brief Operator implementation that includes compute and schedule function.
41 */
42class OpImplementationNode : public Object {
43 public:
44 /*! \brief Compute function */
45 FTVMCompute fcompute;
46 /*! \brief Schedule function */
47 FTVMSchedule fschedule;
48 /*! \brief Name of the implementation */
49 String name;
50 /*! \brief Priority level */
51 int plevel;
52
53 void VisitAttrs(tvm::AttrVisitor* v) {
54 v->Visit("name", &name);
55 v->Visit("plevel", &plevel);
56 }
57
58 static constexpr const char* _type_key = "relay.OpImplementation";
59 TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);
60};
61
62/*!
63 * \brief Operator implementation class.
64 */
65class OpImplementation : public ObjectRef {
66 public:
67 /*!
68 * \brief Invoke the operator compute function.
69 * \param attrs The attribute of the primitive
70 * \param inputs The input tensors.
71 * \param out_type The output type information.
72 * \return The output compute description of the operator.
73 */
74 TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, const Array<te::Tensor>& inputs,
75 const Type& out_type);
76 /*!
77 * \brief Build the computation schedule.
78 * \param attrs The attribute of the node.
79 * \param outs The output tensors.
80 * \param target The build target.
81 * \return The computation schedule.
82 */
83 TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array<te::Tensor>& outs,
84 const Target& target);
85
86 TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
87};
88
89/*!
90 * \brief Specialized implementations for operators under certain conditions.
91 */
92class OpSpecializationNode : public Object {
93 public:
94 /*! \brief List of implementations. */
95 Array<OpImplementation> implementations;
96 /*! \brief Condition to enable the specialization.
97 * Could be undefined to represent generic case. */
98 te::SpecializedCondition condition;
99
100 void VisitAttrs(tvm::AttrVisitor* v) {
101 v->Visit("condition", &condition);
102 v->Visit("implementations", &implementations);
103 }
104
105 static constexpr const char* _type_key = "relay.OpSpecialization";
106 TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
107};
108
109/*!
110 * \brief Operator specialization class.
111 */
112class OpSpecialization : public ObjectRef {
113 public:
114 /*!
115 * \brief Add an implementation.
116 * \param fcompute Compute function
117 * \param fschedule Schedule function
118 * \param name Name of the implementation
119 * \param plevel Priority level of the implementation
120 */
121 TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
122 int plevel);
123
124 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
125};
126
127/*!
128 * \brief Operator strategy to choose implementation.
129 */
130class OpStrategyNode : public Object {
131 public:
132 /*! \brief List of operator specializations. */
133 Array<OpSpecialization> specializations;
134
135 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); }
136
137 static constexpr const char* _type_key = "relay.OpStrategy";
138 TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
139};
140
141/*!
142 * \brief Operator strategy class.
143 */
144class OpStrategy : public ObjectRef {
145 public:
146 /*!
147 * \brief Add an implementation.
148 * \param fcompute Compute function
149 * \param fschedule Schedule function
150 * \param name Name of the implementation
151 * \param plevel Priority level of the implementation
152 */
153 TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
154 int plevel);
155
156 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
157};
158
159} // namespace relay
160} // namespace tvm
161#endif // TVM_RELAY_OP_STRATEGY_H_
162