1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_BROADCAST_STRATEGY_HPP
18#define COMMON_BROADCAST_STRATEGY_HPP
19
20#include <array>
21#include <set>
22
23#include "common/c_types_map.hpp"
24#include "common/memory_desc_wrapper.hpp"
25
26namespace dnnl {
27namespace impl {
28
29using output_dims_t = std::array<dim_t, DNNL_MAX_NDIMS>;
30
31enum class broadcasting_strategy_t {
32 // [n, c, d, h, w]
33 scalar, // [1, 1, 1, 1, 1] // Channel_shared
34 per_oc, // [1, c, 1, 1, 1] // Channel-wise
35 per_oc_spatial, // [1, c, 1, 1, 1] specific case for binary kernel nchw format
36 per_mb_spatial, // [n, 1, d, h, w] // Broadcast only channel
37 per_mb_w, // [n, 1, 1, 1, w] // Broadcast per batch and width
38 per_w, // [1, 1, 1, 1, w] // Broadcast per width
39 shared_axes, // [n, 1, d, h, 1] // General case broadcast (any combination)
40 no_broadcast, // [n, c, d, h, w]
41 unsupported
42};
43
44using bcast_set_t = std::set<broadcasting_strategy_t>;
45
46inline const bcast_set_t &default_strategies() {
47 static const bcast_set_t s
48 = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
49 broadcasting_strategy_t::no_broadcast};
50 return s;
51}
52
53output_dims_t make_output_dims(const memory_desc_wrapper &dst_d);
54
55broadcasting_strategy_t get_rhs_arg_broadcasting_strategy(
56 const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d);
57
58broadcasting_strategy_t get_rhs_arg_broadcasting_strategy(
59 const memory_desc_t &rhs_arg_md, const memory_desc_wrapper &dst_d,
60 const bcast_set_t &supported_strategy_set);
61
62} // namespace impl
63} // namespace dnnl
64
65#endif
66