1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_BINARY_HPP
18#define GPU_OCL_GEN9_BINARY_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/gpu_binary_pd.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/ocl/ocl_stream.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct gen9_binary_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_binary_pd_t {
38 using gpu_binary_pd_t::gpu_binary_pd_t;
39
40 DECLARE_COMMON_PD_T("ocl:gen9", gen9_binary_t);
41
42 status_t init(engine_t *engine) {
43 using namespace data_type;
44 using namespace format_tag;
45 using sm = primitive_attr_t::skip_mask_t;
46
47 auto *compute_engine
48 = utils::downcast<compute::compute_engine_t *>(engine);
49
50 const auto attr_skip_mask = sm::post_ops | sm::scales_runtime;
51 const memory_desc_wrapper dst_d(dst_md());
52 format_tag_t dst_tag
53 = dst_d.matches_one_of_tag(nc, ncw, nchw, ncdhw);
54 bool is_plain_layout = dst_d.matches_tag(dst_tag);
55 bool ok = set_default_params() == status::success
56 && IMPLICATION(is_broadcast(), is_plain_layout)
57 && !memory_desc_ndims_ok(src_md(0), src_md(1), dst_md())
58 && ((utils::everyone_is(bf16, src_md(0)->data_type,
59 src_md(1)->data_type)
60 && utils::one_of(dst_md()->data_type, bf16, u8))
61 || (utils::one_of(
62 src_md(0)->data_type, f16, f32, s8, u8)
63 && utils::one_of(src_md(1)->data_type, f16,
64 f32, s8, u8)
65 && utils::one_of(dst_md()->data_type, f16,
66 f32, s8, u8)))
67 && IMPLICATION(!attr()->scales_.has_default_values(),
68 utils::one_of(dst_md()->data_type, s8, u8)
69 && utils::everyone_is(
70 attr()->scales_.get(DNNL_ARG_SRC_0)
71 .mask_,
72 attr()->scales_.get(DNNL_ARG_SRC_1)
73 .mask_,
74 0))
75 && attr()->has_default_values(attr_skip_mask)
76 && compute_engine->mayiuse(
77 compute::device_ext_t::intel_subgroups)
78 && IMPLICATION(
79 utils::one_of(f16, src_md(1)->data_type,
80 src_md(0)->data_type, dst_md()->data_type),
81 compute_engine->mayiuse(
82 compute::device_ext_t::khr_fp16)
83 && compute_engine->mayiuse(
84 compute::device_ext_t::
85 intel_subgroups_short))
86 && post_ops_with_binary_ok(
87 attr(), dst_md()->data_type, MAX_NDIMS)
88 && attr_.set_default_formats(dst_md(0)) == status::success
89 && !(attr()->post_ops_.len() > 0
90 && src_md(0)->data_type == bf16
91 && src_md(1)->data_type == bf16
92 && dst_md()->data_type == u8);
93
94 if (!ok) return status::unimplemented;
95
96 return init_conf(engine);
97 }
98
99 status_t init_conf(engine_t *engine);
100
101 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
102
103 bool is_broadcast() {
104 auto bcast_dims = broadcast_dims();
105 for (int i = 0; i < src_md(0)->ndims; ++i) {
106 if (bcast_dims[i] != 0) { return true; }
107 }
108 return false;
109 }
110
111 binary_conf_t conf;
112 };
113
114 status_t init(engine_t *engine) override {
115 compute::kernel_ctx_t kernel_ctx;
116
117 auto status = pd()->init_kernel_ctx(kernel_ctx);
118 if (status != status::success) return status;
119
120 create_kernel(engine, &kernel_, "gen9_binary", kernel_ctx);
121 if (!kernel_) return status::runtime_error;
122
123 return status::success;
124 }
125
126 status_t execute(const exec_ctx_t &ctx) const override {
127
128 auto &src0 = CTX_IN_STORAGE(DNNL_ARG_SRC_0);
129 auto &src1 = CTX_IN_STORAGE(DNNL_ARG_SRC_1);
130 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
131
132 const auto &conf = pd()->conf;
133
134 auto &src0_scale
135 = CTX_IN_STORAGE(DNNL_ARG_SRC_0 | DNNL_ARG_ATTR_SCALES);
136 auto &src1_scale
137 = CTX_IN_STORAGE(DNNL_ARG_SRC_1 | DNNL_ARG_ATTR_SCALES);
138
139 compute::kernel_arg_list_t arg_list;
140 arg_list.set(0, src0);
141 arg_list.set(1, src1);
142 arg_list.set(2, dst);
143
144 unsigned arg_idx = append_post_ops_to_arg_list(
145 ctx, arg_list, 3, pd()->attr()->post_ops_);
146
147 arg_list.set(arg_idx++, src0_scale);
148 arg_list.set(arg_idx, src1_scale);
149
150 auto nd_range = conf.dispatch.nd_range();
151
152 return parallel_for(ctx, nd_range, kernel_, arg_list);
153 }
154
155private:
156 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
157 compute::kernel_t kernel_;
158};
159
160} // namespace ocl
161} // namespace gpu
162} // namespace impl
163} // namespace dnnl
164
165#endif
166