1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "proposal.h"
20
21#include <tvm/runtime/container/array.h>
22#include <tvm/runtime/container/map.h>
23#include <tvm/runtime/object.h>
24#include <tvm/runtime/registry.h>
25
26#include <algorithm>
27#include <utility>
28#include <vector>
29
30#include "plan.h"
31
32namespace tvm {
33namespace contrib {
34namespace ethosu {
35namespace cascader {
36
37void ProposalNode::VisitAttrs(AttrVisitor* v) {
38 v->Visit("_graph", &graph_);
39 Array<Part> tmp_parts(part_group_.begin(), part_group_.end());
40 v->Visit("_part_group", &tmp_parts);
41 Array<Plan> tmp_plans(plans_.begin(), plans_.end());
42 v->Visit("_plans", &tmp_plans);
43 Map<Tensor, TensorConfig> tmp_tmap(input_tensor_configs_.begin(), input_tensor_configs_.end());
44 v->Visit("_input_tensor_configs", &tmp_tmap);
45 v->Visit("_cascade_region", &cascade_region_);
46 v->Visit("_memory_usage", &memory_usage_);
47 v->Visit("_cycles", &cycles_);
48}
49
50Proposal::Proposal(const CascaderGraph& graph, const std::vector<Part>& part_group,
51 const std::vector<Plan>& plans, const TensorConfigMap& input_tensor_configs,
52 const MemoryRegion& cascade_region, int memory_usage, int cycles) {
53 auto n = make_object<ProposalNode>();
54 n->graph_ = std::move(graph);
55 n->part_group_ = std::move(part_group);
56 std::sort(n->part_group_.begin(), n->part_group_.end());
57 n->plans_ = std::move(plans);
58 n->input_tensor_configs_ = std::move(input_tensor_configs);
59 n->cascade_region_ = std::move(cascade_region);
60 n->memory_usage_ = std::move(memory_usage);
61 n->cycles_ = cycles;
62 data_ = std::move(n);
63}
64
65TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.Proposal")
66 .set_body_typed([](CascaderGraph graph, Array<Part> part_group, Array<Plan> plans,
67 Map<Tensor, TensorConfig> input_tensor_configs, MemoryRegion cascade_region,
68 int memory_usage, int cycles) {
69 std::vector<Part> spart_group(part_group.begin(), part_group.end());
70 std::vector<Plan> vplans(plans.begin(), plans.end());
71 TensorConfigMap minput_tensor_configs(input_tensor_configs.begin(),
72 input_tensor_configs.end());
73 return Proposal(graph, spart_group, vplans, minput_tensor_configs, cascade_region,
74 memory_usage, cycles);
75 });
76
77TVM_REGISTER_NODE_TYPE(ProposalNode);
78
79} // namespace cascader
80} // namespace ethosu
81} // namespace contrib
82} // namespace tvm
83