1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "attr_proto_util.h"
6
7namespace ONNX_NAMESPACE {
8
9#define ADD_BASIC_ATTR_IMPL(type, enumType, field) \
10 AttributeProto MakeAttribute(const std::string& attr_name, const type& value) { \
11 AttributeProto a; \
12 a.set_name(attr_name); \
13 a.set_type(enumType); \
14 a.set_##field(value); \
15 return a; \
16 }
17
18#define ADD_ATTR_IMPL(type, enumType, field) \
19 AttributeProto MakeAttribute(const std::string& attr_name, const type& value) { \
20 AttributeProto a; \
21 a.set_name(attr_name); \
22 a.set_type(enumType); \
23 *(a.mutable_##field()) = value; \
24 return a; \
25 }
26
27#define ADD_LIST_ATTR_IMPL(type, enumType, field) \
28 AttributeProto MakeAttribute(const std::string& attr_name, const std::vector<type>& values) { \
29 AttributeProto a; \
30 a.set_name(attr_name); \
31 a.set_type(enumType); \
32 for (const auto& val : values) { \
33 *(a.mutable_##field()->Add()) = val; \
34 } \
35 return a; \
36 }
37
38ADD_BASIC_ATTR_IMPL(float, AttributeProto_AttributeType_FLOAT, f)
39ADD_BASIC_ATTR_IMPL(int64_t, AttributeProto_AttributeType_INT, i)
40ADD_BASIC_ATTR_IMPL(std::string, AttributeProto_AttributeType_STRING, s)
41ADD_ATTR_IMPL(TensorProto, AttributeProto_AttributeType_TENSOR, t)
42ADD_ATTR_IMPL(GraphProto, AttributeProto_AttributeType_GRAPH, g)
43ADD_ATTR_IMPL(TypeProto, AttributeProto_AttributeType_TYPE_PROTO, tp)
44ADD_LIST_ATTR_IMPL(float, AttributeProto_AttributeType_FLOATS, floats)
45ADD_LIST_ATTR_IMPL(int64_t, AttributeProto_AttributeType_INTS, ints)
46ADD_LIST_ATTR_IMPL(std::string, AttributeProto_AttributeType_STRINGS, strings)
47ADD_LIST_ATTR_IMPL(TensorProto, AttributeProto_AttributeType_TENSORS, tensors)
48ADD_LIST_ATTR_IMPL(GraphProto, AttributeProto_AttributeType_GRAPHS, graphs)
49ADD_LIST_ATTR_IMPL(TypeProto, AttributeProto_AttributeType_TYPE_PROTOS, type_protos)
50
51AttributeProto MakeRefAttribute(const std::string& attr_name, AttributeProto_AttributeType type) {
52 return MakeRefAttribute(attr_name, attr_name, type);
53}
54
55AttributeProto MakeRefAttribute(
56 const std::string& attr_name,
57 const std::string& referred_attr_name,
58 AttributeProto_AttributeType type) {
59 AttributeProto a;
60 a.set_name(attr_name);
61 a.set_ref_attr_name(referred_attr_name);
62 a.set_type(type);
63 return a;
64}
65
66} // namespace ONNX_NAMESPACE
67