1 | /******************************************************************************* |
2 | * Copyright 2019-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_GPU_INNER_PRODUCT_PD_HPP |
18 | #define GPU_GPU_INNER_PRODUCT_PD_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/inner_product_pd.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | |
31 | namespace { |
32 | inline bool dense_consistency_check(const memory_desc_wrapper &src_d, |
33 | const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { |
34 | using namespace format_tag; |
35 | using namespace utils; |
36 | // Why is dense_gemm_consistency_check not enough (other than dst check)? |
37 | return IMPLICATION(src_d.matches_tag(ncw), wei_d.matches_tag(oiw)) |
38 | && IMPLICATION(src_d.matches_tag(nchw), wei_d.matches_tag(oihw)) |
39 | && IMPLICATION(src_d.matches_tag(ncdhw), wei_d.matches_tag(oidhw)) |
40 | && IMPLICATION( |
41 | src_d.matches_tag(nc), wei_d.matches_one_of_tag(oi, io)) |
42 | && dst_d.matches_tag(nc) && src_d.is_dense(true) && dst_d.is_dense() |
43 | && wei_d.is_dense(true); |
44 | } |
45 | |
46 | inline bool dense_gemm_consistency_check(const memory_desc_wrapper &src_d, |
47 | const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { |
48 | using namespace utils; |
49 | |
50 | auto strides_compatible = [&]() { |
51 | bool ok = true; |
52 | auto w_str = wei_d.blocking_desc().strides; |
53 | auto d_str = src_d.blocking_desc().strides; |
54 | for (int i = 1; i < src_d.ndims() - 1; i++) { |
55 | ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1]; |
56 | } |
57 | return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]); |
58 | }; |
59 | return src_d.is_blocking_desc() && wei_d.is_blocking_desc() |
60 | && src_d.ndims() == wei_d.ndims() |
61 | && src_d.blocking_desc().inner_nblks |
62 | == wei_d.blocking_desc().inner_nblks |
63 | && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1) |
64 | && array_cmp(src_d.blocking_desc().inner_blks, |
65 | wei_d.blocking_desc().inner_blks, |
66 | wei_d.blocking_desc().inner_nblks) |
67 | && array_cmp(src_d.blocking_desc().inner_idxs, |
68 | wei_d.blocking_desc().inner_idxs, |
69 | wei_d.blocking_desc().inner_nblks) |
70 | && strides_compatible() && dst_d.matches_tag(format_tag::nc) |
71 | && src_d.only_padded_dim(1) && wei_d.only_padded_dim(1) |
72 | && src_d.padded_dims()[1] == wei_d.padded_dims()[1] |
73 | && src_d.is_dense(true) && dst_d.is_dense() && wei_d.is_dense(true); |
74 | } |
75 | |
76 | status_t template_set_default_params(memory_desc_t &src_md, |
77 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
78 | memory_desc_t *bias_md, int ndims, bool is_conv = false) { |
79 | using namespace format_tag; |
80 | |
81 | auto init_md = [&](memory_desc_t &out_md, const memory_desc_t &in_md) { |
82 | format_tag_t md_tag; |
83 | if (memory_desc_matches_one_of_tag(in_md, ba, cba, cdba, cdeba)) |
84 | md_tag = utils::pick(ndims - 2, ab, acb, acdb, acdeb); |
85 | else if (memory_desc_matches_one_of_tag(in_md, acb, acdb, acdeb)) |
86 | md_tag = utils::pick(ndims - 3, cba, cdba, cdeba); |
87 | else { |
88 | memory_desc_wrapper md_desc_wrapper(in_md); |
89 | return memory_desc_init_by_blocking_desc( |
90 | out_md, md_desc_wrapper.blocking_desc()); |
91 | } |
92 | return memory_desc_init_by_tag(out_md, md_tag); |
93 | }; |
94 | if (!is_conv) { |
95 | if (src_md.format_kind == format_kind::any |
96 | && weights_md.format_kind == format_kind::any) { |
97 | CHECK(memory_desc_init_by_tag( |
98 | src_md, utils::pick(ndims - 2, nc, ncw, nchw, ncdhw))); |
99 | CHECK(memory_desc_init_by_tag( |
100 | weights_md, utils::pick(ndims - 2, oi, oiw, oihw, oidhw))); |
101 | } else if (src_md.format_kind == format_kind::any) |
102 | CHECK(init_md(src_md, weights_md)); |
103 | else if (weights_md.format_kind == format_kind::any) |
104 | CHECK(init_md(weights_md, src_md)); |
105 | } |
106 | |
107 | if (dst_md.format_kind == format_kind::any) |
108 | CHECK(memory_desc_init_by_tag(dst_md, nc)); |
109 | if (bias_md->format_kind == format_kind::any) |
110 | CHECK(memory_desc_init_by_tag(*bias_md, x)); |
111 | |
112 | return status::success; |
113 | } |
114 | |
115 | } // namespace |
116 | |
117 | struct gpu_inner_product_fwd_pd_t : public inner_product_fwd_pd_t { |
118 | using inner_product_fwd_pd_t::inner_product_fwd_pd_t; |
119 | |
120 | protected: |
121 | status_t set_default_params(bool is_conv = false) { |
122 | return template_set_default_params( |
123 | src_md_, weights_md_, dst_md_, &bias_md_, ndims(), is_conv); |
124 | } |
125 | }; |
126 | |
127 | struct gpu_inner_product_bwd_data_pd_t : public inner_product_bwd_data_pd_t { |
128 | using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t; |
129 | |
130 | protected: |
131 | status_t set_default_params() { |
132 | return template_set_default_params(diff_src_md_, weights_md_, |
133 | diff_dst_md_, &glob_zero_md, ndims()); |
134 | } |
135 | }; |
136 | |
137 | struct gpu_inner_product_bwd_weights_pd_t |
138 | : public inner_product_bwd_weights_pd_t { |
139 | using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t; |
140 | |
141 | protected: |
142 | status_t set_default_params() { |
143 | return template_set_default_params(src_md_, diff_weights_md_, |
144 | diff_dst_md_, &diff_bias_md_, ndims()); |
145 | } |
146 | }; |
147 | |
148 | } // namespace gpu |
149 | } // namespace impl |
150 | } // namespace dnnl |
151 | |
152 | #endif |
153 | |