1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_BINARY_HPP
18#define CPU_X64_JIT_UNI_BINARY_HPP
19
20#include "common/primitive.hpp"
21
22#include "cpu/cpu_eltwise_pd.hpp"
23#include "cpu/x64/jit_uni_binary_kernel.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30struct binary_kernel_t;
31
32using op_t = binary_op_t;
33using bcast_t = binary_bcast_t;
34
35struct jit_uni_binary_t : public primitive_t {
36 struct pd_t : public cpu_binary_pd_t {
37 using cpu_binary_pd_t::cpu_binary_pd_t;
38
39 DECLARE_COMMON_PD_T("jit:uni", jit_uni_binary_t);
40
41 status_t init(engine_t *engine);
42
43 jit_binary_conf_t get_conf() const { return conf_; };
44
45 private:
46 op_t get_op_type(const memory_desc_wrapper &src0_d);
47 bool is_only_dim0_bcasted(const dims_t &bcast_dims, const int ndims);
48 bcast_t get_bcast_type(
49 const memory_desc_wrapper &src1_d, const dims_t &bcast_dims);
50
51 // alg_preserves_zero returns true if operation preserves zero in case
52 // of both inputs contain zero.
53 bool alg_preserves_zero() const;
54 bool check_scales_mask() const;
55 bool is_bcast_pattern(const dims_t &bcast_dims, const dim_t ndims,
56 const dim_t N_bcast, const dim_t C_bcast,
57 const dim_t W_bcast) const;
58 bool is_bcast_pattern(const dims_t &bcast_dims, const dim_t N_bcast,
59 const dim_t C_bcast) const;
60 bool is_bcast_allowed(const int ndims) const;
61 bool is_format_non_blocked(const memory_desc_wrapper &mdw) const;
62 bool is_different_layouts_allowed(const memory_desc_wrapper &src0_d,
63 const memory_desc_wrapper &src1_d) const;
64 bool is_applicable();
65
66 jit_binary_conf_t conf_;
67 };
68
69 jit_uni_binary_t(const pd_t *apd);
70 ~jit_uni_binary_t() = default;
71
72 status_t init(engine_t *engine) override;
73
74 using data_t = int8_t;
75
76 void execute_no_bcast_strategy(const data_t *src0, const data_t *src1,
77 data_t *dst, const float *scale0, const float *scale1,
78 const std::vector<const void *> &post_ops_binary_rhs_arg_vec,
79 const bcast_t bcast_type) const;
80 void execute_bcast_per_batch_strategy(const data_t *src0,
81 const data_t *src1, data_t *dst, const float *scale0,
82 const float *scale1,
83 const std::vector<const void *> &post_ops_binary_rhs_arg_vec) const;
84 void execute_bcast_per_c_strategy(const data_t *src0, const data_t *src1,
85 data_t *dst, const float *scale0, const float *scale1,
86 const std::vector<const void *> &post_ops_binary_rhs_arg_vec,
87 const op_t op_type, const bcast_t bcast_type,
88 const bool blocked_oc_tail) const;
89 void execute_bcast_per_w_strategy(const data_t *src0, const data_t *src1,
90 data_t *dst, const float *scale0, const float *scale1,
91 const std::vector<const void *> &post_ops_binary_rhs_arg_vec,
92 const op_t op_type, const bool blocked_oc_tail) const;
93
94 status_t execute(const exec_ctx_t &ctx) const override;
95
96private:
97 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
98 static bool post_ops_ok(const primitive_attr_t *attr,
99 const memory_desc_wrapper &src0_d, const memory_desc_wrapper &dst_d,
100 const bool is_src1_different_layouts);
101
102 std::unique_ptr<binary_kernel_t> kernel_;
103 // used only in bcast_c_blocked strategy if tail exists
104 std::unique_ptr<binary_kernel_t> kernel_tail_;
105};
106
107} // namespace x64
108} // namespace cpu
109} // namespace impl
110} // namespace dnnl
111
112#endif
113
114// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
115