1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Registration of broadcast operators
22 * \file broadcast.cc
23 */
24#include <tvm/runtime/packed_func.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/topi/broadcast.h>
27#include <tvm/topi/utils.h>
28
29namespace tvm {
30namespace topi {
31
32using namespace tvm;
33using namespace tvm::runtime;
34
35#define TOPI_REGISTER_BCAST_OP(OpName, Op) \
36 TVM_REGISTER_GLOBAL(OpName).set_body([](TVMArgs args, TVMRetValue* rv) { \
37 bool lhs_is_tensor = args[0].IsObjectRef<tvm::te::Tensor>(); \
38 bool rhs_is_tensor = args[1].IsObjectRef<tvm::te::Tensor>(); \
39 if (lhs_is_tensor && rhs_is_tensor) { \
40 *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::te::Tensor()); \
41 } else if (!lhs_is_tensor && rhs_is_tensor) { \
42 *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::te::Tensor()); \
43 } else if (lhs_is_tensor && !rhs_is_tensor) { \
44 *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::PrimExpr()); \
45 } else if (!lhs_is_tensor && !rhs_is_tensor) { \
46 *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::PrimExpr()); \
47 } \
48 });
49
50TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
51TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
52TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
53TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide);
54TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide);
55TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod);
56TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod);
57TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
58TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum);
59TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
60TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
61TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
62TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
63TOPI_REGISTER_BCAST_OP("topi.logical_xor", topi::logical_xor);
64TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
65TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
66TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
67TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
68TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
69TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
70TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal);
71TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
72TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
73TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
74
75TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue* rv) {
76 *rv = broadcast_to(args[0], args[1]);
77});
78
79} // namespace topi
80} // namespace tvm
81