1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "primitive_exec_types.hpp"
18#include "engine.hpp"
19#include "memory.hpp"
20#include "memory_storage.hpp"
21#include "primitive.hpp"
22#include "primitive_desc.hpp"
23
24namespace dnnl {
25namespace impl {
26
27status_t cvt_primitive_args(const primitive_desc_t *pd, int nargs,
28 const dnnl_exec_arg_t *c_args, exec_args_t &args) {
29 using namespace status;
30
31 if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments;
32
33 // TODO: better put extra_* in primitive_desc
34 int n_inputs = 0, extra_inputs = 0;
35 int n_outputs = 0, extra_outputs = 0;
36
37 for (int i = 0; i < nargs; ++i) {
38 int arg = c_args[i].arg;
39 auto *mem = c_args[i].memory;
40
41 // allows dummy arguments
42 if (mem == nullptr) continue;
43
44 switch (pd->arg_usage(arg)) {
45 case primitive_desc_t::arg_usage_t::input:
46 if (args.count(arg) != 0) return invalid_arguments;
47 args[arg] = {mem, true};
48 n_inputs++;
49 extra_inputs += (arg == DNNL_ARG_ATTR_OUTPUT_SCALES)
50 || (arg & DNNL_ARG_ATTR_ZERO_POINTS)
51 || (arg & DNNL_ARG_ATTR_SCALES)
52 // 1x1 + dw conv fusion
53 || (arg
54 == (DNNL_ARG_ATTR_POST_OP_DW
55 | DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC))
56 || (arg
57 == (DNNL_ARG_ATTR_POST_OP_DW
58 | DNNL_ARG_ATTR_SCALES
59 | DNNL_ARG_WEIGHTS))
60 || (arg
61 == (DNNL_ARG_ATTR_POST_OP_DW
62 | DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST));
63 break;
64 case primitive_desc_t::arg_usage_t::output:
65 if (args.count(arg) != 0) return invalid_arguments;
66 args[arg] = {mem, false};
67 n_outputs++;
68 extra_outputs += (arg == DNNL_ARG_SCRATCHPAD);
69 break;
70 case primitive_desc_t::arg_usage_t::unused: break;
71 }
72 }
73
74 if (n_inputs != pd->n_inputs() + extra_inputs) return invalid_arguments;
75 if (n_outputs != pd->n_outputs() + extra_outputs) return invalid_arguments;
76
77 return success;
78}
79
80memory_t *exec_ctx_t::input(int arg) const {
81 if (args_.count(arg) != 1) return nullptr;
82 const auto ma = args_.at(arg);
83 assert(ma.is_const);
84 return ma.mem;
85}
86
87memory_t *exec_ctx_t::output(int arg) const {
88 if (args_.count(arg) != 1) return nullptr;
89 const auto ma = args_.at(arg);
90 assert(!ma.is_const);
91 return ma.mem;
92}
93
94status_t exec_ctx_t::zero_pad_output(int arg) const {
95 memory_t *mem = this->output(arg);
96 if (mem == nullptr) return status::success;
97
98 return mem->zero_pad(*this);
99}
100
101memory_t *exec_ctx_t::memory(int arg) const {
102 assert(args_.count(arg) == 1);
103 const auto ma = args_.at(arg);
104 assert(!ma.is_const);
105 return ma.mem;
106}
107
108void exec_ctx_t::register_memory_mapping(void *handle, void *host_ptr) {
109 assert(memory_mapping_.count(handle) == 0);
110 memory_mapping_.insert({handle, host_ptr});
111}
112
113void *exec_ctx_t::host_ptr(int arg, bool do_zeropad, status_t *status_) const {
114 status_t status = status::success;
115 if (status_) *status_ = status;
116
117 if (args_.count(arg) != 1) return nullptr;
118
119 auto *mem = args_.at(arg).mem;
120 if (do_zeropad) status = mem->zero_pad(*this);
121 if (status_) *status_ = status;
122
123 auto *mem_storage = mem->memory_storage();
124 return host_ptr(mem_storage);
125}
126
127void *exec_ctx_t::host_ptr(const memory_storage_t *mem_storage) const {
128 if (!mem_storage || mem_storage->is_null()) return nullptr;
129
130 void *handle = mem_storage->data_handle();
131 void *base_ptr = nullptr;
132 if (memory_mapping_.count(handle) > 0) {
133 base_ptr = memory_mapping_.at(handle);
134 } else {
135 assert(mem_storage->is_host_accessible());
136 base_ptr = handle;
137 }
138 return base_ptr;
139}
140
141void *exec_ctx_t::map_memory_storage(
142 const memory_storage_t *storage, stream_t *stream, size_t size) const {
143 if (!storage || storage->is_null()) return nullptr;
144
145 if (memory_mapping_.count(storage->data_handle()) > 0) {
146 return host_ptr(storage);
147 }
148
149 void *mapped_ptr;
150 status_t status = storage->map_data(&mapped_ptr, stream, size);
151 assert(status == status::success);
152 MAYBE_UNUSED(status);
153 return mapped_ptr;
154}
155
156void exec_ctx_t::unmap_memory_storage(const memory_storage_t *storage,
157 void *mapped_ptr, stream_t *stream) const {
158 if (!storage || storage->is_null()
159 || memory_mapping_.count(storage->data_handle()) > 0)
160 return;
161
162 status_t status = storage->unmap_data(mapped_ptr, stream);
163 assert(status == status::success);
164 MAYBE_UNUSED(status);
165}
166
167memory_desc_wrapper exec_ctx_t::memory_mdw(
168 int arg, const memory_desc_t *md_from_primitive_desc) const {
169 if (md_from_primitive_desc) {
170 memory_desc_wrapper mdw_from_primitive_desc(md_from_primitive_desc);
171 if (!mdw_from_primitive_desc.has_runtime_dims_or_strides())
172 return mdw_from_primitive_desc;
173 }
174 if (args_.count(arg) != 1) return memory_desc_wrapper(&glob_zero_md);
175 return memory_desc_wrapper(args_.at(arg).mem->md());
176}
177
178const resource_mapper_t *exec_ctx_t::get_resource_mapper() const {
179 assert(resource_mapper_);
180 return resource_mapper_;
181}
182
183void exec_ctx_t::set_resource_mapper(const resource_mapper_t *resource_mapper) {
184 resource_mapper_ = resource_mapper;
185}
186
187} // namespace impl
188} // namespace dnnl
189