1 | /******************************************************************************* |
2 | * Copyright 2019-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_lrn.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | status_t ref_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
25 | |
26 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
27 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
28 | auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE); |
29 | |
30 | compute::kernel_arg_list_t arg_list; |
31 | arg_list.set(0, src); |
32 | if (pd()->desc()->prop_kind == prop_kind::forward_training) { |
33 | arg_list.set(1, ws); |
34 | arg_list.set(2, dst); |
35 | } else { |
36 | arg_list.set(1, dst); |
37 | } |
38 | |
39 | auto nd_range = pd()->dispatch.nd_range(); |
40 | |
41 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
42 | return status; |
43 | } |
44 | |
45 | status_t ref_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { |
46 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
47 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
48 | auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); |
49 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
50 | |
51 | compute::kernel_arg_list_t arg_list; |
52 | arg_list.set(0, src); |
53 | arg_list.set(1, diff_dst); |
54 | arg_list.set(2, ws); |
55 | arg_list.set(3, diff_src); |
56 | |
57 | auto nd_range = pd()->dispatch.nd_range(); |
58 | |
59 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
60 | return status; |
61 | } |
62 | |
63 | } // namespace ocl |
64 | } // namespace gpu |
65 | } // namespace impl |
66 | } // namespace dnnl |
67 | |