1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_SUM_HPP
18#define CPU_REF_SUM_HPP
19
20#include "common/engine.hpp"
21#include "common/memory_tracking.hpp"
22#include "common/primitive.hpp"
23#include "common/reorder.hpp"
24#include "common/reorder_pd.hpp"
25
26#include "cpu/cpu_sum_pd.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32struct ref_sum_t : public primitive_t {
33 struct pd_t : public cpu_sum_pd_t {
34 using cpu_sum_pd_t::cpu_sum_pd_t;
35
36 pd_t(const pd_t &rhs) = default;
37
38 DECLARE_SUM_PD_T("ref:any", ref_sum_t);
39
40 status_t init(engine_t *engine) {
41 bool ok = cpu_sum_pd_t::init(engine) == status::success;
42 if (!ok) return status::unimplemented;
43
44 if (has_zero_dim_memory()) return status::success;
45
46 reorder_pds_.resize(n_ + need_output_reorder());
47 for (int i = 0; i < n_; ++i) {
48 primitive_attr_t r_attr;
49 r_attr.scales_.set(DNNL_ARG_SRC, 0);
50 if (i != 0) r_attr.post_ops_.append_sum(1.0);
51 CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine,
52 src_md(i), dst_acc_md(), &r_attr));
53 }
54
55 if (need_output_reorder()) {
56 CHECK(reorder_primitive_desc_create(
57 reorder_pds_[n_], engine, dst_acc_md(), dst_md()));
58 }
59
60 init_scratchpad();
61 return status::success;
62 }
63
64 std::vector<std::shared_ptr<primitive_desc_t>> reorder_pds_;
65
66 private:
67 void init_scratchpad() {
68 using namespace memory_tracking::names;
69 auto scratchpad = scratchpad_registry().registrar();
70 if (need_output_reorder()) {
71 const memory_desc_wrapper dst_acc_d(dst_acc_md());
72 scratchpad.book(key_sum_reduction, dst_acc_d.size(), 1,
73 dst_acc_d.data_type_size());
74 }
75
76 for (size_t i = 0; i < reorder_pds_.size(); i++) {
77 scratchpad.book(key_nested_multiple + (int)i,
78 reorder_pds_[i]->scratchpad_registry());
79 }
80 };
81 };
82
83 ref_sum_t(const pd_t *apd) : primitive_t(apd) {}
84
85 status_t init(engine_t *engine) override {
86 const size_t n = pd()->reorder_pds_.size();
87 reorders_.resize(n);
88 for (size_t i = 0; i < n; ++i)
89 pd()->reorder_pds_[i]->create_primitive(reorders_[i], engine);
90
91 memory_desc_t scales_md;
92 scales_md.ndims = 1;
93 scales_md.dims[0] = 1;
94 scales_md.data_type = data_type::f32;
95 CHECK(memory_desc_init_by_tag(scales_md, format_tag::x));
96 const float *scales = pd()->scales();
97
98 scales_mem_.resize(n);
99 for (size_t i = 0; i < n; ++i)
100 scales_mem_[i] = std::make_shared<memory_t>(get_service_engine(),
101 &scales_md, use_runtime_ptr,
102 const_cast<float *>(&(scales[i])));
103 return status::success;
104 }
105
106 status_t execute(const exec_ctx_t &ctx) const override {
107 using namespace memory_tracking::names;
108
109 if (pd()->has_zero_dim_memory()) return status::success;
110
111 const auto n = pd()->n_inputs();
112 exec_args_t r_args;
113
114 auto sum_reduce = pd()->need_output_reorder()
115 ? ctx.get_scratchpad_grantor().get_memory_storage(
116 key_sum_reduction)
117 : nullptr;
118 auto dst = ctx.args().at(DNNL_ARG_DST);
119 memory_t acc(
120 dst.mem->engine(), pd()->dst_acc_md(), std::move(sum_reduce));
121 memory_arg_t dst_acc = {&acc, false};
122
123 /* fix: clang MemorySanitizer: use-of-uninitialized-value */
124 if (pd()->need_output_reorder()) {
125 const memory_desc_wrapper acc_d(acc.md());
126 std::memset(acc.memory_storage()->data_handle(), 0, acc_d.size());
127 }
128
129 for (int i = 0; i < n; ++i) {
130 r_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i);
131 r_args[DNNL_ARG_DST] = pd()->need_output_reorder() ? dst_acc : dst;
132 r_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC]
133 = {scales_mem_[i].get(), true};
134
135 exec_ctx_t r_ctx(ctx, std::move(r_args));
136
137 nested_scratchpad_t ns(ctx, key_nested_multiple + i, reorders_[i]);
138 r_ctx.set_scratchpad_grantor(ns.grantor());
139 reorders_[i]->execute(r_ctx);
140 }
141
142 if (pd()->need_output_reorder()) {
143 dst_acc = {&acc, true};
144 r_args[DNNL_ARG_SRC] = dst_acc;
145 r_args[DNNL_ARG_DST] = dst;
146 exec_ctx_t r_ctx(ctx, std::move(r_args));
147
148 nested_scratchpad_t ns(ctx, key_nested_multiple + n, reorders_[n]);
149 r_ctx.set_scratchpad_grantor(ns.grantor());
150 reorders_[n]->execute(r_ctx);
151 }
152 return status::success;
153 }
154
155private:
156 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
157 std::vector<std::shared_ptr<primitive_t>> reorders_;
158 std::vector<std::shared_ptr<memory_t>> scales_mem_;
159};
160
161} // namespace cpu
162} // namespace impl
163} // namespace dnnl
164
165#endif
166