1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "opdesc.hpp"
19#include "primitive_desc_iface.hpp"
20
21#include "oneapi/dnnl/dnnl.h"
22
23#include "c_types_map.hpp"
24#include "type_helpers.hpp"
25#include "utils.hpp"
26
27using namespace dnnl::impl;
28using namespace dnnl::impl::utils;
29using namespace dnnl::impl::status;
30using namespace dnnl::impl::alg_kind;
31using namespace dnnl::impl::types;
32
33status_t dnnl_binary_primitive_desc_create(
34 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
35 alg_kind_t alg_kind, const memory_desc_t *src0_md,
36 const memory_desc_t *src1_md, const memory_desc_t *dst_md,
37 const primitive_attr_t *attr) {
38 bool args_ok = !any_null(src0_md, src1_md, dst_md)
39 && one_of(alg_kind, binary_add, binary_mul, binary_max, binary_min,
40 binary_div, binary_sub, binary_ge, binary_gt, binary_le,
41 binary_lt, binary_eq, binary_ne)
42 // TODO - Add support for mutual or bi-directional broadcasts
43 && !memory_desc_wrapper(src0_md).format_any();
44 if (!args_ok) return invalid_arguments;
45
46 auto bod = binary_desc_t();
47 bod.primitive_kind = primitive_kind::binary;
48 bod.alg_kind = alg_kind;
49
50 bool runtime_dims_or_strides
51 = memory_desc_wrapper(src0_md).has_runtime_dims_or_strides()
52 || memory_desc_wrapper(src1_md).has_runtime_dims_or_strides()
53 || memory_desc_wrapper(dst_md).has_runtime_dims_or_strides();
54 if (runtime_dims_or_strides) return unimplemented;
55
56 bod.src_desc[0] = *src0_md;
57 bod.src_desc[1] = *src1_md;
58 bod.dst_desc = *dst_md;
59
60 const int ndims = dst_md->ndims;
61 const dims_t &dims = dst_md->dims;
62
63 if (!(src0_md->ndims == ndims && src1_md->ndims == ndims))
64 return invalid_arguments;
65 for (int d = 0; d < ndims; ++d) {
66 //dims must equal eachother or equal 1 (broadcast)
67 const bool ok = utils::one_of(src0_md->dims[d], 1, dims[d])
68 && utils::one_of(src1_md->dims[d], 1, dims[d])
69 && IMPLICATION(src0_md->dims[d] != dims[d],
70 src1_md->dims[d] == dims[d]);
71 if (!ok) return invalid_arguments;
72 }
73
74 return primitive_desc_create(primitive_desc_iface, engine,
75 (const op_desc_t *)&bod, nullptr, attr);
76}
77