1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file binary.cc
22 * \brief binary broadcast operators.
23 */
24#include <tvm/relay/expr.h>
25#include <tvm/relay/op.h>
26#include <tvm/topi/broadcast.h>
27
28#include "../op_common.h"
29#include "../type_relations.h"
30
31namespace tvm {
32namespace relay {
33
34#define RELAY_BINARY_COMPUTE(FTOPI) \
35 [](const Attrs& attrs, const Array<te::Tensor>& inputs, \
36 const Type& out_type) -> Array<te::Tensor> { \
37 ICHECK_EQ(inputs.size(), 2U); \
38 return {FTOPI(inputs[0], inputs[1])}; \
39 }
40
41// Addition
42RELAY_REGISTER_BINARY_OP("add")
43 .describe("Elementwise add with broadcasting")
44 .set_support_level(1)
45 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));
46
47// Subtraction
48RELAY_REGISTER_BINARY_OP("subtract")
49 .describe("Elementwise substract with broadcasting")
50 .set_support_level(1)
51 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));
52
53// Right shift
54RELAY_REGISTER_BINARY_OP("right_shift")
55 .describe("Elementwise right shift with broadcasting")
56 .set_support_level(4)
57 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
58
59RELAY_REGISTER_BINARY_OP("left_shift")
60 .describe("Elementwise left shift with broadcasting")
61 .set_support_level(4)
62 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift));
63
64RELAY_REGISTER_BINARY_OP("maximum")
65 .describe("Elementwise maximum of two tensors with broadcasting")
66 .set_support_level(4)
67 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum));
68
69RELAY_REGISTER_BINARY_OP("minimum")
70 .describe("Elementwise minimum of two tensors with broadcasting")
71 .set_support_level(4)
72 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum));
73
74RELAY_REGISTER_BINARY_OP("divide")
75 .describe("Elementwise divide with broadcasting")
76 .set_support_level(1)
77 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide));
78
79RELAY_REGISTER_BINARY_OP("trunc_divide")
80 .describe("Elementwise trunc divide with broadcasting")
81 .set_support_level(1)
82 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::trunc_divide));
83
84RELAY_REGISTER_BINARY_OP("floor_divide")
85 .describe("Elementwise floor divide with broadcasting")
86 .set_support_level(1)
87 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide));
88
89RELAY_REGISTER_BINARY_OP("multiply")
90 .describe("Elementwise multiply with broadcasting")
91 .set_support_level(1)
92 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply));
93
94RELAY_REGISTER_BINARY_OP("power")
95 .describe("Elementwise power with broadcasting")
96 .set_support_level(4)
97 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power));
98
99RELAY_REGISTER_BINARY_OP("mod")
100 .describe("Elementwise mod with broadcasting")
101 .set_support_level(1)
102 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
103
104RELAY_REGISTER_BINARY_OP("floor_mod")
105 .describe("Elementwise floor mod with broadcasting")
106 .set_support_level(1)
107 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod));
108
109RELAY_REGISTER_BINARY_OP("trunc_mod")
110 .describe("Elementwise trunc mod with broadcasting")
111 .set_support_level(1)
112 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::trunc_mod));
113
114RELAY_REGISTER_BINARY_OP("logical_and")
115 .describe("Elementwise logical AND with broadcasting")
116 .set_support_level(4)
117 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and));
118
119RELAY_REGISTER_BINARY_OP("logical_or")
120 .describe("Elementwise logical OR with broadcasting")
121 .set_support_level(4)
122 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));
123
124RELAY_REGISTER_BINARY_OP("logical_xor")
125 .describe("Elementwise logical XOR with broadcasting")
126 .set_support_level(4)
127 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor));
128
129RELAY_REGISTER_BINARY_OP("bitwise_and")
130 .describe("Elementwise bitwise AND with broadcasting")
131 .set_support_level(4)
132 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and));
133
134RELAY_REGISTER_BINARY_OP("bitwise_or")
135 .describe("Elementwise bitwise OR with broadcasting")
136 .set_support_level(4)
137 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or));
138
139RELAY_REGISTER_BINARY_OP("bitwise_xor")
140 .describe("Elementwise bitwise XOR with broadcasting")
141 .set_support_level(4)
142 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor));
143
144RELAY_REGISTER_CMP_OP("equal")
145 .describe("Elementwise equal compare with broadcasting")
146 .set_support_level(4)
147 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal));
148
149RELAY_REGISTER_CMP_OP("not_equal")
150 .describe("Elementwise not equal with broadcasting")
151 .set_support_level(4)
152 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal));
153
154RELAY_REGISTER_CMP_OP("less")
155 .describe("Elementwise less than with broadcasting")
156 .set_support_level(4)
157 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less));
158
159RELAY_REGISTER_CMP_OP("less_equal")
160 .describe("Elementwise less than or equal compare with broadcasting")
161 .set_support_level(4)
162 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal));
163
164RELAY_REGISTER_CMP_OP("greater")
165 .describe("Elementwise greater than compare with broadcasting")
166 .set_support_level(4)
167 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater));
168
169RELAY_REGISTER_CMP_OP("greater_equal")
170 .describe("Elementwise greater than or equal compare with broadcasting")
171 .set_support_level(4)
172 .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal));
173
174} // namespace relay
175} // namespace tvm
176