1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/candidate_function_cache.h
22 * \brief A cache of the unique global symbol name and cost for partitioned functions.
23 */
24
25#ifndef TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_
26#define TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_
27
28#include <tvm/relay/function.h>
29
30#include <memory>
31#include <string>
32#include <unordered_map>
33#include <utility>
34
35#include "../transforms/compiler_function_utils.h"
36#include "./cost.h"
37#include "./name_supply.h"
38
39namespace tvm {
40namespace relay {
41namespace collage {
42
43/*!
44 * \brief A cache of the unique global symbol and cost for functions extracted to represent
45 * partitions. If two functions are structurally equal (which includes equality of their "Compiler"
46 * attributes) then they will share the same global symbol and estimated cost. We rely on the
47 * function's attributes to distinguish partitions which are structurally the same graph but
48 * intended for different targets.
49 */
50class CandidateFunctionCache : public transform::GlobalSymbolCache {
51 public:
52 explicit CandidateFunctionCache(std::shared_ptr<NameSupply> name_supply)
53 : name_supply_(std::move(name_supply)) {}
54
55 struct Entry {
56 GlobalVar global_symbol;
57 Cost cost = Cost::Unknown(); // Filled in when have estimated cost.
58
59 explicit Entry(GlobalVar global_symbol) : global_symbol(std::move(global_symbol)) {}
60 };
61
62 /*!
63 * \brief Returns the unique entry for \p function. If no such entry already exists, create it
64 * and assign it a unique global symbol name.
65 */
66 Entry& GetEntry(const std::string& label, const Function& function);
67
68 GlobalVar GetGlobalSymbol(const Function& function) final;
69
70 private:
71 std::shared_ptr<NameSupply> name_supply_;
72 std::unordered_map<Function, Entry, StructuralHash, StructuralEqual> cache_;
73};
74
75} // namespace collage
76} // namespace relay
77} // namespace tvm
78
79#endif // TVM_RELAY_COLLAGE_CANDIDATE_FUNCTION_CACHE_H_
80