1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CONV_ZERO_OUT_HPP |
18 | #define GPU_JIT_CONV_ZERO_OUT_HPP |
19 | |
20 | #include "gpu/jit/codegen/kernel.hpp" |
21 | #include "gpu/jit/codegen/register_scope.hpp" |
22 | #include "gpu/jit/ir/tensor.hpp" |
23 | #include "gpu/jit/ngen/ngen.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace gpu { |
28 | namespace jit { |
29 | |
30 | template <ngen::HW hw = ngen::HW::Unknown> |
31 | class zero_out_kernel_t : public ir_kernel_t<hw> { |
32 | public: |
33 | IR_KERNEL_FORWARD(hw) |
34 | |
35 | zero_out_kernel_t(const exec_config_t &exec_cfg, |
36 | const kernel_info_t &kernel_info, bool require_dpas, |
37 | grf_mode_t grf_mode) |
38 | : ir_kernel_t<hw>( |
39 | "zero_out" , exec_cfg, kernel_info, require_dpas, grf_mode) { |
40 | |
41 | setup_interface(); |
42 | generate_prologue(); |
43 | |
44 | std::vector<std::string> arg_names(kernel_info.nargs()); |
45 | for (int i = 0; i < kernel_info.nargs(); i++) { |
46 | arg_names[i] = kernel_info.arg_name(i); |
47 | } |
48 | |
49 | int simd_size = getSIMD(); |
50 | bool use_a64 = false; |
51 | // XXX: Stateful messages don't work on XeHPC. |
52 | use_a64 = (hw == ngen::HW::XeHPC); |
53 | |
54 | auto ptr = getArgument(arg_names[0]); |
55 | auto surf = Surface(getArgumentSurface(arg_names[0])); |
56 | auto size = getArgument(arg_names[1]); |
57 | auto global_id = ra_.template alloc_sub<uint32_t>(); |
58 | auto off0 = ra_.template alloc_sub<uint32_t>(); |
59 | |
60 | mul(1, global_id, r0.ud(1), getLocalSize(0).uw()); |
61 | add(1, global_id, global_id, getLocalID(0)); |
62 | shl(1, off0, global_id, math::ilog2q(bytes_per_thr / simd_size)); |
63 | |
64 | int grf_size = ngen::GRF::bytes(hw); |
65 | int bytes_per_store = 16; |
66 | int ud_size = sizeof(uint32_t); |
67 | int uq_size = sizeof(uint64_t); |
68 | |
69 | auto zero = ra_.alloc_range(bytes_per_store * ud_size / grf_size); |
70 | auto off_vec = ra_.alloc_range(bytes_per_thr * ud_size / grf_size); |
71 | auto off_vec_q_strided |
72 | = ra_.alloc_range(bytes_per_thr * uq_size / grf_size); |
73 | auto ptr_vec = ra_.alloc_range(bytes_per_thr * uq_size / grf_size); |
74 | |
75 | for (int i = 0; i < bytes_per_store * ud_size; i += 64) { |
76 | auto z = get_subregister(hw, ngen::DataType::ud, zero, i); |
77 | mov(16, z, 0); |
78 | } |
79 | |
80 | auto idx_vec = ra_.alloc().uw(); |
81 | mov(8, idx_vec, ngen::Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7)); |
82 | |
83 | for (int i = 0; i < bytes_per_thr; i += 8) { |
84 | auto off_sub_vec |
85 | = get_subregister(hw, ngen::DataType::ud, off_vec, i)(1); |
86 | add3(8, off_sub_vec, off0, idx_vec, i); |
87 | if (use_a64) { |
88 | auto ptr_sub_vec = get_subregister( |
89 | hw, ngen::DataType::uq, ptr_vec, i)(1); |
90 | auto off_sub_vec_q_strided = get_subregister( |
91 | hw, ngen::DataType::ud, off_vec_q_strided, i * 2)(2); |
92 | emov(8, off_sub_vec_q_strided, off_sub_vec); |
93 | eadd(8, ptr_sub_vec, ptr, off_sub_vec_q_strided); |
94 | } |
95 | } |
96 | |
97 | for (int i = 0; i < bytes_per_thr; i += bytes_per_store) { |
98 | auto off_sub_vec |
99 | = get_subregister(hw, ngen::DataType::ud, off_vec, i)(1); |
100 | cmp(16 | lt | f0[0], off_sub_vec, size); |
101 | if (use_a64) { |
102 | auto h_a64 |
103 | = get_subregister(hw, ngen::DataType::uq, ptr_vec, i); |
104 | store(16 | f0[0], ngen::scattered_byte(), A64, h_a64, zero[0]); |
105 | } else { |
106 | auto h_bts = off_sub_vec; |
107 | store(16 | f0[0], ngen::scattered_byte(), surf, h_bts, zero[0]); |
108 | } |
109 | } |
110 | |
111 | generate_epilogue(); |
112 | } |
113 | |
114 | static compute::nd_range_t nd_range(int simd, int size) { |
115 | return compute::nd_range_t( |
116 | {utils::div_up(size, bytes_per_thr) * simd, 1, 1}); |
117 | } |
118 | |
119 | static const int bytes_per_thr; |
120 | }; |
121 | |
122 | template <ngen::HW hw> |
123 | const int zero_out_kernel_t<hw>::bytes_per_thr = 128; |
124 | |
125 | } // namespace jit |
126 | } // namespace gpu |
127 | } // namespace impl |
128 | } // namespace dnnl |
129 | |
130 | #endif |
131 | |