1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_GPU_PRIMITIVE_ATTR_HPP |
18 | #define GPU_GPU_PRIMITIVE_ATTR_HPP |
19 | |
20 | #include "common/primitive_attr.hpp" |
21 | #include "common/serialization_stream.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | |
27 | struct gpu_primitive_attr_t : public primitive_attr_item_t { |
28 | gpu_primitive_attr_t(int threads_per_eu = 0) |
29 | : threads_per_eu_(threads_per_eu) {} |
30 | |
31 | std::unique_ptr<primitive_attr_item_t> clone() const override { |
32 | return utils::make_unique<gpu_primitive_attr_t>(threads_per_eu_); |
33 | } |
34 | |
35 | bool has_default_values() const override { return threads_per_eu_ == 0; } |
36 | |
37 | bool is_equal(const primitive_attr_item_t &other) const override { |
38 | auto *other_ptr = utils::downcast<const gpu_primitive_attr_t *>(&other); |
39 | return threads_per_eu_ == other_ptr->threads_per_eu_; |
40 | } |
41 | |
42 | size_t get_hash() const override { return threads_per_eu_; } |
43 | |
44 | void serialize(serialization_stream_t &stream) const override { |
45 | stream.write(&threads_per_eu_); |
46 | } |
47 | |
48 | int threads_per_eu() const { return threads_per_eu_; } |
49 | |
50 | private: |
51 | int threads_per_eu_; |
52 | }; |
53 | |
54 | } // namespace gpu |
55 | } // namespace impl |
56 | } // namespace dnnl |
57 | |
58 | #endif |
59 | |