1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_GPU_PRIMITIVE_ATTR_HPP
18#define GPU_GPU_PRIMITIVE_ATTR_HPP
19
20#include "common/primitive_attr.hpp"
21#include "common/serialization_stream.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26
27struct gpu_primitive_attr_t : public primitive_attr_item_t {
28 gpu_primitive_attr_t(int threads_per_eu = 0)
29 : threads_per_eu_(threads_per_eu) {}
30
31 std::unique_ptr<primitive_attr_item_t> clone() const override {
32 return utils::make_unique<gpu_primitive_attr_t>(threads_per_eu_);
33 }
34
35 bool has_default_values() const override { return threads_per_eu_ == 0; }
36
37 bool is_equal(const primitive_attr_item_t &other) const override {
38 auto *other_ptr = utils::downcast<const gpu_primitive_attr_t *>(&other);
39 return threads_per_eu_ == other_ptr->threads_per_eu_;
40 }
41
42 size_t get_hash() const override { return threads_per_eu_; }
43
44 void serialize(serialization_stream_t &stream) const override {
45 stream.write(&threads_per_eu_);
46 }
47
48 int threads_per_eu() const { return threads_per_eu_; }
49
50private:
51 int threads_per_eu_;
52};
53
54} // namespace gpu
55} // namespace impl
56} // namespace dnnl
57
58#endif
59