1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <bitset> |
18 | |
19 | #include "common/broadcast_strategy.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | |
24 | output_dims_t make_output_dims(const memory_desc_wrapper &dst_d) { |
25 | output_dims_t od {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; |
26 | for (int i = 0; i < dst_d.ndims(); ++i) |
27 | od[i] = dst_d.dims()[i]; |
28 | return od; |
29 | } |
30 | |
31 | broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( |
32 | const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d) { |
33 | |
34 | static const bcast_set_t all_bcast_strategies { |
35 | broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
36 | broadcasting_strategy_t::per_oc_spatial, |
37 | broadcasting_strategy_t::shared_axes, |
38 | broadcasting_strategy_t::per_mb_spatial, |
39 | broadcasting_strategy_t::per_mb_w, broadcasting_strategy_t::per_w, |
40 | broadcasting_strategy_t::no_broadcast}; |
41 | |
42 | return get_rhs_arg_broadcasting_strategy( |
43 | rhs_arg_md, dst_d, all_bcast_strategies); |
44 | } |
45 | |
46 | namespace { |
47 | |
48 | // Checks if mask corresponds to broadcast per first and last dimensions |
49 | // Returns true if mask (5D) is equal to [0, 1, 1, 1, 0] |
50 | bool is_per_mb_w_bcast(const std::bitset<DNNL_MAX_NDIMS> mask, |
51 | const memory_desc_wrapper &dst_d) { |
52 | const auto ndims = dst_d.ndims(); |
53 | const int last_dim = ndims - 1; |
54 | |
55 | bool per_mb_w_bcast = !mask.test(0) && !mask.test(last_dim); |
56 | if (!per_mb_w_bcast) return false; |
57 | |
58 | for (int d = 1; d < last_dim; ++d) |
59 | per_mb_w_bcast = per_mb_w_bcast && mask.test(d); |
60 | return per_mb_w_bcast; |
61 | } |
62 | |
63 | // Checks if mask corresponds to broadcast per last dimension |
64 | // Returns true if mask (5D) is equal to [1, 1, 1, 1, 0] |
65 | bool is_per_w_bcast(const std::bitset<DNNL_MAX_NDIMS> mask, |
66 | const memory_desc_wrapper &dst_d) { |
67 | const auto ndims = dst_d.ndims(); |
68 | const int last_dim = ndims - 1; |
69 | |
70 | bool per_w_bcast = !mask.test(last_dim); |
71 | if (!per_w_bcast) return false; |
72 | |
73 | for (int d = 0; d < last_dim; ++d) |
74 | per_w_bcast = per_w_bcast && mask.test(d); |
75 | return per_w_bcast; |
76 | } |
77 | |
78 | // Checks if mask corresponds to broadcast per batch and spatial dimensions |
79 | // Returns true if mask (5D) is equal to [0, 1, 0, 0, 0] and |
80 | // also if any of mask bits equal 0 will be equal to 1, |
81 | // but only if corresponding output dimensions are also equal to 1. |
82 | bool is_channel_bcast(const std::bitset<DNNL_MAX_NDIMS> mask, |
83 | const memory_desc_wrapper &dst_d) { |
84 | for (int d = 0; d < dst_d.ndims(); ++d) { |
85 | if (d == 1 && !mask.test(1)) return false; |
86 | if (d != 1 && mask.test(d) && dst_d.dims()[d] != 1) return false; |
87 | } |
88 | return true; |
89 | } |
90 | |
91 | // Check if mask corresponds to broadcast per oc |
92 | // Returns true if mask (5D) is equal to [1, 0, 1, 1, 1] |
93 | bool is_per_oc_bcast(const std::bitset<DNNL_MAX_NDIMS> mask, |
94 | const memory_desc_t &rhs_arg_md) { |
95 | const bool broadcast_per_oc = !mask.test(1); |
96 | |
97 | if (!broadcast_per_oc) return false; |
98 | |
99 | const auto ndims = rhs_arg_md.ndims; |
100 | |
101 | if (ndims > 0 && rhs_arg_md.dims[0] != 1) return false; |
102 | |
103 | for (int dim = 2; dim < ndims; dim++) { |
104 | if (rhs_arg_md.dims[dim] != 1) return false; |
105 | } |
106 | return true; |
107 | } |
108 | |
109 | bool bcast_strategy_enabled(const bcast_set_t &supported_strategy_set, |
110 | const broadcasting_strategy_t &bcast) { |
111 | return supported_strategy_set.find(bcast) != supported_strategy_set.cend(); |
112 | } |
113 | |
114 | broadcasting_strategy_t get_per_oc_bcast( |
115 | const bcast_set_t &supported_strategy_set, |
116 | const memory_desc_wrapper &dst_d) { |
117 | |
118 | const auto ndims = dst_d.ndims(); |
119 | const bool use_per_oc_spatial_strategy = bcast_strategy_enabled( |
120 | supported_strategy_set, broadcasting_strategy_t::per_oc_spatial); |
121 | |
122 | if (use_per_oc_spatial_strategy && dst_d.is_blocking_desc()) { |
123 | const auto &strides = dst_d.blocking_desc().strides; |
124 | |
125 | //per_oc_spatial used in nchw data format and matmul having ndims >= 3 |
126 | return (dst_d.is_plain() && strides[0] >= strides[1] |
127 | && IMPLICATION(ndims < 3, strides[1] != 1) |
128 | && IMPLICATION(ndims >= 3, strides[1] >= strides[2])) |
129 | ? broadcasting_strategy_t::per_oc_spatial |
130 | : broadcasting_strategy_t::per_oc; |
131 | } |
132 | |
133 | return broadcasting_strategy_t::per_oc; |
134 | } |
135 | } // namespace |
136 | |
137 | // Compares dimensions of rhs arg (src1) with dimensions of destination. |
138 | // Produces broadcast mask and returns broadcast strategy which |
139 | // corresponds to given mask. |
140 | // Mask bits are set to 1 if corresponding dimensions are different |
141 | // or if both dimensions are equal to 1. |
142 | // Otherwise mask bits are set to 0. |
143 | broadcasting_strategy_t get_rhs_arg_broadcasting_strategy( |
144 | const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d, |
145 | const bcast_set_t &supported_strategy_set) { |
146 | |
147 | const auto is_enabled = [&](const broadcasting_strategy_t &bcast) { |
148 | return bcast_strategy_enabled(supported_strategy_set, bcast); |
149 | }; |
150 | |
151 | const int ndims = rhs_arg_md.ndims; |
152 | const auto output_dims = make_output_dims(dst_d); |
153 | |
154 | bool all_ones = true; |
155 | bool all_equal = true; |
156 | std::bitset<DNNL_MAX_NDIMS> mask(0); |
157 | for (int d = 0; d < ndims; d++) { |
158 | const auto &rhs_arg_dim = rhs_arg_md.dims[d]; |
159 | if (rhs_arg_md.dims[d] != 1 && rhs_arg_md.dims[d] != output_dims[d]) |
160 | return broadcasting_strategy_t::unsupported; |
161 | |
162 | if (rhs_arg_dim != 1) all_ones = false; |
163 | |
164 | const bool different_dims = output_dims[d] != rhs_arg_md.dims[d]; |
165 | if (different_dims) all_equal = false; |
166 | |
167 | if (different_dims || output_dims[d] == 1) mask.set(d); |
168 | } |
169 | |
170 | broadcasting_strategy_t bcast = broadcasting_strategy_t::unsupported; |
171 | |
172 | if (all_ones && is_enabled(broadcasting_strategy_t::scalar)) |
173 | bcast = broadcasting_strategy_t::scalar; |
174 | else if (all_equal && is_enabled(broadcasting_strategy_t::no_broadcast)) |
175 | bcast = broadcasting_strategy_t::no_broadcast; |
176 | else if (is_per_mb_w_bcast(mask, dst_d) |
177 | && is_enabled(broadcasting_strategy_t::per_mb_w)) |
178 | bcast = broadcasting_strategy_t::per_mb_w; |
179 | else if (is_per_oc_bcast(mask, rhs_arg_md) |
180 | && (is_enabled(broadcasting_strategy_t::per_oc) |
181 | || is_enabled(broadcasting_strategy_t::per_oc_spatial))) { |
182 | bcast = get_per_oc_bcast(supported_strategy_set, dst_d); |
183 | } else if (is_per_w_bcast(mask, dst_d) |
184 | && is_enabled(broadcasting_strategy_t::per_w)) |
185 | bcast = broadcasting_strategy_t::per_w; |
186 | else if (is_channel_bcast(mask, dst_d) |
187 | && is_enabled(broadcasting_strategy_t::per_mb_spatial)) |
188 | bcast = broadcasting_strategy_t::per_mb_spatial; |
189 | else if (is_enabled(broadcasting_strategy_t::shared_axes)) |
190 | bcast = broadcasting_strategy_t::shared_axes; |
191 | |
192 | return bcast; |
193 | } |
194 | |
195 | } // namespace impl |
196 | } // namespace dnnl |
197 | |