1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file combine_parallel_op_batch.h
22 * \brief Combine parallel ops into a single batch op.
23 */
24#ifndef TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
25#define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
26
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/attrs/nn.h>
29#include <tvm/relay/attrs/transform.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/relay/transform.h>
33
34#include <string>
35#include <unordered_map>
36#include <unordered_set>
37
38#include "./combine_parallel_op.h"
39#include "./expr_subst.h"
40#include "pattern_utils.h"
41
42namespace tvm {
43namespace relay {
44
45/*
46 * Class to find and combine parallel ops and following element-wise
47 * and broadcast ops into a single batch op. Ops can be combined
48 * if they have the same input data. Batch op is formed by
49 * stacking inputs. Final results are retrieved by splitting output.
50 * For example:
51 *
52 * data
53 * / \
54 * dense (2,2) dense (2,2)
55 * | |
56 * elemwise/bcast (2,2) elemwise/bcast (2,2)
57 *
58 * Would become:
59 *
60 * data
61 * |
62 * batch_matmul+elemwise/bcast (2,2,2)
63 */
64class ParallelOpBatchCombiner : public ParallelOpCombiner {
65 public:
66 /*
67 * \brief Constructor.
68 * \param op_name name of op to combine
69 * \param batch_op_name name of op that combined branches will be joined into
70 * \param min_num_branches min number of parallel branches beginning with op
71 * to start combining
72 */
73 ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name,
74 uint64_t min_num_branches);
75
76 protected:
77 /*
78 * \brief Checks if node is supported to be combined
79 * \param n node in question
80 * \return True by default
81 */
82 virtual bool IsSupportedOp(const CallNode* n);
83
84 /*
85 * \brief Checks if two ops can be combined
86 * \param a node a
87 * \param b node b
88 * \return True if shapes and dtypes of all args of a and b are the same
89 */
90 virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b);
91
92 /*
93 * \brief Makes combined op from parallel ops in branches. This usually involves
94 * concatenating or stacking inputs, then creating a new call.
95 * \param branches branches that are to be combined
96 * \return new call with branches combined as batch op by stacking args
97 */
98 virtual Call MakeCombinedOp(const Group& branches);
99
100 /*
101 * \brief Checks if argument of op following combined ops are able to be combined
102 * \param a node a
103 * \param b node b
104 * \param index index of argument in question
105 * \return True if shapes and dtypes of args[index] a and b are the same
106 */
107 bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final;
108
109 /*
110 * \brief Create combined call from ops that follow the initial combined op at the depth-th level.
111 * This usually involves concatenating or stacking inputs, then creating a new call.
112 * Only called if IsArgCompatbile returns true for each arg.
113 * \param data combined op
114 * \param branches branches of parallel ops to be combined
115 * \param depth depth at which to combine ops
116 * \param parent_index index of arg that corresponds to original input that was shared among
117 * all combined ops
118 * \return new combined call as batch op by stacking args
119 */
120 Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
121 size_t parent_index) final;
122
123 /*
124 * \brief Updates map of expr to substitute with combined expr. This usually involves
125 * slicing or splitting data.
126 * \param data combined op
127 * \param branches branches of parallel ops to be combined
128 * \param depth depth at which to substitute
129 * \param subst_map map of Expr to replace with Expr to replace it with
130 */
131 void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
132 ExprSubstMap* subst_map) final;
133
134 private:
135 /* \brief name of op to replace combined ops with. for example,
136 * for combining parallel dense, this will will be set to
137 * nn.batch_matmul
138 */
139 std::string batch_op_name_;
140};
141
142} // namespace relay
143} // namespace tvm
144
145#endif // TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_
146