1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_COMPUTE_COMPUTE_STREAM_HPP |
18 | #define GPU_COMPUTE_COMPUTE_STREAM_HPP |
19 | |
20 | #include <memory> |
21 | |
22 | #include "common/stream.hpp" |
23 | #include "gpu/compute/kernel.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace gpu { |
28 | namespace compute { |
29 | |
30 | class nd_range_t; |
31 | class kernel_arg_list_t; |
32 | |
33 | class compute_stream_t : public stream_t { |
34 | public: |
35 | using stream_t::stream_t; |
36 | |
37 | virtual status_t copy(const memory_storage_t &src, |
38 | const memory_storage_t &dst, size_t size) |
39 | = 0; |
40 | virtual status_t fill( |
41 | const memory_storage_t &dst, uint8_t pattern, size_t size) |
42 | = 0; |
43 | virtual status_t parallel_for(const nd_range_t &range, |
44 | const kernel_t &kernel, const kernel_arg_list_t &arg_list) { |
45 | return kernel.parallel_for(*this, range, arg_list); |
46 | } |
47 | |
48 | virtual status_t parallel_for( |
49 | const kernel_t &kernel, const std::function<void(void *)> &cgf) { |
50 | return kernel.parallel_for(*this, cgf); |
51 | } |
52 | |
53 | protected: |
54 | bool has_zero_pad_primitive() const { |
55 | return engine()->kind() == dnnl_gpu; |
56 | }; |
57 | |
58 | status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) override; |
59 | }; |
60 | |
61 | } // namespace compute |
62 | } // namespace gpu |
63 | } // namespace impl |
64 | } // namespace dnnl |
65 | |
66 | #endif |
67 | |