1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // Version converter interface for ONNX models between different opset versions. |
6 | |
7 | #pragma once |
8 | |
9 | #include <stdlib.h> |
10 | #include <iostream> |
11 | #include <utility> |
12 | #include "onnx/common/ir.h" |
13 | #include "onnx/common/ir_pb_converter.h" |
14 | #include "onnx/common/stl_backports.h" |
15 | #include "onnx/defs/schema.h" |
16 | #include "onnx/proto_utils.h" |
17 | #include "onnx/version_converter/adapters/adapter.h" |
18 | |
19 | namespace ONNX_NAMESPACE { |
20 | namespace version_conversion { |
21 | |
22 | // TODO: Consider creating interface for this class. |
23 | class BaseVersionConverter { |
24 | // Schema for adapters: {<op_name>:{<from_domain>$<from_version>:{<to_domain> |
25 | // <to_version>: adapter}}} |
26 | protected: |
27 | std::unordered_map< |
28 | std::string, |
29 | std::unordered_map<std::string, std::unordered_map<std::string, std::unique_ptr<Adapter>>>> |
30 | adapters; |
31 | |
32 | // Map of All Versions of format {op_name: {domain: {version: schema}}} |
33 | std::unordered_map<std::string, std::unordered_map<std::string, std::map<int64_t, const OpSchema*>>> all_schemas; |
34 | |
35 | public: |
36 | BaseVersionConverter() = default; |
37 | |
38 | virtual ~BaseVersionConverter() = default; |
39 | |
40 | // adapter_lookup should be called in convert_version when the user would |
41 | // like to identify the proper registered adapter in the adapters map for |
42 | // a given Node from a certain version to another. It should only be called |
43 | // when the user knows that an adapter should exist for the given context. |
44 | const Adapter& adapter_lookup(const Node* op, const OpSetID& initial_version, const OpSetID& target_version) const { |
45 | const std::string op_name = op->kind().toString(); |
46 | const std::string initial = initial_version.toString(); |
47 | const std::string target = target_version.toString(); |
48 | // Find appropriate adapter in adapters map for provided initial and target versions |
49 | // TODO: Consider abstracting elements of this that are specific to |
50 | // DefaultConverter to separate methods here and maintain the procedure in Base Converter |
51 | const auto op_adapters = adapters.find(op_name); |
52 | if (op_adapters != adapters.end()) { |
53 | // If we're adapting downwards, we just want to find the one downwards |
54 | // adapter implemented for initial_version. If we're adapting upwards, we |
55 | // want to actually use the SinceVersion value for the given op. |
56 | const auto target_map = op_adapters->second.find(initial); |
57 | if (target_map != op_adapters->second.end()) { |
58 | // Either adapt from SinceVersion or Incompatible Breaking Change |
59 | const auto adapter_ptr = target_map->second.find(target); |
60 | if (adapter_ptr != target_map->second.end()) { |
61 | return *(adapter_ptr->second); |
62 | } else { |
63 | ONNX_ASSERTM(false, "No Adapter To Version %s for %s" , target.c_str(), op_name.c_str()); |
64 | } |
65 | } else { |
66 | ONNX_ASSERTM(false, "No Adapter From Version %s for %s" , initial.c_str(), op_name.c_str()); |
67 | } |
68 | } else { |
69 | // No adapters exist for the given op |
70 | ONNX_ASSERTM(false, "No Adapter For %s" , op_name.c_str()); |
71 | } |
72 | } |
73 | |
74 | virtual ModelProto |
75 | convert_version(const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version) const = 0; |
76 | |
77 | void registerAdapter(std::unique_ptr<Adapter> a_ptr) { |
78 | const OpSetID& iv = a_ptr->initial_version(); |
79 | const OpSetID& tv = a_ptr->target_version(); |
80 | adapters[a_ptr->name()][iv.toString()][tv.toString()] = std::move(a_ptr); |
81 | } |
82 | |
83 | void registerAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer) { |
84 | registerAdapter(make_unique<GenericAdapter>(op, from, to, transformer)); |
85 | } |
86 | }; |
87 | |
88 | } // namespace version_conversion |
89 | } // namespace ONNX_NAMESPACE |
90 | |