1/*******************************************************************************
2 * Copyright 2021-2022 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#include "gpu/jit/jit_post_op_injector.hpp"
18#include "common/impl_registration.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace gpu {
23namespace jit {
24
25using namespace ngen;
26
27template <gpu_gen_t hw>
28int jit_post_op_injector<hw>::min_scratch_regs() {
29 int regs_cnt = 0;
30 for (size_t idx = 0; idx < workers_.size(); ++idx) {
31 regs_cnt = nstl::max(regs_cnt, workers_[idx].min_scratch_regs());
32 }
33 return regs_cnt;
34}
35
36template <gpu_gen_t hw>
37int jit_post_op_injector<hw>::preferred_scratch_regs() {
38 int regs_cnt = 0;
39 for (size_t idx = 0; idx < workers_.size(); ++idx) {
40 regs_cnt = nstl::max(regs_cnt, workers_[idx].preferred_scratch_regs());
41 }
42 return regs_cnt;
43}
44
45template <gpu_gen_t hw>
46void jit_post_op_injector<hw>::set_scratch(const ngen::GRFRange &scratch) {
47 for (size_t idx = 0; idx < workers_.size(); ++idx) {
48 workers_[idx].set_scratch(scratch);
49 if (workers_.size() == 1) workers_[idx].prepare();
50 }
51 scratch_ = scratch;
52}
53
54template <gpu_gen_t hw>
55void jit_post_op_injector<hw>::compute(const ngen::GRFRange &regs) {
56 for (size_t idx = 0; idx < workers_.size(); ++idx) {
57 if (workers_.size() > 1) workers_[idx].prepare();
58 workers_[idx].compute(regs);
59 }
60}
61
62REG_GEN9_ISA(template struct jit_post_op_injector<gpu_gen9>);
63REG_GEN11_ISA(template struct jit_post_op_injector<gpu_gen11>);
64REG_XELP_ISA(template struct jit_post_op_injector<gpu_xe_lp>);
65REG_XEHP_ISA(template struct jit_post_op_injector<gpu_xe_hp>);
66REG_XEHPG_ISA(template struct jit_post_op_injector<gpu_xe_hpg>);
67REG_XEHPC_ISA(template struct jit_post_op_injector<gpu_xe_hpc>);
68
69} // namespace jit
70} // namespace gpu
71} // namespace impl
72} // namespace dnnl
73