1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_REORDER_CPU_REORDER_PD_HPP |
18 | #define CPU_REORDER_CPU_REORDER_PD_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/reorder_pd.hpp" |
24 | #include "common/utils.hpp" |
25 | #include "cpu/cpu_engine.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | struct cpu_reorder_pd_t : public reorder_pd_t { |
32 | using reorder_pd_t::reorder_pd_t; |
33 | |
34 | status_t init( |
35 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { |
36 | const auto &post_ops = attr()->post_ops_; |
37 | bool args_ok = IMPLICATION(post_ops.len() != 0, |
38 | post_ops.len() == 1 |
39 | && post_ops.entry_[0].kind == primitive_kind::sum); |
40 | return args_ok ? status::success : status::unimplemented; |
41 | } |
42 | |
43 | // The function splits dimension products based on input mask and returns |
44 | // them as `D_start`, `D_mask` and `D_rest`. |
45 | // Its used to estimate amount of memory for scratchpad and precomputed |
46 | // destination scales. |
47 | void get_D_values(const memory_desc_wrapper &input_d, int mask, |
48 | dim_t *D_start, dim_t *D_mask, dim_t *D_rest) const { |
49 | int ndims = input_d.ndims(); |
50 | int ndims_start = 0, ndims_mask = 0; |
51 | // XXX: Currently user can pass a mask that has non-zero values in |
52 | // dimensions that do not exist in a md. Since attributes are created |
53 | // separately mask can't be validated. |
54 | // This line truncates a given mask in range [0, 1 << ndims - 1] |
55 | // TODO: Such masks can be either prohibited at pd creation step at |
56 | // API level or checked by each implementation that relies on it. |
57 | mask &= (1 << ndims) - 1; |
58 | |
59 | for (; mask > 0 && !(mask & 0x1); mask >>= 1) |
60 | ++ndims_start; |
61 | for (; mask > 0 && mask & 0x1; mask >>= 1) |
62 | ++ndims_mask; |
63 | assert(mask == 0); |
64 | |
65 | if (D_start) |
66 | *D_start = utils::array_product(input_d.dims(), ndims_start); |
67 | if (D_mask) |
68 | *D_mask = utils::array_product( |
69 | input_d.dims() + ndims_start, ndims_mask); |
70 | assert(*D_mask >= 1); |
71 | if (D_rest) *D_rest = input_d.nelems() / (*D_start * *D_mask); |
72 | } |
73 | |
74 | // The function serves same purpose as `dnnl::impl::cpu::precompute_scales`. |
75 | // The reason it's dedicated to reorder is it's the only primitive so far |
76 | // that utilizes `mask > 0` for destination scales. |
77 | const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, |
78 | const primitive_attr_t *attr, size_t count, |
79 | const float *dst_scales) const { |
80 | using namespace dnnl::impl::memory_tracking::names; |
81 | |
82 | int mask = -1; |
83 | bool is_set = false; |
84 | auto status = attr->scales_.get(DNNL_ARG_DST, &mask, &is_set); |
85 | if (status != status::success) return nullptr; |
86 | |
87 | // It's possible that mask > 0 but `count` is still `1`. This case is |
88 | // covered by `DEFINE_ARG_SCALES_BUFFER` macro and no need to inverse |
89 | // in such case. |
90 | if (is_set && mask > 0 && count > 1) { |
91 | auto loc_scales = scratchpad.template get<float>( |
92 | key_reorder_precomputed_dst_scales); |
93 | if (!loc_scales) return nullptr; |
94 | |
95 | PRAGMA_OMP_SIMD() |
96 | for (size_t c = 0; c < count; c++) |
97 | loc_scales[c] = 1.f / dst_scales[c]; |
98 | |
99 | return loc_scales; |
100 | } else { |
101 | return dst_scales; |
102 | } |
103 | } |
104 | }; |
105 | |
106 | } // namespace cpu |
107 | } // namespace impl |
108 | } // namespace dnnl |
109 | |
110 | #endif |
111 | |
112 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
113 | |