1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/gen9_softmax.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace ocl { |
23 | |
24 | status_t gen9_softmax_fwd_t::execute_generic(const exec_ctx_t &ctx) const { |
25 | if (pd()->has_zero_dim_memory()) return status::success; |
26 | |
27 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
28 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
29 | auto &src_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
30 | auto &dst_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
31 | |
32 | compute::kernel_arg_list_t arg_list; |
33 | arg_list.set(0, src); |
34 | arg_list.set(1, dst); |
35 | arg_list.set(2, src_scale); |
36 | arg_list.set(3, dst_scale); |
37 | |
38 | auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws); |
39 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
40 | } |
41 | |
42 | status_t gen9_softmax_bwd_t::execute_generic(const exec_ctx_t &ctx) const { |
43 | if (pd()->has_zero_dim_memory()) return status::success; |
44 | |
45 | auto &dst = CTX_IN_STORAGE(DNNL_ARG_DST); |
46 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
47 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
48 | |
49 | compute::kernel_arg_list_t arg_list; |
50 | arg_list.set(0, dst); |
51 | arg_list.set(1, diff_src); |
52 | arg_list.set(2, diff_dst); |
53 | |
54 | auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws); |
55 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
56 | } |
57 | |
58 | } // namespace ocl |
59 | } // namespace gpu |
60 | } // namespace impl |
61 | } // namespace dnnl |
62 | |
63 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
64 | |