1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_BROADCAST_STRATEGY_HPP |
18 | #define COMMON_BROADCAST_STRATEGY_HPP |
19 | |
20 | #include <array> |
21 | #include <set> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/memory_desc_wrapper.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | |
29 | using output_dims_t = std::array<dim_t, DNNL_MAX_NDIMS>; |
30 | |
31 | enum class broadcasting_strategy_t { |
32 | // [n, c, d, h, w] |
33 | scalar, // [1, 1, 1, 1, 1] // Channel_shared |
34 | per_oc, // [1, c, 1, 1, 1] // Channel-wise |
35 | per_oc_spatial, // [1, c, 1, 1, 1] specific case for binary kernel nchw format |
36 | per_mb_spatial, // [n, 1, d, h, w] // Broadcast only channel |
37 | per_mb_w, // [n, 1, 1, 1, w] // Broadcast per batch and width |
38 | per_w, // [1, 1, 1, 1, w] // Broadcast per width |
39 | shared_axes, // [n, 1, d, h, 1] // General case broadcast (any combination) |
40 | no_broadcast, // [n, c, d, h, w] |
41 | unsupported |
42 | }; |
43 | |
44 | using bcast_set_t = std::set<broadcasting_strategy_t>; |
45 | |
46 | inline const bcast_set_t &default_strategies() { |
47 | static const bcast_set_t s |
48 | = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
49 | broadcasting_strategy_t::no_broadcast}; |
50 | return s; |
51 | } |
52 | |
53 | output_dims_t make_output_dims(const memory_desc_wrapper &dst_d); |
54 | |
55 | broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( |
56 | const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d); |
57 | |
58 | broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( |
59 | const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d, |
60 | const bcast_set_t &supported_strategy_set); |
61 | |
62 | } // namespace impl |
63 | } // namespace dnnl |
64 | |
65 | #endif |
66 | |