1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/metadata_utils.h
22 * \brief Declares utilty functions and classes for emitting metadata.
23 */
24#ifndef TVM_TARGET_METADATA_UTILS_H_
25#define TVM_TARGET_METADATA_UTILS_H_
26
27#include <tvm/runtime/data_type.h>
28#include <tvm/runtime/ndarray.h>
29#include <tvm/runtime/object.h>
30
31#include <string>
32#include <tuple>
33#include <unordered_map>
34#include <vector>
35
36#include "metadata.h"
37
38namespace tvm {
39namespace codegen {
40namespace metadata {
41
42/*!
43 * \brief Construct a unique string "address" for a struct member from a vector of pieces.
44 *
45 * In codegen, it is frequently necessary to assemble a C-style identifier for an
46 * otherwise-anonymous member of Metadata. For instance, suppose Metadata declares an array:
47 * struct TVMMetadata {
48 * int64_t* shape;
49 * };
50 *
51 * In order to properly initialize this struct, the array must be declared separately with a global
52 * name. This function produces such a name, here termed "address."
53 *
54 * \param parts A vector of pieces, typically the struct member names which identify the path to
55 * this member.
56 * \return The joined pieces.
57 */
58std::string AddressFromParts(const std::vector<std::string>& parts);
59
60/*!
61 * \brief A prefix in metadata symbol names.
62 * This prefix is typically given to AddressFromParts as the 0th item in parts.
63 */
64static constexpr const char* kMetadataGlobalSymbol = "kTvmgenMetadata";
65
66/*!
67 * \brief Post-order traverse metadata to discover arrays which need to be forward-defined.
68 */
69class DiscoverArraysVisitor : public AttrVisitor {
70 public:
71 /*! \brief Models a single array discovered in this visitor.
72 * Conatains two fields:
73 * 0. An address which uniquely identifies the array in this Metadata instance.
74 * 1. The discovered MetadataArray.
75 */
76 using DiscoveredArray = std::tuple<std::string, runtime::metadata::MetadataArray>;
77 explicit DiscoverArraysVisitor(std::vector<DiscoveredArray>* queue);
78
79 void Visit(const char* key, double* value) final;
80 void Visit(const char* key, int64_t* value) final;
81 void Visit(const char* key, uint64_t* value) final;
82 void Visit(const char* key, int* value) final;
83 void Visit(const char* key, bool* value) final;
84 void Visit(const char* key, std::string* value) final;
85 void Visit(const char* key, DataType* value) final;
86 void Visit(const char* key, runtime::NDArray* value) final;
87 void Visit(const char* key, void** value) final;
88
89 void Visit(const char* key, ObjectRef* value) final;
90
91 private:
92 /*! \brief The queue to be filled with discovered arrays. */
93 std::vector<DiscoveredArray>* queue_;
94
95 /*! \brief Tracks the preceding address pieces. */
96 std::vector<std::string> address_parts_;
97};
98
99/*!
100 * \brief Post-order traverse Metadata to discover all complex types which need to be
101 * forward-defined. This visitor finds one defined() MetadataBase instance for each unique subclass
102 * present inside Metadata in the order in which the subclass was first discovered.
103 */
104class DiscoverComplexTypesVisitor : public AttrVisitor {
105 public:
106 /*! \brief Construct a new instance.
107 * \param queue An ordered map which holds the
108 */
109 explicit DiscoverComplexTypesVisitor(std::vector<runtime::metadata::MetadataBase>* queue)
110 : queue_{queue} {
111 int i = 0;
112 for (auto q : *queue) {
113 type_key_to_position_[q->GetTypeKey()] = i++;
114 }
115 }
116
117 void Visit(const char* key, double* value) final;
118 void Visit(const char* key, int64_t* value) final;
119 void Visit(const char* key, uint64_t* value) final;
120 void Visit(const char* key, int* value) final;
121 void Visit(const char* key, bool* value) final;
122 void Visit(const char* key, std::string* value) final;
123 void Visit(const char* key, DataType* value) final;
124 void Visit(const char* key, runtime::NDArray* value) final;
125 void Visit(const char* key, void** value) final;
126
127 void Visit(const char* key, ObjectRef* value) final;
128
129 void Discover(runtime::metadata::MetadataBase metadata);
130
131 private:
132 bool DiscoverType(std::string type_key);
133
134 void DiscoverInstance(runtime::metadata::MetadataBase md);
135
136 std::vector<runtime::metadata::MetadataBase>* queue_;
137
138 /*! \brief map type_index to index in queue_. */
139 std::unordered_map<std::string, int> type_key_to_position_;
140};
141
142} // namespace metadata
143} // namespace codegen
144} // namespace tvm
145
146#endif // TVM_TARGET_METADATA_UTILS_H_
147