1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_BINARY_INJECTOR_UTILS_HPP |
18 | #define CPU_BINARY_INJECTOR_UTILS_HPP |
19 | |
20 | #include <tuple> |
21 | #include <vector> |
22 | |
23 | #include "common/broadcast_strategy.hpp" |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/primitive_attr.hpp" |
26 | #include "common/primitive_exec_types.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace binary_injector_utils { |
32 | /* |
33 | * Extracts pointers to tensors passed by user as binary postops rhs (right-hand-side) |
34 | * arguments (arg1 from binary postop) from execution context. Those pointers are placed |
35 | * in vector in order of binary post-op appearance inside post_ops_t structure. Returned vector |
36 | * usually is passed to kernel during execution phase in runtime params. |
37 | * @param first_arg_idx_offset - offset for indexation of binary postop arguments |
38 | * (used for fusions with dw convolutions) |
39 | */ |
40 | std::vector<const void *> prepare_binary_args(const post_ops_t &post_ops, |
41 | const dnnl::impl::exec_ctx_t &ctx, |
42 | const unsigned first_arg_idx_offset = 0); |
43 | |
44 | bool bcast_strategy_present( |
45 | const std::vector<broadcasting_strategy_t> &post_ops_bcasts, |
46 | const broadcasting_strategy_t bcast_strategy); |
47 | |
48 | std::vector<broadcasting_strategy_t> ( |
49 | const std::vector<dnnl_post_ops::entry_t> &post_ops, |
50 | const memory_desc_wrapper &dst_md); |
51 | |
52 | /* |
53 | * Returns a tuple of bools, which size is equal to number of bcast |
54 | * strategies passed in. Values at consecutive positions indicate existence of |
55 | * binary postop with a particular bcast strategy in post_ops vector. |
56 | */ |
57 | template <typename... Str> |
58 | auto bcast_strategies_present_tup( |
59 | const std::vector<dnnl_post_ops::entry_t> &post_ops, |
60 | const memory_desc_wrapper &dst_md, Str... bcast_strategies) |
61 | -> decltype(std::make_tuple((bcast_strategies, false)...)) { |
62 | const auto post_ops_bcasts = extract_bcast_strategies(post_ops, dst_md); |
63 | return std::make_tuple( |
64 | bcast_strategy_present(post_ops_bcasts, bcast_strategies)...); |
65 | } |
66 | |
67 | } // namespace binary_injector_utils |
68 | } // namespace cpu |
69 | } // namespace impl |
70 | } // namespace dnnl |
71 | |
72 | #endif |
73 | |