1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Registration of broadcast operators |
22 | * \file broadcast.cc |
23 | */ |
24 | #include <tvm/runtime/packed_func.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/topi/broadcast.h> |
27 | #include <tvm/topi/utils.h> |
28 | |
29 | namespace tvm { |
30 | namespace topi { |
31 | |
32 | using namespace tvm; |
33 | using namespace tvm::runtime; |
34 | |
35 | #define TOPI_REGISTER_BCAST_OP(OpName, Op) \ |
36 | TVM_REGISTER_GLOBAL(OpName).set_body([](TVMArgs args, TVMRetValue* rv) { \ |
37 | bool lhs_is_tensor = args[0].IsObjectRef<tvm::te::Tensor>(); \ |
38 | bool rhs_is_tensor = args[1].IsObjectRef<tvm::te::Tensor>(); \ |
39 | if (lhs_is_tensor && rhs_is_tensor) { \ |
40 | *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::te::Tensor()); \ |
41 | } else if (!lhs_is_tensor && rhs_is_tensor) { \ |
42 | *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::te::Tensor()); \ |
43 | } else if (lhs_is_tensor && !rhs_is_tensor) { \ |
44 | *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::PrimExpr()); \ |
45 | } else if (!lhs_is_tensor && !rhs_is_tensor) { \ |
46 | *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::PrimExpr()); \ |
47 | } \ |
48 | }); |
49 | |
50 | TOPI_REGISTER_BCAST_OP("topi.add" , topi::add); |
51 | TOPI_REGISTER_BCAST_OP("topi.subtract" , topi::subtract); |
52 | TOPI_REGISTER_BCAST_OP("topi.multiply" , topi::multiply); |
53 | TOPI_REGISTER_BCAST_OP("topi.divide" , topi::divide); |
54 | TOPI_REGISTER_BCAST_OP("topi.floor_divide" , topi::floor_divide); |
55 | TOPI_REGISTER_BCAST_OP("topi.mod" , topi::mod); |
56 | TOPI_REGISTER_BCAST_OP("topi.floor_mod" , topi::floor_mod); |
57 | TOPI_REGISTER_BCAST_OP("topi.maximum" , topi::maximum); |
58 | TOPI_REGISTER_BCAST_OP("topi.minimum" , topi::minimum); |
59 | TOPI_REGISTER_BCAST_OP("topi.power" , topi::power); |
60 | TOPI_REGISTER_BCAST_OP("topi.left_shift" , topi::left_shift); |
61 | TOPI_REGISTER_BCAST_OP("topi.logical_and" , topi::logical_and); |
62 | TOPI_REGISTER_BCAST_OP("topi.logical_or" , topi::logical_or); |
63 | TOPI_REGISTER_BCAST_OP("topi.logical_xor" , topi::logical_xor); |
64 | TOPI_REGISTER_BCAST_OP("topi.bitwise_and" , topi::bitwise_and); |
65 | TOPI_REGISTER_BCAST_OP("topi.bitwise_or" , topi::bitwise_or); |
66 | TOPI_REGISTER_BCAST_OP("topi.bitwise_xor" , topi::bitwise_xor); |
67 | TOPI_REGISTER_BCAST_OP("topi.right_shift" , topi::right_shift); |
68 | TOPI_REGISTER_BCAST_OP("topi.greater" , topi::greater); |
69 | TOPI_REGISTER_BCAST_OP("topi.less" , topi::less); |
70 | TOPI_REGISTER_BCAST_OP("topi.equal" , topi::equal); |
71 | TOPI_REGISTER_BCAST_OP("topi.not_equal" , topi::not_equal); |
72 | TOPI_REGISTER_BCAST_OP("topi.greater_equal" , topi::greater_equal); |
73 | TOPI_REGISTER_BCAST_OP("topi.less_equal" , topi::less_equal); |
74 | |
75 | TVM_REGISTER_GLOBAL("topi.broadcast_to" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
76 | *rv = broadcast_to(args[0], args[1]); |
77 | }); |
78 | |
79 | } // namespace topi |
80 | } // namespace tvm |
81 | |