1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_softmax.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace ocl {
23
24status_t ref_softmax_fwd_t::execute_generic(const exec_ctx_t &ctx) const {
25 if (pd()->has_zero_dim_memory()) return status::success;
26
27 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
28 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
29 auto &src_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
30 auto &dst_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
31
32 compute::kernel_arg_list_t arg_list;
33 arg_list.set(0, src);
34 arg_list.set(1, dst);
35 arg_list.set(2, src_scale);
36 arg_list.set(3, dst_scale);
37
38 if (pd()->group_size > 1) {
39 auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws);
40 return parallel_for(ctx, nd_range, kernel_, arg_list);
41 } else {
42 auto nd_range = compute::nd_range_t(pd()->gws);
43 return parallel_for(ctx, nd_range, kernel_, arg_list);
44 }
45}
46
47status_t ref_softmax_bwd_t::execute_generic(const exec_ctx_t &ctx) const {
48 if (pd()->has_zero_dim_memory()) return status::success;
49
50 auto &dst = CTX_IN_STORAGE(DNNL_ARG_DST);
51 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
52 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
53
54 compute::kernel_arg_list_t arg_list;
55 arg_list.set(0, dst);
56 arg_list.set(1, diff_src);
57 arg_list.set(2, diff_dst);
58
59 auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws);
60 return parallel_for(ctx, nd_range, kernel_, arg_list);
61}
62} // namespace ocl
63} // namespace gpu
64} // namespace impl
65} // namespace dnnl
66
67// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
68