1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/mock_cost_estimator.cc
22 * \brief A mock CostEstimator to support unit tests.
23 */
24
25#ifndef TVM_RELAY_COLLAGE_MOCK_COST_ESTIMATOR_H_
26#define TVM_RELAY_COLLAGE_MOCK_COST_ESTIMATOR_H_
27
28#include <tvm/relay/function.h>
29
30#include "./cost.h"
31#include "./cost_estimator.h"
32
33namespace tvm {
34namespace relay {
35namespace collage {
36
37// Clang (15.0.3, at least) validly complains about `@main`, but it invalidly
38// complains even about `\c @main`.
39#if __clang__
40#pragma clang diagnostic push
41#pragma clang diagnostic ignored "-Wdocumentation-unknown-command"
42#endif
43
44/*!
45 * \brief A mock cost estimator which can determine the cost of a candidate based on both
46 * the candidate's target and the number of operator calls inside it.
47 *
48 * The help unit tests the estimator also ICHECK fails if:
49 * - the module has inlined "Compiler" functions
50 * - @main has non-tensor arguments (eg a tuple)
51 * - more than the given number of candidate modules are measured
52 *
53 * To support unit testing only.
54 */
55class MockCostEstimatorNode : public CostEstimatorNode {
56 public:
57 Cost Estimate(const IRModule& mod, const Target& target) const override;
58
59 static constexpr const char* _type_key = "relay.collage.MockCostEstimator";
60 TVM_DECLARE_FINAL_OBJECT_INFO(MockCostEstimatorNode, CostEstimatorNode);
61
62 protected:
63 /*!
64 * \brief Map from target kind name to assumed baseline cost (in integer seconds) for all
65 * operator calls.
66 */
67 Map<String, Integer> target_costs_;
68
69 /*!
70 * \brief If non-zero, the maximum number of distinct modules which may be estimated.
71 */
72 Integer max_estimates_;
73
74 /*! \brief Number of calls to Estimate. */
75 mutable size_t num_estimates_ = 0;
76
77 friend class MockCostEstimator;
78};
79#if __clang__
80#pragma clang diagnostic pop
81#endif
82
83class MockCostEstimator : public CostEstimator {
84 public:
85 explicit MockCostEstimator(Map<String, Integer> target_costs, Integer max_estimates = 0);
86
87 TVM_DEFINE_OBJECT_REF_METHODS(MockCostEstimator, CostEstimator, MockCostEstimatorNode);
88};
89
90} // namespace collage
91} // namespace relay
92} // namespace tvm
93
94#endif // TVM_RELAY_COLLAGE_MOCK_COST_ESTIMATOR_H_
95