1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_GPU_INNER_PRODUCT_PD_HPP
18#define GPU_GPU_INNER_PRODUCT_PD_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/inner_product_pd.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30
31namespace {
32inline bool dense_consistency_check(const memory_desc_wrapper &src_d,
33 const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) {
34 using namespace format_tag;
35 using namespace utils;
36 // Why is dense_gemm_consistency_check not enough (other than dst check)?
37 return IMPLICATION(src_d.matches_tag(ncw), wei_d.matches_tag(oiw))
38 && IMPLICATION(src_d.matches_tag(nchw), wei_d.matches_tag(oihw))
39 && IMPLICATION(src_d.matches_tag(ncdhw), wei_d.matches_tag(oidhw))
40 && IMPLICATION(
41 src_d.matches_tag(nc), wei_d.matches_one_of_tag(oi, io))
42 && dst_d.matches_tag(nc) && src_d.is_dense(true) && dst_d.is_dense()
43 && wei_d.is_dense(true);
44}
45
46inline bool dense_gemm_consistency_check(const memory_desc_wrapper &src_d,
47 const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) {
48 using namespace utils;
49
50 auto strides_compatible = [&]() {
51 bool ok = true;
52 auto w_str = wei_d.blocking_desc().strides;
53 auto d_str = src_d.blocking_desc().strides;
54 for (int i = 1; i < src_d.ndims() - 1; i++) {
55 ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1];
56 }
57 return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]);
58 };
59 return src_d.is_blocking_desc() && wei_d.is_blocking_desc()
60 && src_d.ndims() == wei_d.ndims()
61 && src_d.blocking_desc().inner_nblks
62 == wei_d.blocking_desc().inner_nblks
63 && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1)
64 && array_cmp(src_d.blocking_desc().inner_blks,
65 wei_d.blocking_desc().inner_blks,
66 wei_d.blocking_desc().inner_nblks)
67 && array_cmp(src_d.blocking_desc().inner_idxs,
68 wei_d.blocking_desc().inner_idxs,
69 wei_d.blocking_desc().inner_nblks)
70 && strides_compatible() && dst_d.matches_tag(format_tag::nc)
71 && src_d.only_padded_dim(1) && wei_d.only_padded_dim(1)
72 && src_d.padded_dims()[1] == wei_d.padded_dims()[1]
73 && src_d.is_dense(true) && dst_d.is_dense() && wei_d.is_dense(true);
74}
75
76status_t template_set_default_params(memory_desc_t &src_md,
77 memory_desc_t &weights_md, memory_desc_t &dst_md,
78 memory_desc_t *bias_md, int ndims, bool is_conv = false) {
79 using namespace format_tag;
80
81 auto init_md = [&](memory_desc_t &out_md, const memory_desc_t &in_md) {
82 format_tag_t md_tag;
83 if (memory_desc_matches_one_of_tag(in_md, ba, cba, cdba, cdeba))
84 md_tag = utils::pick(ndims - 2, ab, acb, acdb, acdeb);
85 else if (memory_desc_matches_one_of_tag(in_md, acb, acdb, acdeb))
86 md_tag = utils::pick(ndims - 3, cba, cdba, cdeba);
87 else {
88 memory_desc_wrapper md_desc_wrapper(in_md);
89 return memory_desc_init_by_blocking_desc(
90 out_md, md_desc_wrapper.blocking_desc());
91 }
92 return memory_desc_init_by_tag(out_md, md_tag);
93 };
94 if (!is_conv) {
95 if (src_md.format_kind == format_kind::any
96 && weights_md.format_kind == format_kind::any) {
97 CHECK(memory_desc_init_by_tag(
98 src_md, utils::pick(ndims - 2, nc, ncw, nchw, ncdhw)));
99 CHECK(memory_desc_init_by_tag(
100 weights_md, utils::pick(ndims - 2, oi, oiw, oihw, oidhw)));
101 } else if (src_md.format_kind == format_kind::any)
102 CHECK(init_md(src_md, weights_md));
103 else if (weights_md.format_kind == format_kind::any)
104 CHECK(init_md(weights_md, src_md));
105 }
106
107 if (dst_md.format_kind == format_kind::any)
108 CHECK(memory_desc_init_by_tag(dst_md, nc));
109 if (bias_md->format_kind == format_kind::any)
110 CHECK(memory_desc_init_by_tag(*bias_md, x));
111
112 return status::success;
113}
114
115} // namespace
116
117struct gpu_inner_product_fwd_pd_t : public inner_product_fwd_pd_t {
118 using inner_product_fwd_pd_t::inner_product_fwd_pd_t;
119
120protected:
121 status_t set_default_params(bool is_conv = false) {
122 return template_set_default_params(
123 src_md_, weights_md_, dst_md_, &bias_md_, ndims(), is_conv);
124 }
125};
126
127struct gpu_inner_product_bwd_data_pd_t : public inner_product_bwd_data_pd_t {
128 using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t;
129
130protected:
131 status_t set_default_params() {
132 return template_set_default_params(diff_src_md_, weights_md_,
133 diff_dst_md_, &glob_zero_md, ndims());
134 }
135};
136
137struct gpu_inner_product_bwd_weights_pd_t
138 : public inner_product_bwd_weights_pd_t {
139 using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t;
140
141protected:
142 status_t set_default_params() {
143 return template_set_default_params(src_md_, diff_weights_md_,
144 diff_dst_md_, &diff_bias_md_, ndims());
145 }
146};
147
148} // namespace gpu
149} // namespace impl
150} // namespace dnnl
151
152#endif
153