1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_BINARY_HPP
18#define CPU_REF_BINARY_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/platform.hpp"
28#include "cpu/primitive_attr_postops.hpp"
29
30#include "cpu/cpu_binary_pd.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35
36struct ref_binary_t : public primitive_t {
37 struct pd_t : public cpu_binary_pd_t {
38 using cpu_binary_pd_t::cpu_binary_pd_t;
39
40 DECLARE_COMMON_PD_T("ref:any", ref_binary_t);
41
42 status_t init(engine_t *engine) {
43 using namespace data_type;
44 using sm = primitive_attr_t::skip_mask_t;
45
46 const bool ok
47 = platform::has_data_type_support(src_md(0)->data_type)
48 && platform::has_data_type_support(src_md(1)->data_type)
49 && platform::has_data_type_support(dst_md()->data_type)
50 && set_default_params() == status::success
51 && attr()->has_default_values(
52 sm::post_ops | sm::scales_runtime)
53 && IMPLICATION(!attr()->scales_.has_default_values(),
54 check_scales_mask())
55 && attr_.set_default_formats(dst_md(0)) == status::success;
56 if (!ok) return status::unimplemented;
57
58 return status::success;
59 }
60
61 private:
62 bool check_scales_mask() const {
63 for (const auto &s : attr()->scales_.scales_) {
64 if (s.second.mask_ != 0) return false;
65 }
66 return true;
67 }
68 };
69
70 ref_binary_t(const pd_t *apd) : primitive_t(apd) {}
71
72 status_t init(engine_t *engine) override {
73 ref_post_ops
74 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
75 if (!ref_post_ops) return status::out_of_memory;
76 return status::success;
77 }
78
79 status_t execute(const exec_ctx_t &ctx) const override {
80 return execute_ref(ctx);
81 }
82
83private:
84 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
85 status_t execute_ref(const exec_ctx_t &ctx) const;
86 std::unique_ptr<ref_post_ops_t> ref_post_ops;
87};
88
89} // namespace cpu
90} // namespace impl
91} // namespace dnnl
92
93#endif
94