1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19#include "opdesc.hpp"
20#include "primitive_desc_iface.hpp"
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::utils;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::prop_kind;
30using namespace dnnl::impl::types;
31
32namespace dnnl {
33namespace impl {
34status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind,
35 const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
36 const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) {
37 bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc);
38 if (!args_ok) return invalid_arguments;
39
40 auto id = inner_product_desc_t();
41 id.primitive_kind = primitive_kind::inner_product;
42 id.prop_kind = prop_kind;
43
44 id.diff_src_desc = id.src_desc = zero_md();
45 id.diff_dst_desc = id.dst_desc = zero_md();
46 id.diff_weights_desc = id.weights_desc = zero_md();
47 id.diff_bias_desc = id.bias_desc = zero_md();
48
49 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
50 const bool with_bias
51 = bias_desc && bias_desc->format_kind != format_kind::undef;
52
53 bool runtime_dims_or_strides
54 = memory_desc_wrapper(src_desc).has_runtime_dims_or_strides()
55 || memory_desc_wrapper(weights_desc).has_runtime_dims_or_strides()
56 || memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides();
57 if (with_bias)
58 runtime_dims_or_strides = runtime_dims_or_strides
59 || memory_desc_wrapper(bias_desc).has_runtime_dims_or_strides();
60 if (runtime_dims_or_strides) return unimplemented;
61
62 (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc;
63 (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc;
64 (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc)
65 = *weights_desc;
66 if (with_bias)
67 (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc)
68 = *bias_desc;
69
70 id.accum_data_type = types::default_accum_data_type(src_desc->data_type,
71 weights_desc->data_type, dst_desc->data_type, prop_kind);
72 if (id.accum_data_type == data_type::undef) return invalid_arguments;
73
74 bool consistency = true && memory_desc_wrapper(weights_desc).nelems()
75 && one_of(src_desc->ndims, 2, 3, 4, 5) && dst_desc->ndims == 2
76 && weights_desc->ndims == src_desc->ndims
77 && (with_bias ? bias_desc->ndims == 1 : true)
78 && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
79 && src_desc->dims[0] == dst_desc->dims[0]
80 && array_cmp(&src_desc->dims[1], &weights_desc->dims[1],
81 src_desc->ndims - 1)
82 && dst_desc->dims[1] == weights_desc->dims[0];
83 if (!consistency) return invalid_arguments;
84
85 *ip_desc = id;
86 return success;
87}
88} // namespace impl
89} // namespace dnnl
90
91status_t dnnl_inner_product_forward_primitive_desc_create(
92 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
93 prop_kind_t prop_kind, const memory_desc_t *src_desc,
94 const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
95 const memory_desc_t *dst_desc, const primitive_attr_t *attr) {
96 if (!one_of(prop_kind, forward_training, forward_inference))
97 return invalid_arguments;
98
99 auto ip_desc = inner_product_desc_t();
100 CHECK(ip_desc_init(
101 &ip_desc, prop_kind, src_desc, weights_desc, bias_desc, dst_desc));
102 return primitive_desc_create(primitive_desc_iface, engine,
103 (const op_desc_t *)&ip_desc, nullptr, attr);
104}
105
106status_t dnnl_inner_product_backward_data_primitive_desc_create(
107 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
108 const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
109 const memory_desc_t *diff_dst_desc,
110 const primitive_desc_iface_t *hint_fwd_pd,
111 const primitive_attr_t *attr) {
112
113 auto ip_desc = inner_product_desc_t();
114 CHECK(ip_desc_init(&ip_desc, backward_data, diff_src_desc, weights_desc,
115 nullptr, diff_dst_desc));
116 return primitive_desc_create(primitive_desc_iface, engine,
117 (const op_desc_t *)&ip_desc, hint_fwd_pd, attr);
118}
119
120status_t dnnl_inner_product_backward_weights_primitive_desc_create(
121 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
122 const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
123 const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
124 const primitive_desc_iface_t *hint_fwd_pd,
125 const primitive_attr_t *attr) {
126
127 auto ip_desc = inner_product_desc_t();
128 CHECK(ip_desc_init(&ip_desc, backward_weights, src_desc, diff_weights_desc,
129 diff_bias_desc, diff_dst_desc));
130 return primitive_desc_create(primitive_desc_iface, engine,
131 (const op_desc_t *)&ip_desc, hint_fwd_pd, attr);
132}
133
134// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
135