1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_REORDER_PD_HPP |
18 | #define COMMON_REORDER_PD_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "engine.hpp" |
24 | #include "primitive.hpp" |
25 | #include "primitive_attr.hpp" |
26 | #include "primitive_desc_iface.hpp" |
27 | #include "primitive_iface.hpp" |
28 | #include "type_helpers.hpp" |
29 | #include "utils.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | |
34 | struct reorder_primitive_desc_iface_t : public dnnl_primitive_desc { |
35 | reorder_primitive_desc_iface_t(const std::shared_ptr<primitive_desc_t> &pd, |
36 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) |
37 | : dnnl_primitive_desc(pd, engine) |
38 | , src_engine_(src_engine) |
39 | , dst_engine_(dst_engine) |
40 | , scratchpad_engine_(nullptr) {} |
41 | |
42 | dnnl::impl::engine_t *src_engine() const override { return src_engine_; } |
43 | dnnl::impl::engine_t *dst_engine() const override { return dst_engine_; } |
44 | |
45 | dnnl::impl::engine_t *scratchpad_engine() const override { |
46 | return scratchpad_engine_; |
47 | } |
48 | |
49 | dnnl::impl::status_t query( |
50 | dnnl::impl::query_t what, int idx, void *result) const override { |
51 | auto status = dnnl::impl::status::success; |
52 | switch (what) { |
53 | case dnnl::impl::query::reorder_src_engine: |
54 | *(dnnl::impl::engine_t **)result = src_engine(); |
55 | break; |
56 | case dnnl::impl::query::reorder_dst_engine: |
57 | *(dnnl::impl::engine_t **)result = dst_engine(); |
58 | break; |
59 | default: status = dnnl_primitive_desc::query(what, idx, result); |
60 | } |
61 | return status; |
62 | } |
63 | |
64 | status_t create_primitive_iface( |
65 | std::pair<primitive_iface_t *, bool> &primitive_iface, |
66 | const cache_blob_t &cache_blob) const override { |
67 | // Step 1: create impl::primitive_t or get it from primitive cache |
68 | std::pair<std::shared_ptr<primitive_t>, bool> p; |
69 | auto status = pd_->create_primitive(p, engine(), cache_blob); |
70 | if (status != status::success) return status; |
71 | // Step 2: create primitive_iface_t, init and return it to user |
72 | primitive_iface_t *p_iface = nullptr; |
73 | CHECK(safe_ptr_assign(p_iface, |
74 | new primitive_iface_t( |
75 | p.first, engine(), src_engine_, dst_engine_))); |
76 | status = p_iface->init(); |
77 | if (status != status::success) { |
78 | p_iface->release(); |
79 | return status; |
80 | } |
81 | primitive_iface = std::make_pair(p_iface, p.second); |
82 | return status::success; |
83 | } |
84 | |
85 | private: |
86 | dnnl::impl::engine_t *src_engine_; |
87 | dnnl::impl::engine_t *dst_engine_; |
88 | dnnl::impl::engine_t *scratchpad_engine_; |
89 | }; |
90 | |
91 | struct reorder_pd_t : public primitive_desc_t { |
92 | const reorder_desc_t *desc() const { return &desc_; } |
93 | const op_desc_t *op_desc() const override { |
94 | return reinterpret_cast<const op_desc_t *>(this->desc()); |
95 | } |
96 | |
97 | arg_usage_t arg_usage(int arg) const override { |
98 | if (arg == DNNL_ARG_FROM) return arg_usage_t::input; |
99 | |
100 | if (arg == DNNL_ARG_TO) return arg_usage_t::output; |
101 | |
102 | return primitive_desc_t::arg_usage(arg); |
103 | } |
104 | |
105 | const memory_desc_t *arg_md(int arg) const override { |
106 | switch (arg) { |
107 | case DNNL_ARG_FROM: return src_md(0); |
108 | case DNNL_ARG_TO: return dst_md(0); |
109 | default: return primitive_desc_t::arg_md(arg); |
110 | } |
111 | } |
112 | |
113 | const memory_desc_t *src_md(int index = 0) const override { |
114 | return index == 0 ? &src_md_ : &glob_zero_md; |
115 | } |
116 | const memory_desc_t *dst_md(int index = 0) const override { |
117 | return index == 0 ? &dst_md_ : &glob_zero_md; |
118 | } |
119 | |
120 | int n_inputs() const override { return 1; } |
121 | int n_outputs() const override { return 1; } |
122 | |
123 | float beta() const { |
124 | const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); |
125 | return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale; |
126 | } |
127 | |
128 | protected: |
129 | reorder_desc_t desc_; |
130 | memory_desc_t src_md_; |
131 | memory_desc_t dst_md_; |
132 | |
133 | reorder_pd_t(const primitive_attr_t *attr, engine_kind_t src_engine_kind, |
134 | const memory_desc_t *src_md, engine_kind_t dst_engine_kind, |
135 | const memory_desc_t *dst_md) |
136 | : primitive_desc_t(attr, primitive_kind::reorder) |
137 | , src_md_(*src_md) |
138 | , dst_md_(*dst_md) { |
139 | |
140 | init_desc(src_engine_kind, dst_engine_kind, false); |
141 | } |
142 | |
143 | reorder_pd_t(const reorder_pd_t &other) : primitive_desc_t(other) { |
144 | src_md_ = other.src_md_; |
145 | dst_md_ = other.dst_md_; |
146 | |
147 | init_desc(other.desc_.src_engine_kind, other.desc_.dst_engine_kind, |
148 | other.desc_.is_cross_engine); |
149 | } |
150 | |
151 | protected: |
152 | void init_desc(engine_kind_t src_engine_kind, engine_kind_t dst_engine_kind, |
153 | bool is_cross_engine) { |
154 | desc_ = reorder_desc_t(); |
155 | desc_.primitive_kind = primitive_kind::reorder; |
156 | desc_.src_md = &src_md_; |
157 | desc_.dst_md = &dst_md_; |
158 | desc_.src_engine_kind = src_engine_kind; |
159 | desc_.dst_engine_kind = dst_engine_kind; |
160 | desc_.is_cross_engine = is_cross_engine; |
161 | } |
162 | }; |
163 | |
164 | } // namespace impl |
165 | } // namespace dnnl |
166 | |
167 | #endif |
168 | |
169 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
170 | |