1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/contrib/ethosu/cascader/proposal.h
22 * \brief Proposal object for the NPU cascader
23 */
24#ifndef TVM_CONTRIB_ETHOSU_CASCADER_PROPOSAL_H_
25#define TVM_CONTRIB_ETHOSU_CASCADER_PROPOSAL_H_
26
27#include <tvm/node/reflection.h>
28#include <tvm/runtime/object.h>
29
30#include <unordered_map>
31#include <unordered_set>
32#include <vector>
33
34#include "graph.h"
35#include "plan.h"
36#include "tensor_config.h"
37
38namespace tvm {
39namespace contrib {
40namespace ethosu {
41namespace cascader {
42
43using MemoryUsageMap = std::unordered_map<MemoryRegion, int, ObjectPtrHash, ObjectPtrEqual>;
44using TensorConfigMap = std::unordered_map<Tensor, TensorConfig, ObjectPtrHash, ObjectPtrEqual>;
45
46/*! \brief Node to represent a Proposal */
47class ProposalNode : public Object {
48 public:
49 void VisitAttrs(AttrVisitor* v);
50
51 /*! \return The CascaderGraph to which the Proposal applies */
52 const CascaderGraph GetGraph() const { return graph_; }
53 /*! \return The Parts which are covered by the Proposal */
54 const std::vector<Part> GetPartGroup() const { return part_group_; }
55 /*! \return The Plans used in the Proposal */
56 const std::vector<Plan> GetPlans() const { return plans_; }
57 /*! \return The TensorConfigs indexed by Tensor in the Proposal which aren't produced by a Plan */
58 const TensorConfigMap GetInputTensorConfigs() const { return input_tensor_configs_; }
59 /*! \return The MemoryRegion where cascading buffers should be homed */
60 const MemoryRegion GetCascadeRegion() const { return cascade_region_; }
61 /*! \return The memory required to execute the Proposal in the cascading MemoryRegion */
62 const int GetMemoryUsage() const { return memory_usage_; }
63 /*! \return The estimated cycles taken to execute the Proposal */
64 int GetCycles() const { return cycles_; }
65
66 static constexpr const char* _type_key = "contrib.ethosu.cascader.Proposal";
67 TVM_DECLARE_FINAL_OBJECT_INFO(ProposalNode, Object);
68
69 protected:
70 friend class Proposal;
71
72 /*! \brief The CascaderGraph to which the Proposal applies */
73 CascaderGraph graph_;
74 /*! \brief The Parts which are covered by the Proposal */
75 std::vector<Part> part_group_;
76 /*! \brief The Plans used in the Proposal */
77 std::vector<Plan> plans_;
78 /*! \brief The TensorConfigs indexed by Tensor in the Proposal which aren't produced by a Plan */
79 TensorConfigMap input_tensor_configs_;
80 /*! \brief The MemoryRegion where cascading buffers should be homed */
81 MemoryRegion cascade_region_;
82 /*! \brief The memory required to execute the Proposal in the cascading MemoryRegion */
83 int memory_usage_;
84 /*! \brief The estimated cycles taken to execute the Proposal */
85 int cycles_;
86};
87
88/*!
89 * \brief A class which describes how to schedule a CascaderGraph as a series of disjoint Plans.
90 */
91class Proposal : public ObjectRef {
92 public:
93 Proposal(const CascaderGraph& graph, const std::vector<Part>& part_group,
94 const std::vector<Plan>& plans, const TensorConfigMap& input_tensor_configs,
95 const MemoryRegion& cascade_region, int memory_usage, int cycles);
96
97 TVM_DEFINE_OBJECT_REF_METHODS(Proposal, ObjectRef, ProposalNode);
98};
99
100} // namespace cascader
101} // namespace ethosu
102} // namespace contrib
103} // namespace tvm
104
105#endif // TVM_CONTRIB_ETHOSU_CASCADER_PROPOSAL_H_
106