1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include "opdesc.hpp" |
19 | #include "primitive_desc_iface.hpp" |
20 | |
21 | #include "oneapi/dnnl/dnnl.h" |
22 | |
23 | #include "c_types_map.hpp" |
24 | #include "type_helpers.hpp" |
25 | #include "utils.hpp" |
26 | |
27 | using namespace dnnl::impl; |
28 | using namespace dnnl::impl::utils; |
29 | using namespace dnnl::impl::types; |
30 | |
31 | status_t dnnl_matmul_primitive_desc_create( |
32 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
33 | const memory_desc_t *src_md, const memory_desc_t *weights_md, |
34 | const memory_desc_t *bias_md, const memory_desc_t *dst_md, |
35 | const primitive_attr_t *attr) { |
36 | bool args_ok = !any_null(src_md, weights_md, dst_md); |
37 | if (!args_ok) return status::invalid_arguments; |
38 | |
39 | auto op_d = matmul_desc_t(); |
40 | op_d.primitive_kind = primitive_kind::matmul; |
41 | |
42 | op_d.src_desc = *src_md; |
43 | op_d.weights_desc = *weights_md; |
44 | if (bias_md) op_d.bias_desc = *bias_md; |
45 | op_d.dst_desc = *dst_md; |
46 | |
47 | const bool with_bias = op_d.bias_desc.ndims != 0; |
48 | const int ndims = dst_md->ndims; |
49 | bool ok = ndims >= 2 && ndims <= DNNL_MAX_NDIMS |
50 | && everyone_is(ndims, src_md->ndims, weights_md->ndims) |
51 | && IMPLICATION(with_bias, op_d.bias_desc.ndims == ndims); |
52 | if (!ok) return status::invalid_arguments; |
53 | |
54 | // check: m, n, k |
55 | const int m_idx = ndims - 2; |
56 | const int k_idx_src = m_idx + 1; |
57 | const int k_idx_wei = m_idx; |
58 | const int n_idx = ndims - 1; |
59 | ok = dst_md->dims[m_idx] == src_md->dims[m_idx] |
60 | && dst_md->dims[n_idx] == weights_md->dims[n_idx] |
61 | && src_md->dims[k_idx_src] == weights_md->dims[k_idx_wei] |
62 | && IMPLICATION(with_bias, |
63 | one_of(op_d.bias_desc.dims[n_idx], 1, dst_md->dims[n_idx])) |
64 | && IMPLICATION(with_bias, |
65 | one_of(op_d.bias_desc.dims[m_idx], 1, dst_md->dims[m_idx])); |
66 | if (!ok) return status::invalid_arguments; |
67 | |
68 | // check if other dims match. |
69 | for (int d = 0; d < ndims - 2; ++d) { |
70 | const dim_t s_dim = src_md->dims[d]; |
71 | const dim_t w_dim = weights_md->dims[d]; |
72 | const dim_t d_dim = dst_md->dims[d]; |
73 | const dim_t b_dim = with_bias ? op_d.bias_desc.dims[d] : 0; |
74 | |
75 | if (one_of(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim, b_dim)) { |
76 | |
77 | if (!(everyone_is(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim) |
78 | && IMPLICATION( |
79 | with_bias, b_dim == DNNL_RUNTIME_DIM_VAL))) |
80 | return status::invalid_arguments; |
81 | } else { |
82 | // This follows numpy semantics of broadcasting when 0 is involved. |
83 | ok = IMPLICATION(!everyone_is(s_dim, w_dim, d_dim), |
84 | one_of(1, s_dim, w_dim)) |
85 | && IMPLICATION(s_dim == 1, d_dim == w_dim) |
86 | && IMPLICATION(w_dim == 1, d_dim == s_dim) |
87 | && IMPLICATION(with_bias, one_of(b_dim, 1, d_dim)); |
88 | if (!ok) return status::invalid_arguments; |
89 | } |
90 | } |
91 | |
92 | op_d.accum_data_type = types::default_accum_data_type(src_md->data_type, |
93 | weights_md->data_type, dst_md->data_type, prop_kind::forward); |
94 | if (op_d.accum_data_type == data_type::undef) |
95 | return status::invalid_arguments; |
96 | |
97 | return primitive_desc_create(primitive_desc_iface, engine, |
98 | (const op_desc_t *)&op_d, nullptr, attr); |
99 | } |
100 | |