1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_BINARY_HPP
18#define GPU_OCL_REF_BINARY_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/gpu_binary_pd.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/ocl/ocl_stream.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct ref_binary_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_binary_pd_t {
38 using gpu_binary_pd_t::gpu_binary_pd_t;
39
40 DECLARE_COMMON_PD_T("ocl:ref:any", ref_binary_t);
41
42 status_t init(engine_t *engine) {
43 using namespace data_type;
44 using sm = primitive_attr_t::skip_mask_t;
45
46 const auto attr_skip_mask = sm::post_ops | sm::scales_runtime;
47
48 bool ok = set_default_params() == status::success
49 && ((utils::everyone_is(bf16, src_md(0)->data_type,
50 src_md(1)->data_type)
51 && utils::one_of(dst_md()->data_type, bf16, u8))
52 || (utils::one_of(
53 src_md(0)->data_type, f16, f32, s8, u8)
54 && utils::one_of(src_md(1)->data_type, f16,
55 f32, s8, u8)
56 && utils::one_of(dst_md()->data_type, f16,
57 f32, s8, u8)))
58 && !memory_desc_ndims_ok(src_md(0), src_md(1), dst_md())
59 && IMPLICATION(!attr()->scales_.has_default_values(),
60 check_scales_mask())
61 && attr()->has_default_values(attr_skip_mask)
62 && post_ops_with_binary_ok(
63 attr(), dst_md()->data_type, MAX_NDIMS)
64 && attr_.set_default_formats(dst_md(0)) == status::success
65 && !(attr()->post_ops_.len() > 0
66 && src_md(0)->data_type == bf16
67 && src_md(1)->data_type == bf16
68 && dst_md()->data_type == u8);
69
70 if (!ok) return status::unimplemented;
71
72 return init_conf(engine);
73 }
74
75 status_t init_conf(engine_t *engine);
76 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
77
78 bool with_scales(int position) const {
79 return !attr()->scales_.get(position).has_default_values();
80 }
81
82 bool with_scales() const {
83 return with_scales(DNNL_ARG_SRC_0) || with_scales(DNNL_ARG_SRC_1);
84 }
85
86 bool with_eltwise(int position) const {
87 return attr()->post_ops_.contain(primitive_kind::eltwise, position);
88 }
89
90 bool with_sum() const {
91 return attr()->post_ops_.find(primitive_kind::sum) != -1;
92 }
93
94 float eltwise_alpha() const {
95 const int eltwise_idx
96 = attr()->post_ops_.find(primitive_kind::eltwise);
97 return eltwise_idx != -1
98 ? attr()->post_ops_.entry_[eltwise_idx].eltwise.alpha
99 : 1.0f;
100 }
101
102 float eltwise_beta() const {
103 const int eltwise_idx
104 = attr()->post_ops_.find(primitive_kind::eltwise);
105 return eltwise_idx != -1
106 ? attr()->post_ops_.entry_[eltwise_idx].eltwise.beta
107 : 0.0f;
108 }
109
110 float eltwise_scale() const {
111 const int eltwise_idx
112 = attr()->post_ops_.find(primitive_kind::eltwise);
113 return eltwise_idx != -1
114 ? attr()->post_ops_.entry_[eltwise_idx].eltwise.scale
115 : 1.0f;
116 }
117
118 float sum_scale() const {
119 const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
120 return sum_idx != -1 ? attr()->post_ops_.entry_[sum_idx].sum.scale
121 : 0.0f;
122 }
123
124 alg_kind_t eltwise_alg_kind() const {
125 const int eltwise_idx
126 = attr()->post_ops_.find(primitive_kind::eltwise);
127 return eltwise_idx != -1
128 ? attr()->post_ops_.entry_[eltwise_idx].eltwise.alg
129 : dnnl_alg_kind_undef;
130 }
131
132 binary_conf_t conf;
133
134 private:
135 bool check_scales_mask() const {
136 for (const auto &s : attr()->scales_.scales_) {
137 if (s.second.mask_ != 0) return false;
138 }
139 return true;
140 }
141 };
142
143 status_t init(engine_t *engine) override {
144 compute::kernel_ctx_t kernel_ctx;
145
146 auto status = pd()->init_kernel_ctx(kernel_ctx);
147 if (status != status::success) return status;
148
149 create_kernel(engine, &kernel_, "ref_binary", kernel_ctx);
150 if (!kernel_) return status::runtime_error;
151
152 return status::success;
153 }
154
155 status_t execute(const exec_ctx_t &ctx) const override {
156 return execute_ref(ctx);
157 }
158
159private:
160 status_t execute_ref(const exec_ctx_t &ctx) const;
161 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
162 compute::kernel_t kernel_;
163};
164
165} // namespace ocl
166} // namespace gpu
167} // namespace impl
168} // namespace dnnl
169
170#endif
171