1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_GEN9_SIMPLE_SUM_KERNEL_F32_HPP |
18 | #define GPU_JIT_GEN9_SIMPLE_SUM_KERNEL_F32_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "gpu/jit/jit_generator.hpp" |
22 | #include "gpu/primitive_conf.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace jit { |
28 | |
29 | class gen9_simple_sum_kernel_f32_t : public jit_generator<gpu_gen9> { |
30 | public: |
31 | gen9_simple_sum_kernel_f32_t() : jit_generator<gpu_gen9>() { |
32 | using namespace ngen; |
33 | constexpr auto GlobalPtr = ExternalArgumentType::GlobalPtr; |
34 | constexpr auto Scalar = ExternalArgumentType::Scalar; |
35 | |
36 | newArgument("input" , GlobalPtr); |
37 | newArgument("output" , GlobalPtr); |
38 | newArgument("scale" , DataType::f, Scalar); |
39 | newArgument("a" , DataType::d, Scalar); |
40 | externalName("ngen_gen9_simple_sum" ); |
41 | finalizeInterface(); |
42 | |
43 | setDefaultNoMask(); |
44 | |
45 | Label append, done; |
46 | |
47 | auto global_id0_arg = r0.ud(1); |
48 | auto src_ptr = r32; |
49 | auto dst_ptr = r34; |
50 | auto global_id = r33; |
51 | |
52 | auto src = r40; |
53 | auto dst = r42; |
54 | auto factor = r41; |
55 | auto sum = r43; |
56 | |
57 | mov<uint32_t>(1, global_id, global_id0_arg); |
58 | mov<uint64_t>(1, src_ptr, getArgument("input" )); |
59 | mov<uint64_t>(1, dst_ptr, getArgument("output" )); |
60 | mov<float>(1, factor, getArgument("scale" )); |
61 | |
62 | mul<uint32_t>(1, global_id, global_id, 4); |
63 | add<uint32_t>(1, src_ptr, src_ptr, global_id); |
64 | add<uint32_t>(1, dst_ptr, dst_ptr, global_id); |
65 | |
66 | load(1, src, scattered_dword(), A64, src_ptr); |
67 | mul<float>(1, sum, factor, src); |
68 | |
69 | cmp(1 | eq | f0[0], null.ud(), getArgument("a" ), 0); |
70 | jmpi(1 | ~f0[0], append); |
71 | store(1, scattered_dword(), A64, dst_ptr, sum); |
72 | jmpi(1, done); |
73 | |
74 | mark(append); |
75 | load(1, dst, scattered_dword(), A64, dst_ptr); |
76 | add<float>(1, sum, sum, dst); |
77 | store(1, scattered_dword(), A64, dst_ptr, sum); |
78 | |
79 | mark(done); |
80 | mov<uint32_t>(8, r127, r0); |
81 | threadend(r127); |
82 | } |
83 | }; |
84 | |
85 | } // namespace jit |
86 | } // namespace gpu |
87 | } // namespace impl |
88 | } // namespace dnnl |
89 | |
90 | #endif // GPU_JIT_GEN9_SIMPLE_SUM_KERNEL_F32_HPP |
91 | |