1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/op_strategy.h |
22 | * \brief The Relay operator Strategy and related data structure. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_OP_STRATEGY_H_ |
26 | #define TVM_RELAY_OP_STRATEGY_H_ |
27 | |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/op_attr_types.h> |
30 | #include <tvm/target/target.h> |
31 | #include <tvm/te/schedule.h> |
32 | #include <tvm/te/tensor.h> |
33 | |
34 | #include <string> |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | |
39 | /*! |
40 | * \brief Operator implementation that includes compute and schedule function. |
41 | */ |
42 | class OpImplementationNode : public Object { |
43 | public: |
44 | /*! \brief Compute function */ |
45 | FTVMCompute fcompute; |
46 | /*! \brief Schedule function */ |
47 | FTVMSchedule fschedule; |
48 | /*! \brief Name of the implementation */ |
49 | String name; |
50 | /*! \brief Priority level */ |
51 | int plevel; |
52 | |
53 | void VisitAttrs(tvm::AttrVisitor* v) { |
54 | v->Visit("name" , &name); |
55 | v->Visit("plevel" , &plevel); |
56 | } |
57 | |
58 | static constexpr const char* _type_key = "relay.OpImplementation" ; |
59 | TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object); |
60 | }; |
61 | |
62 | /*! |
63 | * \brief Operator implementation class. |
64 | */ |
65 | class OpImplementation : public ObjectRef { |
66 | public: |
67 | /*! |
68 | * \brief Invoke the operator compute function. |
69 | * \param attrs The attribute of the primitive |
70 | * \param inputs The input tensors. |
71 | * \param out_type The output type information. |
72 | * \return The output compute description of the operator. |
73 | */ |
74 | TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
75 | const Type& out_type); |
76 | /*! |
77 | * \brief Build the computation schedule. |
78 | * \param attrs The attribute of the node. |
79 | * \param outs The output tensors. |
80 | * \param target The build target. |
81 | * \return The computation schedule. |
82 | */ |
83 | TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array<te::Tensor>& outs, |
84 | const Target& target); |
85 | |
86 | TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode); |
87 | }; |
88 | |
89 | /*! |
90 | * \brief Specialized implementations for operators under certain conditions. |
91 | */ |
92 | class OpSpecializationNode : public Object { |
93 | public: |
94 | /*! \brief List of implementations. */ |
95 | Array<OpImplementation> implementations; |
96 | /*! \brief Condition to enable the specialization. |
97 | * Could be undefined to represent generic case. */ |
98 | te::SpecializedCondition condition; |
99 | |
100 | void VisitAttrs(tvm::AttrVisitor* v) { |
101 | v->Visit("condition" , &condition); |
102 | v->Visit("implementations" , &implementations); |
103 | } |
104 | |
105 | static constexpr const char* _type_key = "relay.OpSpecialization" ; |
106 | TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode); |
107 | }; |
108 | |
109 | /*! |
110 | * \brief Operator specialization class. |
111 | */ |
112 | class OpSpecialization : public ObjectRef { |
113 | public: |
114 | /*! |
115 | * \brief Add an implementation. |
116 | * \param fcompute Compute function |
117 | * \param fschedule Schedule function |
118 | * \param name Name of the implementation |
119 | * \param plevel Priority level of the implementation |
120 | */ |
121 | TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, |
122 | int plevel); |
123 | |
124 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode); |
125 | }; |
126 | |
127 | /*! |
128 | * \brief Operator strategy to choose implementation. |
129 | */ |
130 | class OpStrategyNode : public Object { |
131 | public: |
132 | /*! \brief List of operator specializations. */ |
133 | Array<OpSpecialization> specializations; |
134 | |
135 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations" , &specializations); } |
136 | |
137 | static constexpr const char* _type_key = "relay.OpStrategy" ; |
138 | TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode); |
139 | }; |
140 | |
141 | /*! |
142 | * \brief Operator strategy class. |
143 | */ |
144 | class OpStrategy : public ObjectRef { |
145 | public: |
146 | /*! |
147 | * \brief Add an implementation. |
148 | * \param fcompute Compute function |
149 | * \param fschedule Schedule function |
150 | * \param name Name of the implementation |
151 | * \param plevel Priority level of the implementation |
152 | */ |
153 | TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, |
154 | int plevel); |
155 | |
156 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode); |
157 | }; |
158 | |
159 | } // namespace relay |
160 | } // namespace tvm |
161 | #endif // TVM_RELAY_OP_STRATEGY_H_ |
162 | |