1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// Version converter interface for ONNX models between different opset versions.
6
7#pragma once
8
9#include <stdlib.h>
10#include <iostream>
11#include <utility>
12#include "onnx/common/ir.h"
13#include "onnx/common/ir_pb_converter.h"
14#include "onnx/common/stl_backports.h"
15#include "onnx/defs/schema.h"
16#include "onnx/proto_utils.h"
17#include "onnx/version_converter/adapters/adapter.h"
18
19namespace ONNX_NAMESPACE {
20namespace version_conversion {
21
22// TODO: Consider creating interface for this class.
23class BaseVersionConverter {
24 // Schema for adapters: {<op_name>:{<from_domain>$<from_version>:{<to_domain>
25 // <to_version>: adapter}}}
26 protected:
27 std::unordered_map<
28 std::string,
29 std::unordered_map<std::string, std::unordered_map<std::string, std::unique_ptr<Adapter>>>>
30 adapters;
31
32 // Map of All Versions of format {op_name: {domain: {version: schema}}}
33 std::unordered_map<std::string, std::unordered_map<std::string, std::map<int64_t, const OpSchema*>>> all_schemas;
34
35 public:
36 BaseVersionConverter() = default;
37
38 virtual ~BaseVersionConverter() = default;
39
40 // adapter_lookup should be called in convert_version when the user would
41 // like to identify the proper registered adapter in the adapters map for
42 // a given Node from a certain version to another. It should only be called
43 // when the user knows that an adapter should exist for the given context.
44 const Adapter& adapter_lookup(const Node* op, const OpSetID& initial_version, const OpSetID& target_version) const {
45 const std::string op_name = op->kind().toString();
46 const std::string initial = initial_version.toString();
47 const std::string target = target_version.toString();
48 // Find appropriate adapter in adapters map for provided initial and target versions
49 // TODO: Consider abstracting elements of this that are specific to
50 // DefaultConverter to separate methods here and maintain the procedure in Base Converter
51 const auto op_adapters = adapters.find(op_name);
52 if (op_adapters != adapters.end()) {
53 // If we're adapting downwards, we just want to find the one downwards
54 // adapter implemented for initial_version. If we're adapting upwards, we
55 // want to actually use the SinceVersion value for the given op.
56 const auto target_map = op_adapters->second.find(initial);
57 if (target_map != op_adapters->second.end()) {
58 // Either adapt from SinceVersion or Incompatible Breaking Change
59 const auto adapter_ptr = target_map->second.find(target);
60 if (adapter_ptr != target_map->second.end()) {
61 return *(adapter_ptr->second);
62 } else {
63 ONNX_ASSERTM(false, "No Adapter To Version %s for %s", target.c_str(), op_name.c_str());
64 }
65 } else {
66 ONNX_ASSERTM(false, "No Adapter From Version %s for %s", initial.c_str(), op_name.c_str());
67 }
68 } else {
69 // No adapters exist for the given op
70 ONNX_ASSERTM(false, "No Adapter For %s", op_name.c_str());
71 }
72 }
73
74 virtual ModelProto
75 convert_version(const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version) const = 0;
76
77 void registerAdapter(std::unique_ptr<Adapter> a_ptr) {
78 const OpSetID& iv = a_ptr->initial_version();
79 const OpSetID& tv = a_ptr->target_version();
80 adapters[a_ptr->name()][iv.toString()][tv.toString()] = std::move(a_ptr);
81 }
82
83 void registerAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer) {
84 registerAdapter(make_unique<GenericAdapter>(op, from, to, transformer));
85 }
86};
87
88} // namespace version_conversion
89} // namespace ONNX_NAMESPACE
90