1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include "oneapi/dnnl/dnnl.h" |
19 | |
20 | #include "c_types_map.hpp" |
21 | #include "engine.hpp" |
22 | #include "impl_list_item.hpp" |
23 | #include "primitive_cache.hpp" |
24 | #include "primitive_hashing.hpp" |
25 | #include "type_helpers.hpp" |
26 | #include "utils.hpp" |
27 | |
28 | #include "reorder_pd.hpp" |
29 | |
30 | using namespace dnnl::impl; |
31 | using namespace dnnl::impl::utils; |
32 | using namespace dnnl::impl::status; |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | |
37 | namespace { |
38 | engine_t *get_reorder_engine(engine_t *src_engine, engine_t *dst_engine) { |
39 | auto s_ek = src_engine->kind(); |
40 | auto d_ek = dst_engine->kind(); |
41 | auto s_rk = src_engine->runtime_kind(); |
42 | auto d_rk = dst_engine->runtime_kind(); |
43 | |
44 | if (is_native_runtime(d_rk)) return src_engine; |
45 | |
46 | if (is_native_runtime(s_rk)) return dst_engine; |
47 | |
48 | if (d_ek == engine_kind::cpu) return src_engine; |
49 | |
50 | if (s_ek == engine_kind::cpu) return dst_engine; |
51 | |
52 | assert(s_ek == engine_kind::gpu); |
53 | assert(d_ek == engine_kind::gpu); |
54 | return src_engine; |
55 | } |
56 | } // namespace |
57 | |
58 | status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd, |
59 | engine_t *engine, const memory_desc_t *src_md, engine_t *src_engine, |
60 | const memory_desc_t *dst_md, engine_t *dst_engine, |
61 | const primitive_attr_t *attr) { |
62 | pd.reset(); |
63 | |
64 | auto s_ek = src_engine->kind(); |
65 | auto d_ek = dst_engine->kind(); |
66 | |
67 | bool args_ok = !memory_desc_wrapper(src_md).format_any() |
68 | && !memory_desc_wrapper(dst_md).format_any() |
69 | && IMPLICATION( |
70 | s_ek != d_ek, utils::one_of(engine_kind::cpu, s_ek, d_ek)); |
71 | if (!args_ok) return invalid_arguments; |
72 | |
73 | auto s_mdw = memory_desc_wrapper(*src_md); |
74 | auto d_mdw = memory_desc_wrapper(*dst_md); |
75 | |
76 | if (!s_mdw.consistent_with(d_mdw)) return invalid_arguments; |
77 | |
78 | if (attr == nullptr) attr = &default_attr(); |
79 | |
80 | // Zero points are only allowed for integral data types |
81 | auto zero_points = attr->zero_points_; |
82 | const bool is_src_zp_ok = types::is_integral_dt(src_md->data_type) |
83 | || zero_points.has_default_values(DNNL_ARG_SRC); |
84 | if (!is_src_zp_ok) return status::unimplemented; |
85 | const bool is_dst_zp_ok = types::is_integral_dt(dst_md->data_type) |
86 | || zero_points.has_default_values(DNNL_ARG_DST); |
87 | if (!is_dst_zp_ok) return status::unimplemented; |
88 | |
89 | bool is_cross_engine = src_engine != dst_engine |
90 | && utils::one_of( |
91 | engine_kind::gpu, src_engine->kind(), dst_engine->kind()); |
92 | |
93 | reorder_desc_t desc = {primitive_kind::reorder, src_md, dst_md, s_ek, d_ek, |
94 | is_cross_engine}; |
95 | primitive_hashing::key_t key( |
96 | engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {}); |
97 | pd = primitive_cache().get_pd(key); |
98 | if (pd) return success; |
99 | |
100 | for (auto r = engine->get_reorder_implementation_list(src_md, dst_md); *r; |
101 | ++r) { |
102 | reorder_pd_t *reorder_pd = nullptr; |
103 | if ((*r)(&reorder_pd, engine, attr, src_engine, src_md, dst_engine, |
104 | dst_md) |
105 | == success) { |
106 | pd.reset(reorder_pd); |
107 | return success; |
108 | } |
109 | } |
110 | return unimplemented; |
111 | } |
112 | |
113 | status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd, |
114 | engine_t *engine, const memory_desc_t *src_md, |
115 | const memory_desc_t *dst_md, const primitive_attr_t *attr) { |
116 | return reorder_primitive_desc_create( |
117 | pd, engine, src_md, engine, dst_md, engine, attr); |
118 | } |
119 | |
120 | } // namespace impl |
121 | } // namespace dnnl |
122 | |
123 | status_t dnnl_reorder_primitive_desc_create( |
124 | primitive_desc_iface_t **reorder_pd_iface, const memory_desc_t *src_md, |
125 | engine_t *src_engine, const memory_desc_t *dst_md, engine_t *dst_engine, |
126 | const primitive_attr_t *attr) { |
127 | if (any_null(reorder_pd_iface, src_engine, src_md, dst_engine, dst_md)) |
128 | return invalid_arguments; |
129 | |
130 | std::shared_ptr<primitive_desc_t> pd; |
131 | auto e = get_reorder_engine(src_engine, dst_engine); |
132 | CHECK(reorder_primitive_desc_create( |
133 | pd, e, src_md, src_engine, dst_md, dst_engine, attr)); |
134 | |
135 | return safe_ptr_assign(*reorder_pd_iface, |
136 | new reorder_primitive_desc_iface_t(pd, e, src_engine, dst_engine)); |
137 | } |
138 | |
139 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
140 | |