1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "opdesc.hpp"
19#include "primitive_desc_iface.hpp"
20
21#include "oneapi/dnnl/dnnl.h"
22
23#include "c_types_map.hpp"
24#include "type_helpers.hpp"
25#include "utils.hpp"
26
27using namespace dnnl::impl;
28using namespace dnnl::impl::utils;
29using namespace dnnl::impl::types;
30
31status_t dnnl_matmul_primitive_desc_create(
32 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
33 const memory_desc_t *src_md, const memory_desc_t *weights_md,
34 const memory_desc_t *bias_md, const memory_desc_t *dst_md,
35 const primitive_attr_t *attr) {
36 bool args_ok = !any_null(src_md, weights_md, dst_md);
37 if (!args_ok) return status::invalid_arguments;
38
39 auto op_d = matmul_desc_t();
40 op_d.primitive_kind = primitive_kind::matmul;
41
42 op_d.src_desc = *src_md;
43 op_d.weights_desc = *weights_md;
44 if (bias_md) op_d.bias_desc = *bias_md;
45 op_d.dst_desc = *dst_md;
46
47 const bool with_bias = op_d.bias_desc.ndims != 0;
48 const int ndims = dst_md->ndims;
49 bool ok = ndims >= 2 && ndims <= DNNL_MAX_NDIMS
50 && everyone_is(ndims, src_md->ndims, weights_md->ndims)
51 && IMPLICATION(with_bias, op_d.bias_desc.ndims == ndims);
52 if (!ok) return status::invalid_arguments;
53
54 // check: m, n, k
55 const int m_idx = ndims - 2;
56 const int k_idx_src = m_idx + 1;
57 const int k_idx_wei = m_idx;
58 const int n_idx = ndims - 1;
59 ok = dst_md->dims[m_idx] == src_md->dims[m_idx]
60 && dst_md->dims[n_idx] == weights_md->dims[n_idx]
61 && src_md->dims[k_idx_src] == weights_md->dims[k_idx_wei]
62 && IMPLICATION(with_bias,
63 one_of(op_d.bias_desc.dims[n_idx], 1, dst_md->dims[n_idx]))
64 && IMPLICATION(with_bias,
65 one_of(op_d.bias_desc.dims[m_idx], 1, dst_md->dims[m_idx]));
66 if (!ok) return status::invalid_arguments;
67
68 // check if other dims match.
69 for (int d = 0; d < ndims - 2; ++d) {
70 const dim_t s_dim = src_md->dims[d];
71 const dim_t w_dim = weights_md->dims[d];
72 const dim_t d_dim = dst_md->dims[d];
73 const dim_t b_dim = with_bias ? op_d.bias_desc.dims[d] : 0;
74
75 if (one_of(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim, b_dim)) {
76
77 if (!(everyone_is(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim)
78 && IMPLICATION(
79 with_bias, b_dim == DNNL_RUNTIME_DIM_VAL)))
80 return status::invalid_arguments;
81 } else {
82 // This follows numpy semantics of broadcasting when 0 is involved.
83 ok = IMPLICATION(!everyone_is(s_dim, w_dim, d_dim),
84 one_of(1, s_dim, w_dim))
85 && IMPLICATION(s_dim == 1, d_dim == w_dim)
86 && IMPLICATION(w_dim == 1, d_dim == s_dim)
87 && IMPLICATION(with_bias, one_of(b_dim, 1, d_dim));
88 if (!ok) return status::invalid_arguments;
89 }
90 }
91
92 op_d.accum_data_type = types::default_accum_data_type(src_md->data_type,
93 weights_md->data_type, dst_md->data_type, prop_kind::forward);
94 if (op_d.accum_data_type == data_type::undef)
95 return status::invalid_arguments;
96
97 return primitive_desc_create(primitive_desc_iface, engine,
98 (const op_desc_t *)&op_d, nullptr, attr);
99}
100