1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19
20#include "c_types_map.hpp"
21#include "engine.hpp"
22#include "impl_list_item.hpp"
23#include "primitive_cache.hpp"
24#include "primitive_hashing.hpp"
25#include "type_helpers.hpp"
26#include "utils.hpp"
27
28#include "reorder_pd.hpp"
29
30using namespace dnnl::impl;
31using namespace dnnl::impl::utils;
32using namespace dnnl::impl::status;
33
34namespace dnnl {
35namespace impl {
36
37namespace {
38engine_t *get_reorder_engine(engine_t *src_engine, engine_t *dst_engine) {
39 auto s_ek = src_engine->kind();
40 auto d_ek = dst_engine->kind();
41 auto s_rk = src_engine->runtime_kind();
42 auto d_rk = dst_engine->runtime_kind();
43
44 if (is_native_runtime(d_rk)) return src_engine;
45
46 if (is_native_runtime(s_rk)) return dst_engine;
47
48 if (d_ek == engine_kind::cpu) return src_engine;
49
50 if (s_ek == engine_kind::cpu) return dst_engine;
51
52 assert(s_ek == engine_kind::gpu);
53 assert(d_ek == engine_kind::gpu);
54 return src_engine;
55}
56} // namespace
57
58status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
59 engine_t *engine, const memory_desc_t *src_md, engine_t *src_engine,
60 const memory_desc_t *dst_md, engine_t *dst_engine,
61 const primitive_attr_t *attr) {
62 pd.reset();
63
64 auto s_ek = src_engine->kind();
65 auto d_ek = dst_engine->kind();
66
67 bool args_ok = !memory_desc_wrapper(src_md).format_any()
68 && !memory_desc_wrapper(dst_md).format_any()
69 && IMPLICATION(
70 s_ek != d_ek, utils::one_of(engine_kind::cpu, s_ek, d_ek));
71 if (!args_ok) return invalid_arguments;
72
73 auto s_mdw = memory_desc_wrapper(*src_md);
74 auto d_mdw = memory_desc_wrapper(*dst_md);
75
76 if (!s_mdw.consistent_with(d_mdw)) return invalid_arguments;
77
78 if (attr == nullptr) attr = &default_attr();
79
80 // Zero points are only allowed for integral data types
81 auto zero_points = attr->zero_points_;
82 const bool is_src_zp_ok = types::is_integral_dt(src_md->data_type)
83 || zero_points.has_default_values(DNNL_ARG_SRC);
84 if (!is_src_zp_ok) return status::unimplemented;
85 const bool is_dst_zp_ok = types::is_integral_dt(dst_md->data_type)
86 || zero_points.has_default_values(DNNL_ARG_DST);
87 if (!is_dst_zp_ok) return status::unimplemented;
88
89 bool is_cross_engine = src_engine != dst_engine
90 && utils::one_of(
91 engine_kind::gpu, src_engine->kind(), dst_engine->kind());
92
93 reorder_desc_t desc = {primitive_kind::reorder, src_md, dst_md, s_ek, d_ek,
94 is_cross_engine};
95 primitive_hashing::key_t key(
96 engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {});
97 pd = primitive_cache().get_pd(key);
98 if (pd) return success;
99
100 for (auto r = engine->get_reorder_implementation_list(src_md, dst_md); *r;
101 ++r) {
102 reorder_pd_t *reorder_pd = nullptr;
103 if ((*r)(&reorder_pd, engine, attr, src_engine, src_md, dst_engine,
104 dst_md)
105 == success) {
106 pd.reset(reorder_pd);
107 return success;
108 }
109 }
110 return unimplemented;
111}
112
113status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
114 engine_t *engine, const memory_desc_t *src_md,
115 const memory_desc_t *dst_md, const primitive_attr_t *attr) {
116 return reorder_primitive_desc_create(
117 pd, engine, src_md, engine, dst_md, engine, attr);
118}
119
120} // namespace impl
121} // namespace dnnl
122
123status_t dnnl_reorder_primitive_desc_create(
124 primitive_desc_iface_t **reorder_pd_iface, const memory_desc_t *src_md,
125 engine_t *src_engine, const memory_desc_t *dst_md, engine_t *dst_engine,
126 const primitive_attr_t *attr) {
127 if (any_null(reorder_pd_iface, src_engine, src_md, dst_engine, dst_md))
128 return invalid_arguments;
129
130 std::shared_ptr<primitive_desc_t> pd;
131 auto e = get_reorder_engine(src_engine, dst_engine);
132 CHECK(reorder_primitive_desc_create(
133 pd, e, src_md, src_engine, dst_md, dst_engine, attr));
134
135 return safe_ptr_assign(*reorder_pd_iface,
136 new reorder_primitive_desc_iface_t(pd, e, src_engine, dst_engine));
137}
138
139// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
140