1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "primitive_exec_types.hpp" |
18 | #include "engine.hpp" |
19 | #include "memory.hpp" |
20 | #include "memory_storage.hpp" |
21 | #include "primitive.hpp" |
22 | #include "primitive_desc.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | |
27 | status_t cvt_primitive_args(const primitive_desc_t *pd, int nargs, |
28 | const dnnl_exec_arg_t *c_args, exec_args_t &args) { |
29 | using namespace status; |
30 | |
31 | if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments; |
32 | |
33 | // TODO: better put extra_* in primitive_desc |
34 | int n_inputs = 0, = 0; |
35 | int n_outputs = 0, = 0; |
36 | |
37 | for (int i = 0; i < nargs; ++i) { |
38 | int arg = c_args[i].arg; |
39 | auto *mem = c_args[i].memory; |
40 | |
41 | // allows dummy arguments |
42 | if (mem == nullptr) continue; |
43 | |
44 | switch (pd->arg_usage(arg)) { |
45 | case primitive_desc_t::arg_usage_t::input: |
46 | if (args.count(arg) != 0) return invalid_arguments; |
47 | args[arg] = {mem, true}; |
48 | n_inputs++; |
49 | extra_inputs += (arg == DNNL_ARG_ATTR_OUTPUT_SCALES) |
50 | || (arg & DNNL_ARG_ATTR_ZERO_POINTS) |
51 | || (arg & DNNL_ARG_ATTR_SCALES) |
52 | // 1x1 + dw conv fusion |
53 | || (arg |
54 | == (DNNL_ARG_ATTR_POST_OP_DW |
55 | | DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC)) |
56 | || (arg |
57 | == (DNNL_ARG_ATTR_POST_OP_DW |
58 | | DNNL_ARG_ATTR_SCALES |
59 | | DNNL_ARG_WEIGHTS)) |
60 | || (arg |
61 | == (DNNL_ARG_ATTR_POST_OP_DW |
62 | | DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)); |
63 | break; |
64 | case primitive_desc_t::arg_usage_t::output: |
65 | if (args.count(arg) != 0) return invalid_arguments; |
66 | args[arg] = {mem, false}; |
67 | n_outputs++; |
68 | extra_outputs += (arg == DNNL_ARG_SCRATCHPAD); |
69 | break; |
70 | case primitive_desc_t::arg_usage_t::unused: break; |
71 | } |
72 | } |
73 | |
74 | if (n_inputs != pd->n_inputs() + extra_inputs) return invalid_arguments; |
75 | if (n_outputs != pd->n_outputs() + extra_outputs) return invalid_arguments; |
76 | |
77 | return success; |
78 | } |
79 | |
80 | memory_t *exec_ctx_t::input(int arg) const { |
81 | if (args_.count(arg) != 1) return nullptr; |
82 | const auto ma = args_.at(arg); |
83 | assert(ma.is_const); |
84 | return ma.mem; |
85 | } |
86 | |
87 | memory_t *exec_ctx_t::output(int arg) const { |
88 | if (args_.count(arg) != 1) return nullptr; |
89 | const auto ma = args_.at(arg); |
90 | assert(!ma.is_const); |
91 | return ma.mem; |
92 | } |
93 | |
94 | status_t exec_ctx_t::zero_pad_output(int arg) const { |
95 | memory_t *mem = this->output(arg); |
96 | if (mem == nullptr) return status::success; |
97 | |
98 | return mem->zero_pad(*this); |
99 | } |
100 | |
101 | memory_t *exec_ctx_t::memory(int arg) const { |
102 | assert(args_.count(arg) == 1); |
103 | const auto ma = args_.at(arg); |
104 | assert(!ma.is_const); |
105 | return ma.mem; |
106 | } |
107 | |
108 | void exec_ctx_t::register_memory_mapping(void *handle, void *host_ptr) { |
109 | assert(memory_mapping_.count(handle) == 0); |
110 | memory_mapping_.insert({handle, host_ptr}); |
111 | } |
112 | |
113 | void *exec_ctx_t::host_ptr(int arg, bool do_zeropad, status_t *status_) const { |
114 | status_t status = status::success; |
115 | if (status_) *status_ = status; |
116 | |
117 | if (args_.count(arg) != 1) return nullptr; |
118 | |
119 | auto *mem = args_.at(arg).mem; |
120 | if (do_zeropad) status = mem->zero_pad(*this); |
121 | if (status_) *status_ = status; |
122 | |
123 | auto *mem_storage = mem->memory_storage(); |
124 | return host_ptr(mem_storage); |
125 | } |
126 | |
127 | void *exec_ctx_t::host_ptr(const memory_storage_t *mem_storage) const { |
128 | if (!mem_storage || mem_storage->is_null()) return nullptr; |
129 | |
130 | void *handle = mem_storage->data_handle(); |
131 | void *base_ptr = nullptr; |
132 | if (memory_mapping_.count(handle) > 0) { |
133 | base_ptr = memory_mapping_.at(handle); |
134 | } else { |
135 | assert(mem_storage->is_host_accessible()); |
136 | base_ptr = handle; |
137 | } |
138 | return base_ptr; |
139 | } |
140 | |
141 | void *exec_ctx_t::map_memory_storage( |
142 | const memory_storage_t *storage, stream_t *stream, size_t size) const { |
143 | if (!storage || storage->is_null()) return nullptr; |
144 | |
145 | if (memory_mapping_.count(storage->data_handle()) > 0) { |
146 | return host_ptr(storage); |
147 | } |
148 | |
149 | void *mapped_ptr; |
150 | status_t status = storage->map_data(&mapped_ptr, stream, size); |
151 | assert(status == status::success); |
152 | MAYBE_UNUSED(status); |
153 | return mapped_ptr; |
154 | } |
155 | |
156 | void exec_ctx_t::unmap_memory_storage(const memory_storage_t *storage, |
157 | void *mapped_ptr, stream_t *stream) const { |
158 | if (!storage || storage->is_null() |
159 | || memory_mapping_.count(storage->data_handle()) > 0) |
160 | return; |
161 | |
162 | status_t status = storage->unmap_data(mapped_ptr, stream); |
163 | assert(status == status::success); |
164 | MAYBE_UNUSED(status); |
165 | } |
166 | |
167 | memory_desc_wrapper exec_ctx_t::memory_mdw( |
168 | int arg, const memory_desc_t *md_from_primitive_desc) const { |
169 | if (md_from_primitive_desc) { |
170 | memory_desc_wrapper mdw_from_primitive_desc(md_from_primitive_desc); |
171 | if (!mdw_from_primitive_desc.has_runtime_dims_or_strides()) |
172 | return mdw_from_primitive_desc; |
173 | } |
174 | if (args_.count(arg) != 1) return memory_desc_wrapper(&glob_zero_md); |
175 | return memory_desc_wrapper(args_.at(arg).mem->md()); |
176 | } |
177 | |
178 | const resource_mapper_t *exec_ctx_t::get_resource_mapper() const { |
179 | assert(resource_mapper_); |
180 | return resource_mapper_; |
181 | } |
182 | |
183 | void exec_ctx_t::set_resource_mapper(const resource_mapper_t *resource_mapper) { |
184 | resource_mapper_ = resource_mapper; |
185 | } |
186 | |
187 | } // namespace impl |
188 | } // namespace dnnl |
189 | |