1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19#include "opdesc.hpp"
20#include "primitive_desc_iface.hpp"
21
22#include "c_types_map.hpp"
23#include "memory_desc_wrapper.hpp"
24#include "utils.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::utils;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::prop_kind;
30using namespace dnnl::impl::alg_kind;
31using namespace dnnl::impl::types;
32
33namespace {
34status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind,
35 alg_kind_t alg_kind, const memory_desc_t *src_desc,
36 const memory_desc_t *dst_desc, const memory_desc_t *diff_src_desc,
37 const memory_desc_t *diff_dst_desc, int softmax_axis) {
38 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
39 bool args_ok = !any_null(softmax_desc, dst_desc)
40 && IMPLICATION(is_fwd, src_desc != nullptr)
41 && IMPLICATION(!is_fwd, !any_null(diff_src_desc, diff_dst_desc))
42 && one_of(alg_kind, softmax_accurate, softmax_log)
43 && 0 <= softmax_axis && softmax_axis < dst_desc->ndims
44 && IMPLICATION(is_fwd, !memory_desc_wrapper(src_desc).format_any())
45 && IMPLICATION(
46 !is_fwd, !memory_desc_wrapper(dst_desc).format_any());
47 if (!args_ok) return invalid_arguments;
48
49 bool runtime_dims_or_strides
50 = memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides();
51 if (is_fwd) {
52 runtime_dims_or_strides = runtime_dims_or_strides
53 || memory_desc_wrapper(src_desc).has_runtime_dims_or_strides();
54 } else {
55 runtime_dims_or_strides = runtime_dims_or_strides
56 || memory_desc_wrapper(diff_src_desc)
57 .has_runtime_dims_or_strides()
58 || memory_desc_wrapper(diff_dst_desc)
59 .has_runtime_dims_or_strides();
60 }
61 if (runtime_dims_or_strides) return unimplemented;
62
63 auto sd = softmax_desc_t();
64 sd.primitive_kind = primitive_kind::softmax;
65 sd.prop_kind = prop_kind;
66
67 if (is_fwd) sd.src_desc = *src_desc;
68 if (!is_fwd) sd.diff_src_desc = *diff_src_desc;
69 sd.softmax_axis = softmax_axis;
70 sd.alg_kind = alg_kind;
71 sd.dst_desc = *dst_desc;
72 if (!is_fwd) sd.diff_dst_desc = *diff_dst_desc;
73
74 *softmax_desc = sd;
75 return success;
76}
77} // namespace
78
79status_t dnnl_softmax_forward_primitive_desc_create(
80 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
81 prop_kind_t prop_kind, alg_kind_t alg_kind,
82 const memory_desc_t *src_desc, const memory_desc_t *dst_desc, int axis,
83 const primitive_attr_t *attr) {
84 if (!one_of(prop_kind, forward_inference, forward_training))
85 return invalid_arguments;
86
87 auto softmax_desc = softmax_desc_t();
88 CHECK(softmax_desc_init(&softmax_desc, prop_kind, alg_kind, src_desc,
89 dst_desc, nullptr, nullptr, axis));
90 return primitive_desc_create(primitive_desc_iface, engine,
91 (const op_desc_t *)&softmax_desc, nullptr, attr);
92}
93
94status_t dnnl_softmax_backward_primitive_desc_create(
95 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
96 alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
97 const memory_desc_t *diff_dst_desc, const memory_desc_t *dst_desc,
98 int axis, const primitive_desc_iface_t *hint_fwd_pd,
99 const primitive_attr_t *attr) {
100
101 auto softmax_desc = softmax_desc_t();
102 CHECK(softmax_desc_init(&softmax_desc, prop_kind::backward_data, alg_kind,
103 nullptr, dst_desc, diff_src_desc, diff_dst_desc, axis));
104 return primitive_desc_create(primitive_desc_iface, engine,
105 (const op_desc_t *)&softmax_desc, hint_fwd_pd, attr);
106}
107
108// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
109