1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include "cpu/binary_injector_utils.hpp"
17#include "common/primitive.hpp"
18#include "common/primitive_attr.hpp"
19#include "oneapi/dnnl/dnnl_types.h"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24namespace binary_injector_utils {
25
26std::vector<const void *> prepare_binary_args(const post_ops_t &post_ops,
27 const exec_ctx_t &ctx, const unsigned first_arg_idx_offset) {
28 std::vector<const void *> post_ops_binary_rhs_arg_vec;
29 post_ops_binary_rhs_arg_vec.reserve(post_ops.entry_.size());
30
31 unsigned idx = first_arg_idx_offset;
32 for (const auto &post_op : post_ops.entry_) {
33 if (post_op.is_binary()) {
34 post_ops_binary_rhs_arg_vec.emplace_back(CTX_IN_MEM(const void *,
35 DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1));
36 }
37 ++idx;
38 }
39
40 post_ops_binary_rhs_arg_vec.shrink_to_fit();
41
42 return post_ops_binary_rhs_arg_vec;
43}
44
45bool bcast_strategy_present(
46 const std::vector<broadcasting_strategy_t> &post_ops_bcasts,
47 const broadcasting_strategy_t bcast_strategy) {
48 for (const auto &post_op_bcast : post_ops_bcasts)
49 if (post_op_bcast == bcast_strategy) return true;
50 return false;
51}
52
53std::vector<broadcasting_strategy_t> extract_bcast_strategies(
54 const std::vector<dnnl_post_ops::entry_t> &post_ops,
55 const memory_desc_wrapper &dst_md) {
56 std::vector<broadcasting_strategy_t> post_ops_bcasts;
57 post_ops_bcasts.reserve(post_ops.size());
58 for (const auto &post_op : post_ops)
59 if (post_op.is_binary())
60 post_ops_bcasts.emplace_back(get_rhs_arg_broadcasting_strategy(
61 post_op.binary.src1_desc, dst_md));
62 return post_ops_bcasts;
63}
64
65} // namespace binary_injector_utils
66} // namespace cpu
67} // namespace impl
68} // namespace dnnl
69