1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file binary.cc |
22 | * \brief binary broadcast operators. |
23 | */ |
24 | #include <tvm/relay/expr.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/topi/broadcast.h> |
27 | |
28 | #include "../op_common.h" |
29 | #include "../type_relations.h" |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | |
34 | #define RELAY_BINARY_COMPUTE(FTOPI) \ |
35 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, \ |
36 | const Type& out_type) -> Array<te::Tensor> { \ |
37 | ICHECK_EQ(inputs.size(), 2U); \ |
38 | return {FTOPI(inputs[0], inputs[1])}; \ |
39 | } |
40 | |
41 | // Addition |
42 | RELAY_REGISTER_BINARY_OP("add" ) |
43 | .describe("Elementwise add with broadcasting" ) |
44 | .set_support_level(1) |
45 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::add)); |
46 | |
47 | // Subtraction |
48 | RELAY_REGISTER_BINARY_OP("subtract" ) |
49 | .describe("Elementwise substract with broadcasting" ) |
50 | .set_support_level(1) |
51 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::subtract)); |
52 | |
53 | // Right shift |
54 | RELAY_REGISTER_BINARY_OP("right_shift" ) |
55 | .describe("Elementwise right shift with broadcasting" ) |
56 | .set_support_level(4) |
57 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::right_shift)); |
58 | |
59 | RELAY_REGISTER_BINARY_OP("left_shift" ) |
60 | .describe("Elementwise left shift with broadcasting" ) |
61 | .set_support_level(4) |
62 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::left_shift)); |
63 | |
64 | RELAY_REGISTER_BINARY_OP("maximum" ) |
65 | .describe("Elementwise maximum of two tensors with broadcasting" ) |
66 | .set_support_level(4) |
67 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::maximum)); |
68 | |
69 | RELAY_REGISTER_BINARY_OP("minimum" ) |
70 | .describe("Elementwise minimum of two tensors with broadcasting" ) |
71 | .set_support_level(4) |
72 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::minimum)); |
73 | |
74 | RELAY_REGISTER_BINARY_OP("divide" ) |
75 | .describe("Elementwise divide with broadcasting" ) |
76 | .set_support_level(1) |
77 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::divide)); |
78 | |
79 | RELAY_REGISTER_BINARY_OP("trunc_divide" ) |
80 | .describe("Elementwise trunc divide with broadcasting" ) |
81 | .set_support_level(1) |
82 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::trunc_divide)); |
83 | |
84 | RELAY_REGISTER_BINARY_OP("floor_divide" ) |
85 | .describe("Elementwise floor divide with broadcasting" ) |
86 | .set_support_level(1) |
87 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::floor_divide)); |
88 | |
89 | RELAY_REGISTER_BINARY_OP("multiply" ) |
90 | .describe("Elementwise multiply with broadcasting" ) |
91 | .set_support_level(1) |
92 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::multiply)); |
93 | |
94 | RELAY_REGISTER_BINARY_OP("power" ) |
95 | .describe("Elementwise power with broadcasting" ) |
96 | .set_support_level(4) |
97 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::power)); |
98 | |
99 | RELAY_REGISTER_BINARY_OP("mod" ) |
100 | .describe("Elementwise mod with broadcasting" ) |
101 | .set_support_level(1) |
102 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::mod)); |
103 | |
104 | RELAY_REGISTER_BINARY_OP("floor_mod" ) |
105 | .describe("Elementwise floor mod with broadcasting" ) |
106 | .set_support_level(1) |
107 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::floor_mod)); |
108 | |
109 | RELAY_REGISTER_BINARY_OP("trunc_mod" ) |
110 | .describe("Elementwise trunc mod with broadcasting" ) |
111 | .set_support_level(1) |
112 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::trunc_mod)); |
113 | |
114 | RELAY_REGISTER_BINARY_OP("logical_and" ) |
115 | .describe("Elementwise logical AND with broadcasting" ) |
116 | .set_support_level(4) |
117 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::logical_and)); |
118 | |
119 | RELAY_REGISTER_BINARY_OP("logical_or" ) |
120 | .describe("Elementwise logical OR with broadcasting" ) |
121 | .set_support_level(4) |
122 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::logical_or)); |
123 | |
124 | RELAY_REGISTER_BINARY_OP("logical_xor" ) |
125 | .describe("Elementwise logical XOR with broadcasting" ) |
126 | .set_support_level(4) |
127 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::logical_xor)); |
128 | |
129 | RELAY_REGISTER_BINARY_OP("bitwise_and" ) |
130 | .describe("Elementwise bitwise AND with broadcasting" ) |
131 | .set_support_level(4) |
132 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::bitwise_and)); |
133 | |
134 | RELAY_REGISTER_BINARY_OP("bitwise_or" ) |
135 | .describe("Elementwise bitwise OR with broadcasting" ) |
136 | .set_support_level(4) |
137 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::bitwise_or)); |
138 | |
139 | RELAY_REGISTER_BINARY_OP("bitwise_xor" ) |
140 | .describe("Elementwise bitwise XOR with broadcasting" ) |
141 | .set_support_level(4) |
142 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::bitwise_xor)); |
143 | |
144 | RELAY_REGISTER_CMP_OP("equal" ) |
145 | .describe("Elementwise equal compare with broadcasting" ) |
146 | .set_support_level(4) |
147 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::equal)); |
148 | |
149 | RELAY_REGISTER_CMP_OP("not_equal" ) |
150 | .describe("Elementwise not equal with broadcasting" ) |
151 | .set_support_level(4) |
152 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::not_equal)); |
153 | |
154 | RELAY_REGISTER_CMP_OP("less" ) |
155 | .describe("Elementwise less than with broadcasting" ) |
156 | .set_support_level(4) |
157 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::less)); |
158 | |
159 | RELAY_REGISTER_CMP_OP("less_equal" ) |
160 | .describe("Elementwise less than or equal compare with broadcasting" ) |
161 | .set_support_level(4) |
162 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::less_equal)); |
163 | |
164 | RELAY_REGISTER_CMP_OP("greater" ) |
165 | .describe("Elementwise greater than compare with broadcasting" ) |
166 | .set_support_level(4) |
167 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::greater)); |
168 | |
169 | RELAY_REGISTER_CMP_OP("greater_equal" ) |
170 | .describe("Elementwise greater than or equal compare with broadcasting" ) |
171 | .set_support_level(4) |
172 | .set_attr<FTVMCompute>("FTVMCompute" , RELAY_BINARY_COMPUTE(topi::greater_equal)); |
173 | |
174 | } // namespace relay |
175 | } // namespace tvm |
176 | |