1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gemm_post_ops_inner_product.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | #include "gpu/gemm/gpu_gemm.hpp" |
20 | #include "gpu/ocl/ocl_utils.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace ocl { |
26 | |
27 | status_t gemm_post_ops_inner_product_fwd_t::execute_forward( |
28 | const exec_ctx_t &ctx) const { |
29 | using namespace memory_tracking::names; |
30 | |
31 | gemm_exec_args_t gemm_args; |
32 | gemm_args.a = &CTX_IN_STORAGE(DNNL_ARG_SRC); |
33 | gemm_args.b = &CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
34 | |
35 | std::unique_ptr<memory_storage_t> acc; |
36 | if (pd()->use_scratchpad() || pd()->use_temp_dst()) |
37 | acc = ctx.get_scratchpad_grantor().get_memory_storage( |
38 | key_iprod_int_dat_in_acc_dt); |
39 | |
40 | if (pd()->use_temp_dst()) { |
41 | gemm_args.c = acc.get(); |
42 | } else { |
43 | gemm_args.c = &CTX_OUT_STORAGE(DNNL_ARG_DST); |
44 | } |
45 | |
46 | gemm_exec_ctx_t gemm_ctx(ctx, gemm_args); |
47 | |
48 | nested_scratchpad_t ns(ctx, key_nested, gemm_); |
49 | gemm_ctx.set_scratchpad_grantor(ns.grantor()); |
50 | |
51 | status_t gemm_exec_status = gpu_gemm(gemm_)->execute(gemm_ctx); |
52 | if (gemm_exec_status != status::success) return gemm_exec_status; |
53 | |
54 | if (pd()->with_post_process()) { |
55 | compute::kernel_arg_list_t arg_list; |
56 | arg_list.set(0, CTX_OUT_STORAGE(DNNL_ARG_DST)); |
57 | arg_list.set(1, CTX_IN_STORAGE(DNNL_ARG_BIAS)); |
58 | arg_list.set(2, CTX_OUT_STORAGE(DNNL_ARG_DST)); |
59 | unsigned arg_idx = append_post_ops_to_arg_list( |
60 | ctx, arg_list, 3, pd()->attr()->post_ops_); |
61 | arg_list.set(arg_idx++, |
62 | pd()->use_scratchpad() ? *acc |
63 | : memory_storage_t::empty_storage()); |
64 | arg_list.set(arg_idx, |
65 | pd()->attr_info_.with_runtime_oscales |
66 | ? CTX_IN_STORAGE(DNNL_ARG_ATTR_OUTPUT_SCALES) |
67 | : memory_storage_t::empty_storage()); |
68 | |
69 | size_t mb = pd()->MB(); |
70 | size_t oc = pd()->OC(); |
71 | |
72 | auto nd_range = compute::nd_range_t({mb * oc}); |
73 | |
74 | status_t status |
75 | = parallel_for(ctx, nd_range, post_process_kernel_, arg_list); |
76 | if (status != status::success) return status; |
77 | } |
78 | |
79 | return status::success; |
80 | } |
81 | |
82 | } // namespace ocl |
83 | } // namespace gpu |
84 | } // namespace impl |
85 | } // namespace dnnl |
86 | |