1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_REF_SHUFFLE_HPP |
18 | #define CPU_REF_SHUFFLE_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/primitive.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/platform.hpp" |
29 | |
30 | #include "cpu/cpu_shuffle_pd.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | |
36 | struct ref_shuffle_t : public primitive_t { |
37 | struct pd_t : public cpu_shuffle_pd_t { |
38 | using cpu_shuffle_pd_t::cpu_shuffle_pd_t; |
39 | |
40 | DECLARE_COMMON_PD_T("ref:any" , ref_shuffle_t); |
41 | |
42 | status_t init(engine_t *engine) { |
43 | using namespace format_tag; |
44 | |
45 | const memory_desc_wrapper src_d( |
46 | is_fwd() ? src_md() : diff_src_md()); |
47 | const memory_desc_wrapper dst_d( |
48 | is_fwd() ? dst_md() : diff_dst_md()); |
49 | |
50 | bool ok = src_d.data_type() == dst_d.data_type() |
51 | && platform::has_data_type_support(src_d.data_type()) |
52 | && attr()->has_default_values() |
53 | && set_default_formats_common() && src_d == dst_d; |
54 | if (!ok) return status::unimplemented; |
55 | |
56 | if (ndims() == 5) { |
57 | dat_tag_ = memory_desc_matches_one_of_tag( |
58 | *src_d.md_, nCdhw16c, nCdhw8c, nCdhw4c, ncdhw, ndhwc); |
59 | } else if (ndims() == 4) { |
60 | dat_tag_ = memory_desc_matches_one_of_tag( |
61 | *src_d.md_, nChw16c, nChw8c, nChw4c, nchw, nhwc); |
62 | } else |
63 | dat_tag_ = any; |
64 | |
65 | return status::success; |
66 | } |
67 | |
68 | format_tag_t dat_tag_; |
69 | }; |
70 | |
71 | ref_shuffle_t(const pd_t *apd) : primitive_t(apd) {} |
72 | |
73 | status_t init(engine_t *engine) override { |
74 | const int axis_size = pd()->axis_size(); |
75 | const dim_t group_size = pd()->group_size(); |
76 | const dim_t transpose_row |
77 | = pd()->is_fwd() ? group_size : axis_size / group_size; |
78 | const dim_t transpose_col |
79 | = pd()->is_fwd() ? axis_size / group_size : group_size; |
80 | rev_transposed_ = (int *)malloc( |
81 | axis_size * sizeof(int), platform::get_cache_line_size()); |
82 | if (rev_transposed_ == nullptr) return dnnl_out_of_memory; |
83 | parallel_nd(transpose_col, transpose_row, [&](dim_t i, dim_t j) { |
84 | rev_transposed_[j * transpose_col + i] = i * transpose_row + j; |
85 | }); |
86 | return dnnl_success; |
87 | } |
88 | |
89 | ~ref_shuffle_t() { free(rev_transposed_); } |
90 | |
91 | status_t execute(const exec_ctx_t &ctx) const override { |
92 | const memory_desc_wrapper src_d( |
93 | pd()->is_fwd() ? pd()->src_md() : pd()->diff_src_md()); |
94 | switch (types::data_type_size(src_d.data_type())) { |
95 | case sizeof(float): return execute_<sizeof(float)>(ctx); break; |
96 | case sizeof(bfloat16_t): |
97 | return execute_<sizeof(bfloat16_t)>(ctx); |
98 | break; |
99 | case sizeof(int8_t): return execute_<sizeof(int8_t)>(ctx); break; |
100 | default: assert(!"unsupported data type size" ); |
101 | } |
102 | return status::success; |
103 | } |
104 | |
105 | private: |
106 | template <int data_type_size> |
107 | status_t execute_(const exec_ctx_t &ctx) const; |
108 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
109 | int *rev_transposed_ = nullptr; |
110 | }; |
111 | |
112 | } // namespace cpu |
113 | } // namespace impl |
114 | } // namespace dnnl |
115 | |
116 | #endif |
117 | |
118 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
119 | |