1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_REF_BINARY_HPP |
18 | #define GPU_OCL_REF_BINARY_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "gpu/compute/compute.hpp" |
23 | #include "gpu/gpu_binary_pd.hpp" |
24 | #include "gpu/gpu_primitive.hpp" |
25 | #include "gpu/gpu_resource.hpp" |
26 | #include "gpu/ocl/ocl_stream.hpp" |
27 | #include "gpu/ocl/ocl_utils.hpp" |
28 | #include "gpu/primitive_conf.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace gpu { |
33 | namespace ocl { |
34 | |
35 | struct ref_binary_t : public gpu_primitive_t { |
36 | using gpu_primitive_t::gpu_primitive_t; |
37 | struct pd_t : public gpu_binary_pd_t { |
38 | using gpu_binary_pd_t::gpu_binary_pd_t; |
39 | |
40 | DECLARE_COMMON_PD_T("ocl:ref:any" , ref_binary_t); |
41 | |
42 | status_t init(engine_t *engine) { |
43 | using namespace data_type; |
44 | using sm = primitive_attr_t::skip_mask_t; |
45 | |
46 | const auto attr_skip_mask = sm::post_ops | sm::scales_runtime; |
47 | |
48 | bool ok = set_default_params() == status::success |
49 | && ((utils::everyone_is(bf16, src_md(0)->data_type, |
50 | src_md(1)->data_type) |
51 | && utils::one_of(dst_md()->data_type, bf16, u8)) |
52 | || (utils::one_of( |
53 | src_md(0)->data_type, f16, f32, s8, u8) |
54 | && utils::one_of(src_md(1)->data_type, f16, |
55 | f32, s8, u8) |
56 | && utils::one_of(dst_md()->data_type, f16, |
57 | f32, s8, u8))) |
58 | && !memory_desc_ndims_ok(src_md(0), src_md(1), dst_md()) |
59 | && IMPLICATION(!attr()->scales_.has_default_values(), |
60 | check_scales_mask()) |
61 | && attr()->has_default_values(attr_skip_mask) |
62 | && post_ops_with_binary_ok( |
63 | attr(), dst_md()->data_type, MAX_NDIMS) |
64 | && attr_.set_default_formats(dst_md(0)) == status::success |
65 | && !(attr()->post_ops_.len() > 0 |
66 | && src_md(0)->data_type == bf16 |
67 | && src_md(1)->data_type == bf16 |
68 | && dst_md()->data_type == u8); |
69 | |
70 | if (!ok) return status::unimplemented; |
71 | |
72 | return init_conf(engine); |
73 | } |
74 | |
75 | status_t init_conf(engine_t *engine); |
76 | status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; |
77 | |
78 | bool with_scales(int position) const { |
79 | return !attr()->scales_.get(position).has_default_values(); |
80 | } |
81 | |
82 | bool with_scales() const { |
83 | return with_scales(DNNL_ARG_SRC_0) || with_scales(DNNL_ARG_SRC_1); |
84 | } |
85 | |
86 | bool with_eltwise(int position) const { |
87 | return attr()->post_ops_.contain(primitive_kind::eltwise, position); |
88 | } |
89 | |
90 | bool with_sum() const { |
91 | return attr()->post_ops_.find(primitive_kind::sum) != -1; |
92 | } |
93 | |
94 | float eltwise_alpha() const { |
95 | const int eltwise_idx |
96 | = attr()->post_ops_.find(primitive_kind::eltwise); |
97 | return eltwise_idx != -1 |
98 | ? attr()->post_ops_.entry_[eltwise_idx].eltwise.alpha |
99 | : 1.0f; |
100 | } |
101 | |
102 | float eltwise_beta() const { |
103 | const int eltwise_idx |
104 | = attr()->post_ops_.find(primitive_kind::eltwise); |
105 | return eltwise_idx != -1 |
106 | ? attr()->post_ops_.entry_[eltwise_idx].eltwise.beta |
107 | : 0.0f; |
108 | } |
109 | |
110 | float eltwise_scale() const { |
111 | const int eltwise_idx |
112 | = attr()->post_ops_.find(primitive_kind::eltwise); |
113 | return eltwise_idx != -1 |
114 | ? attr()->post_ops_.entry_[eltwise_idx].eltwise.scale |
115 | : 1.0f; |
116 | } |
117 | |
118 | float sum_scale() const { |
119 | const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); |
120 | return sum_idx != -1 ? attr()->post_ops_.entry_[sum_idx].sum.scale |
121 | : 0.0f; |
122 | } |
123 | |
124 | alg_kind_t eltwise_alg_kind() const { |
125 | const int eltwise_idx |
126 | = attr()->post_ops_.find(primitive_kind::eltwise); |
127 | return eltwise_idx != -1 |
128 | ? attr()->post_ops_.entry_[eltwise_idx].eltwise.alg |
129 | : dnnl_alg_kind_undef; |
130 | } |
131 | |
132 | binary_conf_t conf; |
133 | |
134 | private: |
135 | bool check_scales_mask() const { |
136 | for (const auto &s : attr()->scales_.scales_) { |
137 | if (s.second.mask_ != 0) return false; |
138 | } |
139 | return true; |
140 | } |
141 | }; |
142 | |
143 | status_t init(engine_t *engine) override { |
144 | compute::kernel_ctx_t kernel_ctx; |
145 | |
146 | auto status = pd()->init_kernel_ctx(kernel_ctx); |
147 | if (status != status::success) return status; |
148 | |
149 | create_kernel(engine, &kernel_, "ref_binary" , kernel_ctx); |
150 | if (!kernel_) return status::runtime_error; |
151 | |
152 | return status::success; |
153 | } |
154 | |
155 | status_t execute(const exec_ctx_t &ctx) const override { |
156 | return execute_ref(ctx); |
157 | } |
158 | |
159 | private: |
160 | status_t execute_ref(const exec_ctx_t &ctx) const; |
161 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
162 | compute::kernel_t kernel_; |
163 | }; |
164 | |
165 | } // namespace ocl |
166 | } // namespace gpu |
167 | } // namespace impl |
168 | } // namespace dnnl |
169 | |
170 | #endif |
171 | |