1/*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#ifndef GPU_JIT_GEN9_SIMPLE_SUM_KERNEL_F32_HPP
18#define GPU_JIT_GEN9_SIMPLE_SUM_KERNEL_F32_HPP
19
20#include "common/c_types_map.hpp"
21#include "gpu/jit/jit_generator.hpp"
22#include "gpu/primitive_conf.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace jit {
28
29class gen9_simple_sum_kernel_f32_t : public jit_generator<gpu_gen9> {
30public:
31 gen9_simple_sum_kernel_f32_t() : jit_generator<gpu_gen9>() {
32 using namespace ngen;
33 constexpr auto GlobalPtr = ExternalArgumentType::GlobalPtr;
34 constexpr auto Scalar = ExternalArgumentType::Scalar;
35
36 newArgument("input", GlobalPtr);
37 newArgument("output", GlobalPtr);
38 newArgument("scale", DataType::f, Scalar);
39 newArgument("a", DataType::d, Scalar);
40 externalName("ngen_gen9_simple_sum");
41 finalizeInterface();
42
43 setDefaultNoMask();
44
45 Label append, done;
46
47 auto global_id0_arg = r0.ud(1);
48 auto src_ptr = r32;
49 auto dst_ptr = r34;
50 auto global_id = r33;
51
52 auto src = r40;
53 auto dst = r42;
54 auto factor = r41;
55 auto sum = r43;
56
57 mov<uint32_t>(1, global_id, global_id0_arg);
58 mov<uint64_t>(1, src_ptr, getArgument("input"));
59 mov<uint64_t>(1, dst_ptr, getArgument("output"));
60 mov<float>(1, factor, getArgument("scale"));
61
62 mul<uint32_t>(1, global_id, global_id, 4);
63 add<uint32_t>(1, src_ptr, src_ptr, global_id);
64 add<uint32_t>(1, dst_ptr, dst_ptr, global_id);
65
66 load(1, src, scattered_dword(), A64, src_ptr);
67 mul<float>(1, sum, factor, src);
68
69 cmp(1 | eq | f0[0], null.ud(), getArgument("a"), 0);
70 jmpi(1 | ~f0[0], append);
71 store(1, scattered_dword(), A64, dst_ptr, sum);
72 jmpi(1, done);
73
74 mark(append);
75 load(1, dst, scattered_dword(), A64, dst_ptr);
76 add<float>(1, sum, sum, dst);
77 store(1, scattered_dword(), A64, dst_ptr, sum);
78
79 mark(done);
80 mov<uint32_t>(8, r127, r0);
81 threadend(r127);
82 }
83};
84
85} // namespace jit
86} // namespace gpu
87} // namespace impl
88} // namespace dnnl
89
90#endif // GPU_JIT_GEN9_SIMPLE_SUM_KERNEL_F32_HPP
91