1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/gemm_post_ops_inner_product.hpp"
18#include "common/c_types_map.hpp"
19#include "gpu/gemm/gpu_gemm.hpp"
20#include "gpu/ocl/ocl_utils.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace ocl {
26
27status_t gemm_post_ops_inner_product_fwd_t::execute_forward(
28 const exec_ctx_t &ctx) const {
29 using namespace memory_tracking::names;
30
31 gemm_exec_args_t gemm_args;
32 gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC);
33 gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
34
35 std::unique_ptr<memory_storage_t> acc;
36 if (pd()->use_scratchpad() || pd()->use_temp_dst())
37 acc = ctx.get_scratchpad_grantor().get_memory_storage(
38 key_iprod_int_dat_in_acc_dt);
39
40 if (pd()->use_temp_dst()) {
41 gemm_args.c = acc.get();
42 } else {
43 gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DST);
44 }
45
46 gemm_exec_ctx_t gemm_ctx(ctx, gemm_args);
47
48 nested_scratchpad_t ns(ctx, key_nested, gemm_);
49 gemm_ctx.set_scratchpad_grantor(ns.grantor());
50
51 status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx);
52 if (gemm_exec_status != status::success) return gemm_exec_status;
53
54 if (pd()->with_post_process()) {
55 compute::kernel_arg_list_t arg_list;
56 arg_list.set(0, CTX_OUT_STORAGE(DNNL_ARG_DST));
57 arg_list.set(1, CTX_IN_STORAGE(DNNL_ARG_BIAS));
58 arg_list.set(2, CTX_OUT_STORAGE(DNNL_ARG_DST));
59 unsigned arg_idx = append_post_ops_to_arg_list(
60 ctx, arg_list, 3, pd()->attr()->post_ops_);
61 arg_list.set(arg_idx++,
62 pd()->use_scratchpad() ? *acc
63 : memory_storage_t::empty_storage());
64 arg_list.set(arg_idx,
65 pd()->attr_info_.with_runtime_oscales
66 ? CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES)
67 : memory_storage_t::empty_storage());
68
69 size_t mb = pd()->MB();
70 size_t oc = pd()->OC();
71
72 auto nd_range = compute::nd_range_t({mb * oc});
73
74 status_t status
75 = parallel_for(ctx, nd_range, post_process_kernel_, arg_list);
76 if (status != status::success) return status;
77 }
78
79 return status::success;
80}
81
82} // namespace ocl
83} // namespace gpu
84} // namespace impl
85} // namespace dnnl
86