1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "plan.h"
20
21#include <tvm/runtime/container/array.h>
22#include <tvm/runtime/container/map.h>
23#include <tvm/runtime/object.h>
24#include <tvm/runtime/registry.h>
25
26#include <algorithm>
27#include <utility>
28#include <vector>
29
30#include "graph.h"
31#include "tensor_config.h"
32
33namespace tvm {
34namespace contrib {
35namespace ethosu {
36namespace cascader {
37
38void PlanNode::VisitAttrs(AttrVisitor* v) {
39 Array<TensorConfig> tmp_arr(tensor_configs_);
40 v->Visit("_tensor_configs", &tmp_arr);
41 Array<TensorConfig> tmp_cfgs(open_configs_.begin(), open_configs_.end());
42 v->Visit("_open_configs", &tmp_cfgs);
43 v->Visit("_output_config", &output_config_);
44 Array<Part> tmp_parts(part_group_.begin(), part_group_.end());
45 v->Visit("_part_group", &tmp_parts);
46 v->Visit("_interior_region", &interior_region_);
47 v->Visit("_memory_usage", &memory_usage_);
48 v->Visit("_cycles", &cycles_);
49}
50
51Plan::Plan(const std::vector<TensorConfig>& tensor_configs,
52 const std::vector<TensorConfig>& open_configs, const TensorConfig& output_config,
53 const std::vector<Part>& part_group, const MemoryRegion& interior_region,
54 int memory_usage, int cycles) {
55 auto n = make_object<PlanNode>();
56 n->tensor_configs_ = std::move(tensor_configs);
57 n->open_configs_ = std::move(open_configs);
58 n->output_config_ = std::move(output_config);
59 n->part_group_ = std::move(part_group);
60 n->interior_region_ = interior_region;
61 n->memory_usage_ = memory_usage;
62 n->cycles_ = cycles;
63 data_ = std::move(n);
64}
65
66Plan Plan::Merge(const Plan& other) const {
67 auto n = make_object<PlanNode>(*this->operator->());
68 n->tensor_configs_.insert(n->tensor_configs_.end(), other->tensor_configs_.begin(),
69 other->tensor_configs_.end());
70 n->open_configs_.erase(
71 std::remove(n->open_configs_.begin(), n->open_configs_.end(), (*this)->output_config_),
72 n->open_configs_.end());
73 for (const auto& config : other->open_configs_) {
74 if (config->GetTensor() != (*this)->output_config_->GetTensor()) {
75 n->open_configs_.push_back(config);
76 }
77 }
78 n->output_config_ = other->output_config_;
79 n->part_group_.insert(n->part_group_.end(), other->part_group_.begin(), other->part_group_.end());
80 std::sort(n->part_group_.begin(), n->part_group_.end());
81 n->memory_usage_ += other->memory_usage_;
82 n->cycles_ += other->cycles_;
83 return Plan(n);
84}
85
86TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.Plan")
87 .set_body_typed([](Array<TensorConfig> tensor_configs, Array<TensorConfig> open_configs,
88 TensorConfig output_config, Array<Part> part_group,
89 MemoryRegion interior_region, int memory_usage, int cycles) {
90 std::vector<TensorConfig> vtensor_configs(tensor_configs.begin(), tensor_configs.end());
91 std::vector<TensorConfig> sopen_configs(open_configs.begin(), open_configs.end());
92 std::vector<Part> spart_group(part_group.begin(), part_group.end());
93 return Plan(vtensor_configs, sopen_configs, output_config, spart_group, interior_region,
94 memory_usage, cycles);
95 });
96
97TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PlanMerge").set_body_method(&Plan::Merge);
98
99TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PlanMergeBenchmark")
100 .set_body_typed([](Plan plan, Plan other, int repeats) {
101 for (int i = 0; i < repeats; i++) {
102 plan.Merge(other);
103 }
104 return plan.Merge(other);
105 });
106
107TVM_REGISTER_NODE_TYPE(PlanNode);
108
109} // namespace cascader
110} // namespace ethosu
111} // namespace contrib
112} // namespace tvm
113