1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_MATMUL_PD_HPP
18#define COMMON_MATMUL_PD_HPP
19
20#include <assert.h>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "c_types_map.hpp"
25#include "primitive_desc.hpp"
26#include "utils.hpp"
27
28namespace dnnl {
29namespace impl {
30
31struct matmul_pd_t : public primitive_desc_t {
32 static constexpr auto base_pkind = primitive_kind::matmul;
33
34 typedef matmul_pd_t base_class;
35 typedef matmul_pd_t hint_class;
36
37 const matmul_desc_t *desc() const { return &desc_; }
38 const op_desc_t *op_desc() const override {
39 return reinterpret_cast<const op_desc_t *>(this->desc());
40 }
41
42 arg_usage_t arg_usage(int arg) const override {
43 const bool input = utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS);
44 if (input) return arg_usage_t::input;
45
46 if (arg == DNNL_ARG_BIAS && with_bias()) return arg_usage_t::input;
47
48 if (arg == DNNL_ARG_DST) return arg_usage_t::output;
49
50 return primitive_desc_t::arg_usage(arg);
51 }
52
53 const memory_desc_t *arg_md(int arg) const override {
54 switch (arg) {
55 case DNNL_ARG_SRC: return src_md(0);
56 case DNNL_ARG_WEIGHTS: return weights_md(0);
57 case DNNL_ARG_BIAS: return weights_md(1);
58 case DNNL_ARG_DST: return dst_md(0);
59 default: return primitive_desc_t::arg_md(arg);
60 }
61 }
62
63 const memory_desc_t *src_md(int index = 0) const override {
64 return index == 0 ? &src_md_ : &glob_zero_md;
65 }
66
67 const memory_desc_t *weights_md(int index = 0) const override {
68 if (index == 0) return &weights_md_;
69 if (index == 1 && with_bias()) return &bias_md_;
70 return &glob_zero_md;
71 }
72
73 const memory_desc_t *dst_md(int index = 0) const override {
74 return index == 0 ? &dst_md_ : &glob_zero_md;
75 }
76
77 int n_inputs() const override {
78 return 2 + with_bias() + n_binary_po_inputs();
79 }
80 int n_outputs() const override { return 1; }
81
82 bool has_zero_dim_memory() const {
83 return memory_desc_wrapper(src_md(0)).has_zero_dim()
84 || memory_desc_wrapper(weights_md(0)).has_zero_dim()
85 || memory_desc_wrapper(dst_md(0)).has_zero_dim();
86 }
87
88 bool has_runtime_dims_or_strides() const {
89 return memory_desc_wrapper(src_md_).has_runtime_dims_or_strides()
90 || memory_desc_wrapper(weights_md_)
91 .has_runtime_dims_or_strides()
92 || memory_desc_wrapper(dst_md_).has_runtime_dims_or_strides();
93 };
94
95 int ndims() const { return dst_md_.ndims; }
96
97 dim_t ldc() const {
98 return memory_desc_wrapper(dst_md(0))
99 .blocking_desc()
100 .strides[ndims() - 2];
101 }
102
103 bool with_bias() const { return bias_md_.ndims != 0; }
104 bool batched() const { return ndims() > 2; }
105
106 dim_t batch() const {
107 return utils::array_product(dst_md_.dims, ndims() - 2);
108 }
109 dim_t M() const { return dst_md_.dims[ndims() - 2]; }
110 dim_t N() const { return dst_md_.dims[ndims() - 1]; }
111 dim_t K() const { return src_md_.dims[ndims() - 1]; }
112
113 bool is_bias_1xN() const {
114 if (!with_bias()) return false;
115
116 const auto &dims = weights_md(1)->dims;
117 const int n_dims = ndims();
118 for (int i = 0; i < n_dims - 1; ++i) {
119 if (dims[i] != 1) return false;
120 }
121
122 return dims[n_dims - 1] == N();
123 }
124
125protected:
126 matmul_desc_t desc_;
127
128 memory_desc_t src_md_;
129 memory_desc_t weights_md_;
130 memory_desc_t bias_md_;
131 memory_desc_t dst_md_;
132
133 matmul_pd_t(const matmul_desc_t *adesc, const primitive_attr_t *attr,
134 const matmul_pd_t *hint_fwd_pd)
135 : primitive_desc_t(attr, base_pkind)
136 , desc_(*adesc)
137 , src_md_(desc_.src_desc)
138 , weights_md_(desc_.weights_desc)
139 , bias_md_(desc_.bias_desc)
140 , dst_md_(desc_.dst_desc) {}
141
142 // temporary solution to deal with format `any`
143 bool set_default_formats() {
144 for (auto md : {&src_md_, &weights_md_, &bias_md_, &dst_md_}) {
145 memory_desc_wrapper mdw(md);
146 if (mdw.format_any()) {
147 if (mdw.has_runtime_dims_or_strides()) return false;
148 status_t status = memory_desc_init_by_strides(*md, nullptr);
149 if (status != status::success) return false;
150 }
151 }
152
153 return true;
154 }
155};
156
157} // namespace impl
158} // namespace dnnl
159
160#endif
161