1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_BINARY_PD_HPP
18#define COMMON_BINARY_PD_HPP
19
20#include <assert.h>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "c_types_map.hpp"
25#include "primitive_desc.hpp"
26#include "utils.hpp"
27
28namespace dnnl {
29namespace impl {
30
31struct binary_pd_t : public primitive_desc_t {
32 static constexpr auto base_pkind = primitive_kind::binary;
33
34 typedef binary_pd_t base_class;
35 typedef binary_pd_t hint_class;
36
37 const binary_desc_t *desc() const { return &desc_; }
38 const op_desc_t *op_desc() const override {
39 return reinterpret_cast<const op_desc_t *>(this->desc());
40 }
41
42 status_t query(query_t what, int idx, void *result) const override {
43 switch (what) {
44 case query::alg_kind:
45 *(alg_kind_t *)result = desc()->alg_kind;
46 break;
47 default: return primitive_desc_t::query(what, idx, result);
48 }
49 return status::success;
50 }
51
52 arg_usage_t arg_usage(int arg) const override {
53 if (arg == DNNL_ARG_SRC_0 || arg == DNNL_ARG_SRC_1)
54 return arg_usage_t::input;
55
56 if (arg == DNNL_ARG_DST) return arg_usage_t::output;
57
58 return primitive_desc_t::arg_usage(arg);
59 }
60
61 const memory_desc_t *arg_md(int arg) const override {
62 switch (arg) {
63 case DNNL_ARG_SRC_0: return src_md(0);
64 case DNNL_ARG_SRC_1: return src_md(1);
65 case DNNL_ARG_DST: return dst_md(0);
66 default: return primitive_desc_t::arg_md(arg);
67 }
68 }
69
70 const memory_desc_t *src_md(int index = 0) const override {
71 if (index == 0)
72 return &src0_md_;
73 else if (index == 1)
74 return &src1_md_;
75 return &glob_zero_md;
76 }
77 const memory_desc_t *dst_md(int index = 0) const override {
78 return index == 0 ? &dst_md_ : &glob_zero_md;
79 }
80
81 int n_inputs() const override { return 2 + n_binary_po_inputs(); }
82 int n_outputs() const override { return 1; }
83
84 const dims_t &broadcast_dims() const { return broadcast_dims_; }
85
86 bool has_zero_dim_memory() const {
87 return memory_desc_wrapper(src_md(0)).has_zero_dim();
88 }
89
90 int ndims() const { return memory_desc_wrapper(src_md(0)).ndims(); }
91
92 bool is_tensor_op() const {
93 const memory_desc_wrapper src0_d(src_md(0));
94 const memory_desc_wrapper src1_d(src_md(1));
95 return src0_d.consistent_with(src1_d);
96 }
97
98protected:
99 binary_desc_t desc_;
100
101 memory_desc_t src0_md_;
102 memory_desc_t src1_md_;
103 memory_desc_t dst_md_;
104
105 dims_t broadcast_dims_;
106
107 binary_pd_t(const binary_desc_t *adesc, const primitive_attr_t *attr,
108 const binary_pd_t *hint_fwd_pd)
109 : primitive_desc_t(attr, base_pkind)
110 , desc_(*adesc)
111 , src0_md_(desc_.src_desc[0])
112 , src1_md_(desc_.src_desc[1])
113 , dst_md_(desc_.dst_desc) {
114 init_broadcast_dims();
115 }
116
117 status_t set_default_params() {
118 if (src1_md_.format_kind == format_kind::any) {
119 const memory_desc_wrapper src_d(src_md(0));
120 if (src_d.is_blocking_desc()) {
121 CHECK(memory_desc_init_by_blocking_desc(
122 src1_md_, src_d.blocking_desc()));
123 }
124 }
125
126 if (dst_md_.format_kind == format_kind::any) {
127 const memory_desc_wrapper src_d(src_md(0));
128 if (src_d.is_blocking_desc()) {
129 CHECK(memory_desc_init_by_blocking_desc(
130 dst_md_, src_d.blocking_desc()));
131 }
132 }
133
134 return status::success;
135 }
136
137 bool attr_post_ops_ok() const {
138 using namespace primitive_kind;
139 const auto &p = attr()->post_ops_;
140 switch (p.len()) {
141 case 0: return true;
142 case 1: return p.contain(sum, 0) || p.contain(eltwise, 0);
143 case 2: return p.contain(sum, 0) && p.contain(eltwise, 1);
144 default: return false;
145 }
146 }
147
148private:
149 void init_broadcast_dims() {
150 const dims_t &dims_A = src_md(0)->dims;
151 const dims_t &dims_B = src_md(1)->dims;
152
153 for (int d = 0; d < ndims(); ++d)
154 broadcast_dims_[d]
155 = (dims_A[d] == dims_B[d] && dims_A[d] != 1) ? 0 : 1;
156 }
157};
158
159} // namespace impl
160} // namespace dnnl
161
162#endif
163