1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <bitset>
18
19#include "common/broadcast_strategy.hpp"
20
21namespace dnnl {
22namespace impl {
23
24output_dims_t make_output_dims(const memory_desc_wrapper &dst_d) {
25 output_dims_t od {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
26 for (int i = 0; i < dst_d.ndims(); ++i)
27 od[i] = dst_d.dims()[i];
28 return od;
29}
30
31broadcasting_strategy_t get_rhs_arg_broadcasting_strategy(
32 const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d) {
33
34 static const bcast_set_t all_bcast_strategies {
35 broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
36 broadcasting_strategy_t::per_oc_spatial,
37 broadcasting_strategy_t::shared_axes,
38 broadcasting_strategy_t::per_mb_spatial,
39 broadcasting_strategy_t::per_mb_w, broadcasting_strategy_t::per_w,
40 broadcasting_strategy_t::no_broadcast};
41
42 return get_rhs_arg_broadcasting_strategy(
43 rhs_arg_md, dst_d, all_bcast_strategies);
44}
45
46namespace {
47
48// Checks if mask corresponds to broadcast per first and last dimensions
49// Returns true if mask (5D) is equal to [0, 1, 1, 1, 0]
50bool is_per_mb_w_bcast(const std::bitset<DNNL_MAX_NDIMS> mask,
51 const memory_desc_wrapper &dst_d) {
52 const auto ndims = dst_d.ndims();
53 const int last_dim = ndims - 1;
54
55 bool per_mb_w_bcast = !mask.test(0) && !mask.test(last_dim);
56 if (!per_mb_w_bcast) return false;
57
58 for (int d = 1; d < last_dim; ++d)
59 per_mb_w_bcast = per_mb_w_bcast && mask.test(d);
60 return per_mb_w_bcast;
61}
62
63// Checks if mask corresponds to broadcast per last dimension
64// Returns true if mask (5D) is equal to [1, 1, 1, 1, 0]
65bool is_per_w_bcast(const std::bitset<DNNL_MAX_NDIMS> mask,
66 const memory_desc_wrapper &dst_d) {
67 const auto ndims = dst_d.ndims();
68 const int last_dim = ndims - 1;
69
70 bool per_w_bcast = !mask.test(last_dim);
71 if (!per_w_bcast) return false;
72
73 for (int d = 0; d < last_dim; ++d)
74 per_w_bcast = per_w_bcast && mask.test(d);
75 return per_w_bcast;
76}
77
78// Checks if mask corresponds to broadcast per batch and spatial dimensions
79// Returns true if mask (5D) is equal to [0, 1, 0, 0, 0] and
80// also if any of mask bits equal 0 will be equal to 1,
81// but only if corresponding output dimensions are also equal to 1.
82bool is_channel_bcast(const std::bitset<DNNL_MAX_NDIMS> mask,
83 const memory_desc_wrapper &dst_d) {
84 for (int d = 0; d < dst_d.ndims(); ++d) {
85 if (d == 1 && !mask.test(1)) return false;
86 if (d != 1 && mask.test(d) && dst_d.dims()[d] != 1) return false;
87 }
88 return true;
89}
90
91// Check if mask corresponds to broadcast per oc
92// Returns true if mask (5D) is equal to [1, 0, 1, 1, 1]
93bool is_per_oc_bcast(const std::bitset<DNNL_MAX_NDIMS> mask,
94 const memory_desc_t &rhs_arg_md) {
95 const bool broadcast_per_oc = !mask.test(1);
96
97 if (!broadcast_per_oc) return false;
98
99 const auto ndims = rhs_arg_md.ndims;
100
101 if (ndims > 0 && rhs_arg_md.dims[0] != 1) return false;
102
103 for (int dim = 2; dim < ndims; dim++) {
104 if (rhs_arg_md.dims[dim] != 1) return false;
105 }
106 return true;
107}
108
109bool bcast_strategy_enabled(const bcast_set_t &supported_strategy_set,
110 const broadcasting_strategy_t &bcast) {
111 return supported_strategy_set.find(bcast) != supported_strategy_set.cend();
112}
113
114broadcasting_strategy_t get_per_oc_bcast(
115 const bcast_set_t &supported_strategy_set,
116 const memory_desc_wrapper &dst_d) {
117
118 const auto ndims = dst_d.ndims();
119 const bool use_per_oc_spatial_strategy = bcast_strategy_enabled(
120 supported_strategy_set, broadcasting_strategy_t::per_oc_spatial);
121
122 if (use_per_oc_spatial_strategy && dst_d.is_blocking_desc()) {
123 const auto &strides = dst_d.blocking_desc().strides;
124
125 //per_oc_spatial used in nchw data format and matmul having ndims >= 3
126 return (dst_d.is_plain() && strides[0] >= strides[1]
127 && IMPLICATION(ndims < 3, strides[1] != 1)
128 && IMPLICATION(ndims >= 3, strides[1] >= strides[2]))
129 ? broadcasting_strategy_t::per_oc_spatial
130 : broadcasting_strategy_t::per_oc;
131 }
132
133 return broadcasting_strategy_t::per_oc;
134}
135} // namespace
136
137// Compares dimensions of rhs arg (src1) with dimensions of destination.
138// Produces broadcast mask and returns broadcast strategy which
139// corresponds to given mask.
140// Mask bits are set to 1 if corresponding dimensions are different
141// or if both dimensions are equal to 1.
142// Otherwise mask bits are set to 0.
143broadcasting_strategy_t get_rhs_arg_broadcasting_strategy(
144 const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d,
145 const bcast_set_t &supported_strategy_set) {
146
147 const auto is_enabled = [&](const broadcasting_strategy_t &bcast) {
148 return bcast_strategy_enabled(supported_strategy_set, bcast);
149 };
150
151 const int ndims = rhs_arg_md.ndims;
152 const auto output_dims = make_output_dims(dst_d);
153
154 bool all_ones = true;
155 bool all_equal = true;
156 std::bitset<DNNL_MAX_NDIMS> mask(0);
157 for (int d = 0; d < ndims; d++) {
158 const auto &rhs_arg_dim = rhs_arg_md.dims[d];
159 if (rhs_arg_md.dims[d] != 1 && rhs_arg_md.dims[d] != output_dims[d])
160 return broadcasting_strategy_t::unsupported;
161
162 if (rhs_arg_dim != 1) all_ones = false;
163
164 const bool different_dims = output_dims[d] != rhs_arg_md.dims[d];
165 if (different_dims) all_equal = false;
166
167 if (different_dims || output_dims[d] == 1) mask.set(d);
168 }
169
170 broadcasting_strategy_t bcast = broadcasting_strategy_t::unsupported;
171
172 if (all_ones && is_enabled(broadcasting_strategy_t::scalar))
173 bcast = broadcasting_strategy_t::scalar;
174 else if (all_equal && is_enabled(broadcasting_strategy_t::no_broadcast))
175 bcast = broadcasting_strategy_t::no_broadcast;
176 else if (is_per_mb_w_bcast(mask, dst_d)
177 && is_enabled(broadcasting_strategy_t::per_mb_w))
178 bcast = broadcasting_strategy_t::per_mb_w;
179 else if (is_per_oc_bcast(mask, rhs_arg_md)
180 && (is_enabled(broadcasting_strategy_t::per_oc)
181 || is_enabled(broadcasting_strategy_t::per_oc_spatial))) {
182 bcast = get_per_oc_bcast(supported_strategy_set, dst_d);
183 } else if (is_per_w_bcast(mask, dst_d)
184 && is_enabled(broadcasting_strategy_t::per_w))
185 bcast = broadcasting_strategy_t::per_w;
186 else if (is_channel_bcast(mask, dst_d)
187 && is_enabled(broadcasting_strategy_t::per_mb_spatial))
188 bcast = broadcasting_strategy_t::per_mb_spatial;
189 else if (is_enabled(broadcasting_strategy_t::shared_axes))
190 bcast = broadcasting_strategy_t::shared_axes;
191
192 return bcast;
193}
194
195} // namespace impl
196} // namespace dnnl
197