1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_GEMM_PD_HPP
18#define COMMON_GEMM_PD_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "common/c_types_map.hpp"
23#include "common/gemm_utils.hpp"
24#include "common/primitive_desc.hpp"
25#include "common/utils.hpp"
26
27namespace dnnl {
28namespace impl {
29
30struct gemm_pd_t : public primitive_desc_t {
31 static constexpr auto base_pkind = primitive_kind::gemm;
32
33 typedef gemm_pd_t base_class;
34 typedef gemm_pd_t hint_class;
35
36 const gemm_desc_t *desc() const { return &desc_; }
37 const op_desc_t *op_desc() const override {
38 return reinterpret_cast<const op_desc_t *>(this->desc());
39 }
40
41 arg_usage_t arg_usage(int arg) const override {
42 if (utils::one_of(arg, DNNL_ARG_SRC_0, DNNL_ARG_SRC_1))
43 return arg_usage_t::input;
44
45 if (arg == DNNL_ARG_DST) return arg_usage_t::output;
46
47 return primitive_desc_t::arg_usage(arg);
48 }
49
50 const memory_desc_t *arg_md(int arg) const override {
51 switch (arg) {
52 case DNNL_ARG_SRC_0: return src_md(0);
53 case DNNL_ARG_SRC_1: return src_md(1);
54 case DNNL_ARG_BIAS: return src_md(2);
55 case DNNL_ARG_DST: return dst_md(0);
56 default: return primitive_desc_t::arg_md(arg);
57 }
58 }
59
60 const memory_desc_t *src_md(int index = 0) const override {
61 switch (index) {
62 case 0: return &desc_.a_desc;
63 case 1: return &desc_.b_desc;
64 case 2: return &desc_.bias_desc;
65 default: return &glob_zero_md;
66 }
67 }
68 const memory_desc_t *dst_md(int index = 0) const override {
69 return index == 0 ? &desc_.c_desc : &glob_zero_md;
70 }
71
72 int n_inputs() const override { return 2; }
73 int n_outputs() const override { return 1; }
74
75protected:
76 // Note: we do not copy memory desc locally to avoid
77 // overheads. This means we lose the users memory descs when we
78 // resolve the 'any' tags.
79 gemm_desc_t desc_;
80
81 gemm_pd_t(const gemm_desc_t *adesc, const primitive_attr_t *attr,
82 const hint_class *hint_fwd_pd)
83 : primitive_desc_t(attr, base_pkind), desc_(*adesc) {}
84
85 // By default, we just resolve 'any' with blocked layout and trivial strides
86 bool set_default_format(memory_desc_t *md) {
87 memory_desc_wrapper mdw(md);
88 if (mdw.format_any()) {
89 if (mdw.has_runtime_dims_or_strides()) return false;
90 status_t status = memory_desc_init_by_strides(*md, nullptr);
91 if (status != status::success) return false;
92 }
93
94 return true;
95 }
96
97 bool set_default_formats() {
98 bool ok = true;
99
100 for (auto md : {&desc_.a_desc, &desc_.b_desc, &desc_.bias_desc,
101 &desc_.c_desc}) {
102 ok = ok && set_default_format(md);
103 }
104
105 auto status = attr_.post_ops_.set_default_formats(&desc_.c_desc);
106 ok = ok && (status == status::success);
107
108 return ok;
109 }
110};
111
112} // namespace impl
113} // namespace dnnl
114
115#endif
116