1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "attr_proto_util.h" |
6 | |
7 | namespace ONNX_NAMESPACE { |
8 | |
9 | #define ADD_BASIC_ATTR_IMPL(type, enumType, field) \ |
10 | AttributeProto MakeAttribute(const std::string& attr_name, const type& value) { \ |
11 | AttributeProto a; \ |
12 | a.set_name(attr_name); \ |
13 | a.set_type(enumType); \ |
14 | a.set_##field(value); \ |
15 | return a; \ |
16 | } |
17 | |
18 | #define ADD_ATTR_IMPL(type, enumType, field) \ |
19 | AttributeProto MakeAttribute(const std::string& attr_name, const type& value) { \ |
20 | AttributeProto a; \ |
21 | a.set_name(attr_name); \ |
22 | a.set_type(enumType); \ |
23 | *(a.mutable_##field()) = value; \ |
24 | return a; \ |
25 | } |
26 | |
27 | #define ADD_LIST_ATTR_IMPL(type, enumType, field) \ |
28 | AttributeProto MakeAttribute(const std::string& attr_name, const std::vector<type>& values) { \ |
29 | AttributeProto a; \ |
30 | a.set_name(attr_name); \ |
31 | a.set_type(enumType); \ |
32 | for (const auto& val : values) { \ |
33 | *(a.mutable_##field()->Add()) = val; \ |
34 | } \ |
35 | return a; \ |
36 | } |
37 | |
38 | ADD_BASIC_ATTR_IMPL(float, AttributeProto_AttributeType_FLOAT, f) |
39 | ADD_BASIC_ATTR_IMPL(int64_t, AttributeProto_AttributeType_INT, i) |
40 | ADD_BASIC_ATTR_IMPL(std::string, AttributeProto_AttributeType_STRING, s) |
41 | ADD_ATTR_IMPL(TensorProto, AttributeProto_AttributeType_TENSOR, t) |
42 | ADD_ATTR_IMPL(GraphProto, AttributeProto_AttributeType_GRAPH, g) |
43 | ADD_ATTR_IMPL(TypeProto, AttributeProto_AttributeType_TYPE_PROTO, tp) |
44 | ADD_LIST_ATTR_IMPL(float, AttributeProto_AttributeType_FLOATS, floats) |
45 | ADD_LIST_ATTR_IMPL(int64_t, AttributeProto_AttributeType_INTS, ints) |
46 | ADD_LIST_ATTR_IMPL(std::string, AttributeProto_AttributeType_STRINGS, strings) |
47 | ADD_LIST_ATTR_IMPL(TensorProto, AttributeProto_AttributeType_TENSORS, tensors) |
48 | ADD_LIST_ATTR_IMPL(GraphProto, AttributeProto_AttributeType_GRAPHS, graphs) |
49 | ADD_LIST_ATTR_IMPL(TypeProto, AttributeProto_AttributeType_TYPE_PROTOS, type_protos) |
50 | |
51 | AttributeProto MakeRefAttribute(const std::string& attr_name, AttributeProto_AttributeType type) { |
52 | return MakeRefAttribute(attr_name, attr_name, type); |
53 | } |
54 | |
55 | AttributeProto MakeRefAttribute( |
56 | const std::string& attr_name, |
57 | const std::string& referred_attr_name, |
58 | AttributeProto_AttributeType type) { |
59 | AttributeProto a; |
60 | a.set_name(attr_name); |
61 | a.set_ref_attr_name(referred_attr_name); |
62 | a.set_type(type); |
63 | return a; |
64 | } |
65 | |
66 | } // namespace ONNX_NAMESPACE |
67 | |