1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_REF_SUM_HPP |
18 | #define CPU_REF_SUM_HPP |
19 | |
20 | #include "common/engine.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | #include "common/primitive.hpp" |
23 | #include "common/reorder.hpp" |
24 | #include "common/reorder_pd.hpp" |
25 | |
26 | #include "cpu/cpu_sum_pd.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | |
32 | struct ref_sum_t : public primitive_t { |
33 | struct pd_t : public cpu_sum_pd_t { |
34 | using cpu_sum_pd_t::cpu_sum_pd_t; |
35 | |
36 | pd_t(const pd_t &rhs) = default; |
37 | |
38 | DECLARE_SUM_PD_T("ref:any" , ref_sum_t); |
39 | |
40 | status_t init(engine_t *engine) { |
41 | bool ok = cpu_sum_pd_t::init(engine) == status::success; |
42 | if (!ok) return status::unimplemented; |
43 | |
44 | if (has_zero_dim_memory()) return status::success; |
45 | |
46 | reorder_pds_.resize(n_ + need_output_reorder()); |
47 | for (int i = 0; i < n_; ++i) { |
48 | primitive_attr_t r_attr; |
49 | r_attr.scales_.set(DNNL_ARG_SRC, 0); |
50 | if (i != 0) r_attr.post_ops_.append_sum(1.0); |
51 | CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine, |
52 | src_md(i), dst_acc_md(), &r_attr)); |
53 | } |
54 | |
55 | if (need_output_reorder()) { |
56 | CHECK(reorder_primitive_desc_create( |
57 | reorder_pds_[n_], engine, dst_acc_md(), dst_md())); |
58 | } |
59 | |
60 | init_scratchpad(); |
61 | return status::success; |
62 | } |
63 | |
64 | std::vector<std::shared_ptr<primitive_desc_t>> reorder_pds_; |
65 | |
66 | private: |
67 | void init_scratchpad() { |
68 | using namespace memory_tracking::names; |
69 | auto scratchpad = scratchpad_registry().registrar(); |
70 | if (need_output_reorder()) { |
71 | const memory_desc_wrapper dst_acc_d(dst_acc_md()); |
72 | scratchpad.book(key_sum_reduction, dst_acc_d.size(), 1, |
73 | dst_acc_d.data_type_size()); |
74 | } |
75 | |
76 | for (size_t i = 0; i < reorder_pds_.size(); i++) { |
77 | scratchpad.book(key_nested_multiple + (int)i, |
78 | reorder_pds_[i]->scratchpad_registry()); |
79 | } |
80 | }; |
81 | }; |
82 | |
83 | ref_sum_t(const pd_t *apd) : primitive_t(apd) {} |
84 | |
85 | status_t init(engine_t *engine) override { |
86 | const size_t n = pd()->reorder_pds_.size(); |
87 | reorders_.resize(n); |
88 | for (size_t i = 0; i < n; ++i) |
89 | pd()->reorder_pds_[i]->create_primitive(reorders_[i], engine); |
90 | |
91 | memory_desc_t scales_md; |
92 | scales_md.ndims = 1; |
93 | scales_md.dims[0] = 1; |
94 | scales_md.data_type = data_type::f32; |
95 | CHECK(memory_desc_init_by_tag(scales_md, format_tag::x)); |
96 | const float *scales = pd()->scales(); |
97 | |
98 | scales_mem_.resize(n); |
99 | for (size_t i = 0; i < n; ++i) |
100 | scales_mem_[i] = std::make_shared<memory_t>(get_service_engine(), |
101 | &scales_md, use_runtime_ptr, |
102 | const_cast<float *>(&(scales[i]))); |
103 | return status::success; |
104 | } |
105 | |
106 | status_t execute(const exec_ctx_t &ctx) const override { |
107 | using namespace memory_tracking::names; |
108 | |
109 | if (pd()->has_zero_dim_memory()) return status::success; |
110 | |
111 | const auto n = pd()->n_inputs(); |
112 | exec_args_t r_args; |
113 | |
114 | auto sum_reduce = pd()->need_output_reorder() |
115 | ? ctx.get_scratchpad_grantor().get_memory_storage( |
116 | key_sum_reduction) |
117 | : nullptr; |
118 | auto dst = ctx.args().at(DNNL_ARG_DST); |
119 | memory_t acc( |
120 | dst.mem->engine(), pd()->dst_acc_md(), std::move(sum_reduce)); |
121 | memory_arg_t dst_acc = {&acc, false}; |
122 | |
123 | /* fix: clang MemorySanitizer: use-of-uninitialized-value */ |
124 | if (pd()->need_output_reorder()) { |
125 | const memory_desc_wrapper acc_d(acc.md()); |
126 | std::memset(acc.memory_storage()->data_handle(), 0, acc_d.size()); |
127 | } |
128 | |
129 | for (int i = 0; i < n; ++i) { |
130 | r_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i); |
131 | r_args[DNNL_ARG_DST] = pd()->need_output_reorder() ? dst_acc : dst; |
132 | r_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] |
133 | = {scales_mem_[i].get(), true}; |
134 | |
135 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
136 | |
137 | nested_scratchpad_t ns(ctx, key_nested_multiple + i, reorders_[i]); |
138 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
139 | reorders_[i]->execute(r_ctx); |
140 | } |
141 | |
142 | if (pd()->need_output_reorder()) { |
143 | dst_acc = {&acc, true}; |
144 | r_args[DNNL_ARG_SRC] = dst_acc; |
145 | r_args[DNNL_ARG_DST] = dst; |
146 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
147 | |
148 | nested_scratchpad_t ns(ctx, key_nested_multiple + n, reorders_[n]); |
149 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
150 | reorders_[n]->execute(r_ctx); |
151 | } |
152 | return status::success; |
153 | } |
154 | |
155 | private: |
156 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
157 | std::vector<std::shared_ptr<primitive_t>> reorders_; |
158 | std::vector<std::shared_ptr<memory_t>> scales_mem_; |
159 | }; |
160 | |
161 | } // namespace cpu |
162 | } // namespace impl |
163 | } // namespace dnnl |
164 | |
165 | #endif |
166 | |