1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_softmax.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | status_t ref_softmax_fwd_t::execute_generic(const exec_ctx_t &ctx) const { |
25 | if (pd()->has_zero_dim_memory()) return status::success; |
26 | |
27 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
28 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
29 | auto &src_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
30 | auto &dst_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
31 | |
32 | compute::kernel_arg_list_t arg_list; |
33 | arg_list.set(0, src); |
34 | arg_list.set(1, dst); |
35 | arg_list.set(2, src_scale); |
36 | arg_list.set(3, dst_scale); |
37 | |
38 | if (pd()->group_size > 1) { |
39 | auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws); |
40 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
41 | } else { |
42 | auto nd_range = compute::nd_range_t(pd()->gws); |
43 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
44 | } |
45 | } |
46 | |
47 | status_t ref_softmax_bwd_t::execute_generic(const exec_ctx_t &ctx) const { |
48 | if (pd()->has_zero_dim_memory()) return status::success; |
49 | |
50 | auto &dst = CTX_IN_STORAGE(DNNL_ARG_DST); |
51 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
52 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
53 | |
54 | compute::kernel_arg_list_t arg_list; |
55 | arg_list.set(0, dst); |
56 | arg_list.set(1, diff_src); |
57 | arg_list.set(2, diff_dst); |
58 | |
59 | auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws); |
60 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
61 | } |
62 | } // namespace ocl |
63 | } // namespace gpu |
64 | } // namespace impl |
65 | } // namespace dnnl |
66 | |
67 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
68 | |