1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file combine_parallel_op_batch.h |
22 | * \brief Combine parallel ops into a single batch op. |
23 | */ |
24 | #ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_ |
25 | #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_ |
26 | |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/attrs/nn.h> |
29 | #include <tvm/relay/attrs/transform.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/relay/transform.h> |
33 | |
34 | #include <string> |
35 | #include <unordered_map> |
36 | #include <unordered_set> |
37 | |
38 | #include "./combine_parallel_op.h" |
39 | #include "./expr_subst.h" |
40 | #include "pattern_utils.h" |
41 | |
42 | namespace tvm { |
43 | namespace relay { |
44 | |
45 | /* |
46 | * Class to find and combine parallel ops and following element-wise |
47 | * and broadcast ops into a single batch op. Ops can be combined |
48 | * if they have the same input data. Batch op is formed by |
49 | * stacking inputs. Final results are retrieved by splitting output. |
50 | * For example: |
51 | * |
52 | * data |
53 | * / \ |
54 | * dense (2,2) dense (2,2) |
55 | * | | |
56 | * elemwise/bcast (2,2) elemwise/bcast (2,2) |
57 | * |
58 | * Would become: |
59 | * |
60 | * data |
61 | * | |
62 | * batch_matmul+elemwise/bcast (2,2,2) |
63 | */ |
64 | class ParallelOpBatchCombiner : public ParallelOpCombiner { |
65 | public: |
66 | /* |
67 | * \brief Constructor. |
68 | * \param op_name name of op to combine |
69 | * \param batch_op_name name of op that combined branches will be joined into |
70 | * \param min_num_branches min number of parallel branches beginning with op |
71 | * to start combining |
72 | */ |
73 | ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, |
74 | uint64_t min_num_branches); |
75 | |
76 | protected: |
77 | /* |
78 | * \brief Checks if node is supported to be combined |
79 | * \param n node in question |
80 | * \return True by default |
81 | */ |
82 | virtual bool IsSupportedOp(const CallNode* n); |
83 | |
84 | /* |
85 | * \brief Checks if two ops can be combined |
86 | * \param a node a |
87 | * \param b node b |
88 | * \return True if shapes and dtypes of all args of a and b are the same |
89 | */ |
90 | virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b); |
91 | |
92 | /* |
93 | * \brief Makes combined op from parallel ops in branches. This usually involves |
94 | * concatenating or stacking inputs, then creating a new call. |
95 | * \param branches branches that are to be combined |
96 | * \return new call with branches combined as batch op by stacking args |
97 | */ |
98 | virtual Call MakeCombinedOp(const Group& branches); |
99 | |
100 | /* |
101 | * \brief Checks if argument of op following combined ops are able to be combined |
102 | * \param a node a |
103 | * \param b node b |
104 | * \param index index of argument in question |
105 | * \return True if shapes and dtypes of args[index] a and b are the same |
106 | */ |
107 | bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final; |
108 | |
109 | /* |
110 | * \brief Create combined call from ops that follow the initial combined op at the depth-th level. |
111 | * This usually involves concatenating or stacking inputs, then creating a new call. |
112 | * Only called if IsArgCompatbile returns true for each arg. |
113 | * \param data combined op |
114 | * \param branches branches of parallel ops to be combined |
115 | * \param depth depth at which to combine ops |
116 | * \param parent_index index of arg that corresponds to original input that was shared among |
117 | * all combined ops |
118 | * \return new combined call as batch op by stacking args |
119 | */ |
120 | Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, |
121 | size_t parent_index) final; |
122 | |
123 | /* |
124 | * \brief Updates map of expr to substitute with combined expr. This usually involves |
125 | * slicing or splitting data. |
126 | * \param data combined op |
127 | * \param branches branches of parallel ops to be combined |
128 | * \param depth depth at which to substitute |
129 | * \param subst_map map of Expr to replace with Expr to replace it with |
130 | */ |
131 | void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, |
132 | ExprSubstMap* subst_map) final; |
133 | |
134 | private: |
135 | /* \brief name of op to replace combined ops with. for example, |
136 | * for combining parallel dense, this will will be set to |
137 | * nn.batch_matmul |
138 | */ |
139 | std::string batch_op_name_; |
140 | }; |
141 | |
142 | } // namespace relay |
143 | } // namespace tvm |
144 | |
145 | #endif // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_ |
146 | |