1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_GEMM_PD_HPP |
18 | #define COMMON_GEMM_PD_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/gemm_utils.hpp" |
24 | #include "common/primitive_desc.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | |
30 | struct gemm_pd_t : public primitive_desc_t { |
31 | static constexpr auto base_pkind = primitive_kind::gemm; |
32 | |
33 | typedef gemm_pd_t base_class; |
34 | typedef gemm_pd_t hint_class; |
35 | |
36 | const gemm_desc_t *desc() const { return &desc_; } |
37 | const op_desc_t *op_desc() const override { |
38 | return reinterpret_cast<const op_desc_t *>(this->desc()); |
39 | } |
40 | |
41 | arg_usage_t arg_usage(int arg) const override { |
42 | if (utils::one_of(arg, DNNL_ARG_SRC_0, DNNL_ARG_SRC_1)) |
43 | return arg_usage_t::input; |
44 | |
45 | if (arg == DNNL_ARG_DST) return arg_usage_t::output; |
46 | |
47 | return primitive_desc_t::arg_usage(arg); |
48 | } |
49 | |
50 | const memory_desc_t *arg_md(int arg) const override { |
51 | switch (arg) { |
52 | case DNNL_ARG_SRC_0: return src_md(0); |
53 | case DNNL_ARG_SRC_1: return src_md(1); |
54 | case DNNL_ARG_BIAS: return src_md(2); |
55 | case DNNL_ARG_DST: return dst_md(0); |
56 | default: return primitive_desc_t::arg_md(arg); |
57 | } |
58 | } |
59 | |
60 | const memory_desc_t *src_md(int index = 0) const override { |
61 | switch (index) { |
62 | case 0: return &desc_.a_desc; |
63 | case 1: return &desc_.b_desc; |
64 | case 2: return &desc_.bias_desc; |
65 | default: return &glob_zero_md; |
66 | } |
67 | } |
68 | const memory_desc_t *dst_md(int index = 0) const override { |
69 | return index == 0 ? &desc_.c_desc : &glob_zero_md; |
70 | } |
71 | |
72 | int n_inputs() const override { return 2; } |
73 | int n_outputs() const override { return 1; } |
74 | |
75 | protected: |
76 | // Note: we do not copy memory desc locally to avoid |
77 | // overheads. This means we lose the users memory descs when we |
78 | // resolve the 'any' tags. |
79 | gemm_desc_t desc_; |
80 | |
81 | gemm_pd_t(const gemm_desc_t *adesc, const primitive_attr_t *attr, |
82 | const hint_class *hint_fwd_pd) |
83 | : primitive_desc_t(attr, base_pkind), desc_(*adesc) {} |
84 | |
85 | // By default, we just resolve 'any' with blocked layout and trivial strides |
86 | bool set_default_format(memory_desc_t *md) { |
87 | memory_desc_wrapper mdw(md); |
88 | if (mdw.format_any()) { |
89 | if (mdw.has_runtime_dims_or_strides()) return false; |
90 | status_t status = memory_desc_init_by_strides(*md, nullptr); |
91 | if (status != status::success) return false; |
92 | } |
93 | |
94 | return true; |
95 | } |
96 | |
97 | bool set_default_formats() { |
98 | bool ok = true; |
99 | |
100 | for (auto md : {&desc_.a_desc, &desc_.b_desc, &desc_.bias_desc, |
101 | &desc_.c_desc}) { |
102 | ok = ok && set_default_format(md); |
103 | } |
104 | |
105 | auto status = attr_.post_ops_.set_default_formats(&desc_.c_desc); |
106 | ok = ok && (status == status::success); |
107 | |
108 | return ok; |
109 | } |
110 | }; |
111 | |
112 | } // namespace impl |
113 | } // namespace dnnl |
114 | |
115 | #endif |
116 | |