1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CONV_ZERO_OUT_HPP
18#define GPU_JIT_CONV_ZERO_OUT_HPP
19
20#include "gpu/jit/codegen/kernel.hpp"
21#include "gpu/jit/codegen/register_scope.hpp"
22#include "gpu/jit/ir/tensor.hpp"
23#include "gpu/jit/ngen/ngen.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30template <ngen::HW hw = ngen::HW::Unknown>
31class zero_out_kernel_t : public ir_kernel_t<hw> {
32public:
33 IR_KERNEL_FORWARD(hw)
34
35 zero_out_kernel_t(const exec_config_t &exec_cfg,
36 const kernel_info_t &kernel_info, bool require_dpas,
37 grf_mode_t grf_mode)
38 : ir_kernel_t<hw>(
39 "zero_out", exec_cfg, kernel_info, require_dpas, grf_mode) {
40
41 setup_interface();
42 generate_prologue();
43
44 std::vector<std::string> arg_names(kernel_info.nargs());
45 for (int i = 0; i < kernel_info.nargs(); i++) {
46 arg_names[i] = kernel_info.arg_name(i);
47 }
48
49 int simd_size = getSIMD();
50 bool use_a64 = false;
51 // XXX: Stateful messages don't work on XeHPC.
52 use_a64 = (hw == ngen::HW::XeHPC);
53
54 auto ptr = getArgument(arg_names[0]);
55 auto surf = Surface(getArgumentSurface(arg_names[0]));
56 auto size = getArgument(arg_names[1]);
57 auto global_id = ra_.template alloc_sub<uint32_t>();
58 auto off0 = ra_.template alloc_sub<uint32_t>();
59
60 mul(1, global_id, r0.ud(1), getLocalSize(0).uw());
61 add(1, global_id, global_id, getLocalID(0));
62 shl(1, off0, global_id, math::ilog2q(bytes_per_thr / simd_size));
63
64 int grf_size = ngen::GRF::bytes(hw);
65 int bytes_per_store = 16;
66 int ud_size = sizeof(uint32_t);
67 int uq_size = sizeof(uint64_t);
68
69 auto zero = ra_.alloc_range(bytes_per_store * ud_size / grf_size);
70 auto off_vec = ra_.alloc_range(bytes_per_thr * ud_size / grf_size);
71 auto off_vec_q_strided
72 = ra_.alloc_range(bytes_per_thr * uq_size / grf_size);
73 auto ptr_vec = ra_.alloc_range(bytes_per_thr * uq_size / grf_size);
74
75 for (int i = 0; i < bytes_per_store * ud_size; i += 64) {
76 auto z = get_subregister(hw, ngen::DataType::ud, zero, i);
77 mov(16, z, 0);
78 }
79
80 auto idx_vec = ra_.alloc().uw();
81 mov(8, idx_vec, ngen::Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7));
82
83 for (int i = 0; i < bytes_per_thr; i += 8) {
84 auto off_sub_vec
85 = get_subregister(hw, ngen::DataType::ud, off_vec, i)(1);
86 add3(8, off_sub_vec, off0, idx_vec, i);
87 if (use_a64) {
88 auto ptr_sub_vec = get_subregister(
89 hw, ngen::DataType::uq, ptr_vec, i)(1);
90 auto off_sub_vec_q_strided = get_subregister(
91 hw, ngen::DataType::ud, off_vec_q_strided, i * 2)(2);
92 emov(8, off_sub_vec_q_strided, off_sub_vec);
93 eadd(8, ptr_sub_vec, ptr, off_sub_vec_q_strided);
94 }
95 }
96
97 for (int i = 0; i < bytes_per_thr; i += bytes_per_store) {
98 auto off_sub_vec
99 = get_subregister(hw, ngen::DataType::ud, off_vec, i)(1);
100 cmp(16 | lt | f0[0], off_sub_vec, size);
101 if (use_a64) {
102 auto h_a64
103 = get_subregister(hw, ngen::DataType::uq, ptr_vec, i);
104 store(16 | f0[0], ngen::scattered_byte(), A64, h_a64, zero[0]);
105 } else {
106 auto h_bts = off_sub_vec;
107 store(16 | f0[0], ngen::scattered_byte(), surf, h_bts, zero[0]);
108 }
109 }
110
111 generate_epilogue();
112 }
113
114 static compute::nd_range_t nd_range(int simd, int size) {
115 return compute::nd_range_t(
116 {utils::div_up(size, bytes_per_thr) * simd, 1, 1});
117 }
118
119 static const int bytes_per_thr;
120};
121
122template <ngen::HW hw>
123const int zero_out_kernel_t<hw>::bytes_per_thr = 128;
124
125} // namespace jit
126} // namespace gpu
127} // namespace impl
128} // namespace dnnl
129
130#endif
131