1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/cross_engine_reorder.hpp" |
18 | |
19 | #include "common/reorder.hpp" |
20 | #include "common/utils.hpp" |
21 | #include "gpu/ocl/ocl_stream.hpp" |
22 | #include "gpu/ocl/ocl_utils.hpp" |
23 | #include "gpu/primitive_conf.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace gpu { |
28 | namespace ocl { |
29 | |
30 | void cross_engine_reorder_t::pd_t::init_scratchpad() { |
31 | using namespace memory_tracking::names; |
32 | if (!do_reorder_) return; |
33 | |
34 | const memory_desc_wrapper wspace_md( |
35 | desc()->src_engine_kind == reorder_engine_kind_ ? dst_md() |
36 | : src_md()); |
37 | auto scratchpad = scratchpad_registry().registrar(); |
38 | scratchpad.book(memory_tracking::names::key_reorder_cross_space, |
39 | wspace_md.size(), 1, OCL_BUFFER_ALIGNMENT); |
40 | scratchpad.book(key_nested, reorder_pd_->scratchpad_registry().size(), 1, |
41 | OCL_BUFFER_ALIGNMENT); |
42 | } |
43 | |
44 | status_t cross_engine_reorder_t::pd_t::init( |
45 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { |
46 | bool args_ok = src_engine != dst_engine |
47 | && utils::one_of( |
48 | engine_kind::gpu, src_engine->kind(), dst_engine->kind()) |
49 | && attr_ok() && extra_ok(); |
50 | |
51 | if (!args_ok) return status::unimplemented; |
52 | |
53 | memory_desc_wrapper src_mdw(src_md()); |
54 | memory_desc_wrapper dst_mdw(dst_md()); |
55 | |
56 | if (src_mdw.has_runtime_dims_or_strides()) return status::unimplemented; |
57 | |
58 | quantization_t src_quant {attr(), src_mdw, DNNL_ARG_SRC}; |
59 | quantization_t dst_quant {attr(), dst_mdw, DNNL_ARG_DST}; |
60 | sum_quantization_t sum_quant {attr()}; |
61 | bool with_sum_ab = src_quant.with_scale() || src_quant.with_zp() |
62 | || dst_quant.with_scale() || dst_quant.with_zp() |
63 | || sum_quant.with_scale() || sum_quant.with_zp(); |
64 | do_reorder_ = with_sum_ab || src_mdw != dst_mdw; |
65 | |
66 | engine_t *reorder_engine |
67 | = src_engine->kind() == engine_kind::gpu ? src_engine : dst_engine; |
68 | |
69 | primitive_attr_t r_attr(*attr()); |
70 | if (!r_attr.is_initialized()) return status::out_of_memory; |
71 | |
72 | CHECK(reorder_primitive_desc_create( |
73 | reorder_pd_, reorder_engine, src_md(), dst_md(), &r_attr)); |
74 | init_scratchpad(); |
75 | |
76 | reorder_pd_t::init_desc( |
77 | src_engine->kind(), dst_engine->kind(), true /* is_cross_engine */); |
78 | |
79 | return status::success; |
80 | } |
81 | |
82 | status_t cross_engine_reorder_t::execute(const exec_ctx_t &ctx) const { |
83 | using namespace memory_tracking::names; |
84 | auto *compute_stream |
85 | = utils::downcast<compute::compute_stream_t *>(ctx.stream()); |
86 | |
87 | status_t status = status::success; |
88 | |
89 | auto &src = CTX_IN_STORAGE(DNNL_ARG_FROM); |
90 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_TO); |
91 | |
92 | std::unique_ptr<memory_t> wspace; |
93 | if (pd()->do_reorder_) { |
94 | auto src_engine_kind = pd()->desc()->src_engine_kind; |
95 | auto reorder_engine_kind = pd()->reorder_engine_kind_; |
96 | auto scratchpad = ctx.get_scratchpad_grantor().get_memory_storage( |
97 | key_reorder_cross_space); |
98 | auto wspace_md = src_engine_kind == reorder_engine_kind |
99 | ? pd()->dst_md() |
100 | : pd()->src_md(); |
101 | CHECK(safe_ptr_assign(wspace, |
102 | new memory_t(ctx.stream()->engine(), wspace_md, |
103 | std::move(scratchpad)))); |
104 | } |
105 | |
106 | auto exec_reorder = [&](const memory_t *src_mem, const memory_t *dst_mem, |
107 | const memory_t *src_scales_mem, |
108 | const memory_t *dst_scales_mem) { |
109 | exec_args_t r_args; |
110 | r_args[DNNL_ARG_SRC] |
111 | = memory_arg_t {const_cast<memory_t *>(src_mem), true}; |
112 | r_args[DNNL_ARG_DST] |
113 | = memory_arg_t {const_cast<memory_t *>(dst_mem), false}; |
114 | r_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] |
115 | = memory_arg_t {const_cast<memory_t *>(src_scales_mem), true}; |
116 | r_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST] |
117 | = memory_arg_t {const_cast<memory_t *>(dst_scales_mem), true}; |
118 | |
119 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
120 | |
121 | nested_scratchpad_t ns(ctx, key_nested, reorder_); |
122 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
123 | return reorder_->execute(r_ctx); |
124 | }; |
125 | |
126 | if (pd()->desc()->src_engine_kind == engine_kind::gpu) { |
127 | // GPU -> CPU or GPU -> GPU |
128 | memory_desc_wrapper dst_mdw(pd()->dst_md()); |
129 | if (pd()->do_reorder_) { |
130 | if (pd()->beta() != 0.f) { |
131 | status = compute_stream->copy( |
132 | dst, *wspace->memory_storage(), dst_mdw.size()); |
133 | } |
134 | if (status == status::success) |
135 | status = exec_reorder(ctx.input(DNNL_ARG_FROM), wspace.get(), |
136 | ctx.input(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC), |
137 | ctx.input(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)); |
138 | } |
139 | if (status == status::success) { |
140 | status = compute_stream->copy( |
141 | pd()->do_reorder_ ? *wspace->memory_storage() : src, dst, |
142 | dst_mdw.size()); |
143 | } |
144 | } else { |
145 | // CPU -> GPU |
146 | memory_desc_wrapper src_mdw(pd()->src_md()); |
147 | status = compute_stream->copy(src, |
148 | pd()->do_reorder_ ? *wspace->memory_storage() : dst, |
149 | src_mdw.size()); |
150 | if (status == status::success && pd()->do_reorder_) { |
151 | status = exec_reorder(wspace.get(), ctx.output(DNNL_ARG_TO), |
152 | ctx.input(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC), |
153 | ctx.input(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)); |
154 | } |
155 | } |
156 | return status; |
157 | } |
158 | |
159 | } // namespace ocl |
160 | } // namespace gpu |
161 | } // namespace impl |
162 | } // namespace dnnl |
163 | |