1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_REORDER_PD_HPP
18#define COMMON_REORDER_PD_HPP
19
20#include <assert.h>
21
22#include "c_types_map.hpp"
23#include "engine.hpp"
24#include "primitive.hpp"
25#include "primitive_attr.hpp"
26#include "primitive_desc_iface.hpp"
27#include "primitive_iface.hpp"
28#include "type_helpers.hpp"
29#include "utils.hpp"
30
31namespace dnnl {
32namespace impl {
33
34struct reorder_primitive_desc_iface_t : public dnnl_primitive_desc {
35 reorder_primitive_desc_iface_t(const std::shared_ptr<primitive_desc_t> &pd,
36 engine_t *engine, engine_t *src_engine, engine_t *dst_engine)
37 : dnnl_primitive_desc(pd, engine)
38 , src_engine_(src_engine)
39 , dst_engine_(dst_engine)
40 , scratchpad_engine_(nullptr) {}
41
42 dnnl::impl::engine_t *src_engine() const override { return src_engine_; }
43 dnnl::impl::engine_t *dst_engine() const override { return dst_engine_; }
44
45 dnnl::impl::engine_t *scratchpad_engine() const override {
46 return scratchpad_engine_;
47 }
48
49 dnnl::impl::status_t query(
50 dnnl::impl::query_t what, int idx, void *result) const override {
51 auto status = dnnl::impl::status::success;
52 switch (what) {
53 case dnnl::impl::query::reorder_src_engine:
54 *(dnnl::impl::engine_t **)result = src_engine();
55 break;
56 case dnnl::impl::query::reorder_dst_engine:
57 *(dnnl::impl::engine_t **)result = dst_engine();
58 break;
59 default: status = dnnl_primitive_desc::query(what, idx, result);
60 }
61 return status;
62 }
63
64 status_t create_primitive_iface(
65 std::pair<primitive_iface_t *, bool> &primitive_iface,
66 const cache_blob_t &cache_blob) const override {
67 // Step 1: create impl::primitive_t or get it from primitive cache
68 std::pair<std::shared_ptr<primitive_t>, bool> p;
69 auto status = pd_->create_primitive(p, engine(), cache_blob);
70 if (status != status::success) return status;
71 // Step 2: create primitive_iface_t, init and return it to user
72 primitive_iface_t *p_iface = nullptr;
73 CHECK(safe_ptr_assign(p_iface,
74 new primitive_iface_t(
75 p.first, engine(), src_engine_, dst_engine_)));
76 status = p_iface->init();
77 if (status != status::success) {
78 p_iface->release();
79 return status;
80 }
81 primitive_iface = std::make_pair(p_iface, p.second);
82 return status::success;
83 }
84
85private:
86 dnnl::impl::engine_t *src_engine_;
87 dnnl::impl::engine_t *dst_engine_;
88 dnnl::impl::engine_t *scratchpad_engine_;
89};
90
91struct reorder_pd_t : public primitive_desc_t {
92 const reorder_desc_t *desc() const { return &desc_; }
93 const op_desc_t *op_desc() const override {
94 return reinterpret_cast<const op_desc_t *>(this->desc());
95 }
96
97 arg_usage_t arg_usage(int arg) const override {
98 if (arg == DNNL_ARG_FROM) return arg_usage_t::input;
99
100 if (arg == DNNL_ARG_TO) return arg_usage_t::output;
101
102 return primitive_desc_t::arg_usage(arg);
103 }
104
105 const memory_desc_t *arg_md(int arg) const override {
106 switch (arg) {
107 case DNNL_ARG_FROM: return src_md(0);
108 case DNNL_ARG_TO: return dst_md(0);
109 default: return primitive_desc_t::arg_md(arg);
110 }
111 }
112
113 const memory_desc_t *src_md(int index = 0) const override {
114 return index == 0 ? &src_md_ : &glob_zero_md;
115 }
116 const memory_desc_t *dst_md(int index = 0) const override {
117 return index == 0 ? &dst_md_ : &glob_zero_md;
118 }
119
120 int n_inputs() const override { return 1; }
121 int n_outputs() const override { return 1; }
122
123 float beta() const {
124 const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
125 return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale;
126 }
127
128protected:
129 reorder_desc_t desc_;
130 memory_desc_t src_md_;
131 memory_desc_t dst_md_;
132
133 reorder_pd_t(const primitive_attr_t *attr, engine_kind_t src_engine_kind,
134 const memory_desc_t *src_md, engine_kind_t dst_engine_kind,
135 const memory_desc_t *dst_md)
136 : primitive_desc_t(attr, primitive_kind::reorder)
137 , src_md_(*src_md)
138 , dst_md_(*dst_md) {
139
140 init_desc(src_engine_kind, dst_engine_kind, false);
141 }
142
143 reorder_pd_t(const reorder_pd_t &other) : primitive_desc_t(other) {
144 src_md_ = other.src_md_;
145 dst_md_ = other.dst_md_;
146
147 init_desc(other.desc_.src_engine_kind, other.desc_.dst_engine_kind,
148 other.desc_.is_cross_engine);
149 }
150
151protected:
152 void init_desc(engine_kind_t src_engine_kind, engine_kind_t dst_engine_kind,
153 bool is_cross_engine) {
154 desc_ = reorder_desc_t();
155 desc_.primitive_kind = primitive_kind::reorder;
156 desc_.src_md = &src_md_;
157 desc_.dst_md = &dst_md_;
158 desc_.src_engine_kind = src_engine_kind;
159 desc_.dst_engine_kind = dst_engine_kind;
160 desc_.is_cross_engine = is_cross_engine;
161 }
162};
163
164} // namespace impl
165} // namespace dnnl
166
167#endif
168
169// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
170