1 | /******************************************************************************* |
2 | * Copyright 2016-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_STREAM_HPP |
18 | #define COMMON_STREAM_HPP |
19 | |
20 | #include <assert.h> |
21 | #include "oneapi/dnnl/dnnl.h" |
22 | #include "oneapi/dnnl/dnnl_threadpool_iface.hpp" |
23 | |
24 | #include "c_types_map.hpp" |
25 | #include "engine.hpp" |
26 | #include "utils.hpp" |
27 | |
28 | struct dnnl_stream : public dnnl::impl::c_compatible { |
29 | dnnl_stream(dnnl::impl::engine_t *engine, unsigned flags) |
30 | : engine_(engine), flags_(flags) {} |
31 | virtual ~dnnl_stream() {} |
32 | |
33 | /** returns stream's engine */ |
34 | dnnl::impl::engine_t *engine() const { return engine_; } |
35 | template <typename tgt_engine_t> |
36 | tgt_engine_t *engine() const { |
37 | return dnnl::impl::utils::downcast<tgt_engine_t *>(engine_); |
38 | } |
39 | |
40 | /** returns stream's kind */ |
41 | unsigned flags() const { return flags_; } |
42 | |
43 | virtual dnnl::impl::status_t enqueue_primitive( |
44 | const primitive_iface_t *primitive_iface, |
45 | dnnl::impl::exec_ctx_t &ctx); |
46 | |
47 | /** blocks until all submitted primitives to the stream are completed */ |
48 | virtual dnnl::impl::status_t wait() = 0; |
49 | |
50 | virtual void before_exec_hook() {} |
51 | virtual void after_exec_hook() {} |
52 | |
53 | virtual dnnl::impl::status_t zero_pad(const dnnl::impl::memory_t *memory, |
54 | const dnnl::impl::exec_ctx_t &ctx); |
55 | |
56 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
57 | dnnl_stream(dnnl::impl::engine_t *engine, |
58 | dnnl::threadpool_interop::threadpool_iface *threadpool) |
59 | : dnnl_stream(engine, dnnl::impl::stream_flags::in_order) { |
60 | assert(engine->kind() == dnnl::impl::engine_kind::cpu); |
61 | threadpool_ = threadpool; |
62 | } |
63 | |
64 | dnnl::impl::status_t get_threadpool( |
65 | dnnl::threadpool_interop::threadpool_iface **threadpool) const { |
66 | using namespace dnnl::impl; |
67 | if (engine_->kind() != engine_kind::cpu) |
68 | return status::invalid_arguments; |
69 | *threadpool = threadpool_; |
70 | return status::success; |
71 | } |
72 | #endif |
73 | |
74 | protected: |
75 | dnnl::impl::engine_t *engine_; |
76 | unsigned flags_; |
77 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
78 | dnnl::threadpool_interop::threadpool_iface *threadpool_ = nullptr; |
79 | #endif |
80 | }; |
81 | |
82 | #endif |
83 | |
84 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
85 | |