1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include "cpu/binary_injector_utils.hpp" |
17 | #include "common/primitive.hpp" |
18 | #include "common/primitive_attr.hpp" |
19 | #include "oneapi/dnnl/dnnl_types.h" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace cpu { |
24 | namespace binary_injector_utils { |
25 | |
26 | std::vector<const void *> prepare_binary_args(const post_ops_t &post_ops, |
27 | const exec_ctx_t &ctx, const unsigned first_arg_idx_offset) { |
28 | std::vector<const void *> post_ops_binary_rhs_arg_vec; |
29 | post_ops_binary_rhs_arg_vec.reserve(post_ops.entry_.size()); |
30 | |
31 | unsigned idx = first_arg_idx_offset; |
32 | for (const auto &post_op : post_ops.entry_) { |
33 | if (post_op.is_binary()) { |
34 | post_ops_binary_rhs_arg_vec.emplace_back(CTX_IN_MEM(const void *, |
35 | DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); |
36 | } |
37 | ++idx; |
38 | } |
39 | |
40 | post_ops_binary_rhs_arg_vec.shrink_to_fit(); |
41 | |
42 | return post_ops_binary_rhs_arg_vec; |
43 | } |
44 | |
45 | bool bcast_strategy_present( |
46 | const std::vector<broadcasting_strategy_t> &post_ops_bcasts, |
47 | const broadcasting_strategy_t bcast_strategy) { |
48 | for (const auto &post_op_bcast : post_ops_bcasts) |
49 | if (post_op_bcast == bcast_strategy) return true; |
50 | return false; |
51 | } |
52 | |
53 | std::vector<broadcasting_strategy_t> ( |
54 | const std::vector<dnnl_post_ops::entry_t> &post_ops, |
55 | const memory_desc_wrapper &dst_md) { |
56 | std::vector<broadcasting_strategy_t> post_ops_bcasts; |
57 | post_ops_bcasts.reserve(post_ops.size()); |
58 | for (const auto &post_op : post_ops) |
59 | if (post_op.is_binary()) |
60 | post_ops_bcasts.emplace_back(get_rhs_arg_broadcasting_strategy( |
61 | post_op.binary.src1_desc, dst_md)); |
62 | return post_ops_bcasts; |
63 | } |
64 | |
65 | } // namespace binary_injector_utils |
66 | } // namespace cpu |
67 | } // namespace impl |
68 | } // namespace dnnl |
69 | |