1#pragma once
2#include <c10/util/Exception.h>
3#include <torch/csrc/autograd/variable.h>
4#include <torch/csrc/jit/api/object.h>
5#include <torch/csrc/jit/frontend/source_range.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/ir/named_value.h>
8#include <torch/csrc/jit/runtime/argument_spec.h>
9#include <torch/csrc/jit/runtime/graph_executor.h>
10
11#include <torch/csrc/Export.h>
12#include <torch/csrc/api/include/torch/ordered_dict.h>
13#include <torch/csrc/jit/api/compilation_unit.h>
14#include <torch/csrc/utils/memory.h>
15
16#include <ATen/core/function_schema.h>
17#include <ATen/core/qualified_name.h>
18#include <c10/util/ArrayRef.h>
19#include <c10/util/Optional.h>
20#include <c10/util/irange.h>
21
22#include <functional>
23#include <memory>
24#include <mutex>
25#include <ostream>
26#include <string>
27#include <unordered_map>
28#include <unordered_set>
29#include <utility>
30#include <vector>
31
32// This file contains classes which assist in desugaring Python style
33// modules and their methods into flattened graphs which don't have any
34// function calls.
35
36namespace torch {
37namespace jit {
38
39using ::c10::Argument;
40using ::c10::FunctionSchema;
41using ::c10::QualifiedName;
42// Map which stores filename to content.
43using ExtraFilesMap = std::unordered_map<std::string, std::string>;
44
45using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>;
46
47struct Module;
48
49template <typename T>
50struct slot_list_impl;
51
52template <typename T>
53struct Named {
54 std::string name;
55 T value;
56};
57
58using NameModule = Named<Module>;
59using NameValue = Named<IValue>;
60using NameTensor = Named<at::Tensor>;
61
62namespace detail {
63struct TORCH_API ModulePolicy;
64struct TORCH_API ParameterPolicy;
65struct TORCH_API AttributePolicy;
66struct TORCH_API BufferPolicy;
67template <typename P>
68struct NamedPolicy;
69} // namespace detail
70
71using module_list = slot_list_impl<detail::ModulePolicy>;
72using named_module_list =
73 slot_list_impl<detail::NamedPolicy<detail::ModulePolicy>>;
74
75using parameter_list = slot_list_impl<detail::ParameterPolicy>;
76using named_parameter_list =
77 slot_list_impl<detail::NamedPolicy<detail::ParameterPolicy>>;
78
79using attribute_list = slot_list_impl<detail::AttributePolicy>;
80using named_attribute_list =
81 slot_list_impl<detail::NamedPolicy<detail::AttributePolicy>>;
82
83using buffer_list = slot_list_impl<detail::BufferPolicy>;
84using named_buffer_list =
85 slot_list_impl<detail::NamedPolicy<detail::BufferPolicy>>;
86
87using ModuleLookup = std::function<Module(const std::vector<std::string>&)>;
88
89struct TORCH_API Module : public Object {
90 explicit Module(c10::QualifiedName class_name);
91 Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
92 Module() = default;
93 Module(
94 c10::QualifiedName,
95 std::shared_ptr<CompilationUnit> cu,
96 bool shouldMangle = false);
97 Module(ModulePtr module_value) : Object(std::move(module_value)) {}
98 ~Module() = default;
99
100 void set_optimized(bool o) {
101 TORCH_WARN(
102 "Module::set_optimized() is deprecated and has no effect. "
103 "Please use setGraphExecutorOptimize()");
104 }
105
106 bool is_optimized() const {
107 TORCH_WARN(
108 "Module::is_optimized() is deprecated and always returns true. "
109 "Please use getGraphExecutorOptimize()");
110 return true;
111 }
112
113 IValue forward(std::vector<IValue> inputs, const Kwargs& kwargs = Kwargs()) {
114 return get_method("forward")(std::move(inputs), kwargs);
115 }
116
117 // In script modules, buffers are Tensors attribute that are _not_ registered
118 // as parameters. This is different than in nn.Module where there is a special
119 // register_buffer method. With this simplification, we only need to track
120 // whether a slot is a parameter to be able to classify it.
121 void register_buffer(const std::string& name, at::Tensor v) {
122 bool is_param = false;
123 bool is_buffer = true;
124 type()->addOrCheckAttribute(name, TensorType::get(), is_param, is_buffer);
125 _ivalue()->setAttr(name, std::move(v));
126 }
127
128 void register_parameter(
129 const std::string& name,
130 at::Tensor v,
131 bool is_buffer) {
132 type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer, is_buffer);
133 _ivalue()->setAttr(name, std::move(v));
134 }
135
136 void register_attribute(
137 const std::string& name,
138 const TypePtr& t,
139 IValue v,
140 bool is_param = false,
141 bool is_buffer = false) {
142 type()->addOrCheckAttribute(name, t, is_param, is_buffer);
143 _ivalue()->setAttr(name, std::move(v));
144 }
145
146 void register_module(const std::string& name, const Module& module) {
147 type()->addOrCheckAttribute(name, module.type());
148 _ivalue()->setAttr(name, module._ivalue());
149 }
150
151 void apply(const std::function<void(Module&)>& fn);
152
153 buffer_list buffers(bool recurse = true) const;
154 named_buffer_list named_buffers(bool recurse = true) const;
155
156 module_list children() const; // direct modules
157 named_module_list named_children() const;
158 module_list modules() const; // all modules, including this one, recursively
159 named_module_list named_modules() const;
160
161 // all tensors involved in gradient optimization
162 parameter_list parameters(bool recurse = true) const;
163 named_parameter_list named_parameters(bool recurse = true) const;
164
165 // all members of the object, similar to iterating over dir(obj) in python
166 attribute_list attributes(bool recurse = true) const;
167 named_attribute_list named_attributes(bool recurse = true) const;
168
169 void dump(
170 bool print_method_bodies,
171 bool print_attr_values,
172 bool print_param_values) const;
173
174 std::string dump_to_str(
175 bool print_method_bodies,
176 bool print_attr_values,
177 bool print_param_values) const;
178
179 /// Enables "training" mode.
180 void train(bool on = true);
181 /// Calls train(false) to enable "eval" mode.
182 /// Do not override this method, override `train()` instead.
183 void eval() {
184 train(/*on=*/false);
185 }
186 /// True if the module is in training mode.
187 bool is_training() const {
188 return attr("training", true).toBool();
189 }
190
191 /// Recursively casts all parameters to the given `dtype` and `device`.
192 ///
193 /// If `non_blocking` is true and the source is in pinned memory and
194 /// destination is on the GPU or vice versa, the copy is performed
195 /// asynchronously with respect to the host. Otherwise, the argument has no
196 /// effect.
197 void to(at::Device device, at::ScalarType dtype, bool non_blocking = false);
198
199 /// Recursively casts all parameters to the given dtype.
200 ///
201 /// If `non_blocking` is true and the source is in pinned memory and
202 /// destination is on the GPU or vice versa, the copy is performed
203 /// asynchronously with respect to the host. Otherwise, the argument has no
204 /// effect.
205 void to(at::ScalarType dtype, bool non_blocking = false);
206
207 /// Recursively moves all parameters to the given device.
208 ///
209 /// If `non_blocking` is true and the source is in pinned memory and
210 /// destination is on the GPU or vice versa, the copy is performed
211 /// asynchronously with respect to the host. Otherwise, the argument has no
212 /// effect.
213 void to(at::Device device, bool non_blocking = false);
214
215 void save(
216 std::ostream& out,
217 const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
218
219 void save(
220 const std::string& filename,
221 const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
222
223 void _save_for_mobile(
224 std::ostream& out,
225 const ExtraFilesMap& extra_files = ExtraFilesMap(),
226 bool save_mobile_debug_info = false,
227 bool use_flatbuffer = false) const;
228
229 void _save_for_mobile(
230 const std::string& filename,
231 const ExtraFilesMap& extra_files = ExtraFilesMap(),
232 bool save_mobile_debug_info = false,
233 bool use_flatbuffer = false) const;
234
235 Module copy() const;
236
237 Module deepcopy() const;
238
239 // Clones both the underlying `ClassType` and the module instance(data), this
240 // function creates a new `ClassType` and returns a new instance that has the
241 // same data as the current instance but with the new type, shared ClassType
242 // will be preserved as well
243 Module clone(bool inplace = false) const;
244
245 // Clones both the underlying `ClassType` and the module instance(data), this
246 // function creates a new `ClassType` and returns a new instance that has the
247 // same data as the current instance but with the new type, shared ClassType
248 // will be preserved as well. Also allows the caller to specify a set of
249 // method and attribute names to not clone.
250 Module clone(
251 bool inplace,
252 const std::unordered_set<std::string>& ignored_method,
253 const std::unordered_set<std::string>& ignored_attributes) const;
254
255 void clone_method(const Module& orig, const std::string& name);
256
257 IValue operator()(std::vector<IValue> inputs);
258
259 template <typename... Types>
260 IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
261 return create_class(name, {IValue(std::forward<Types>(args))...});
262 }
263
264 IValue create_class(const c10::QualifiedName& name, Stack stack) const;
265
266 inline bool operator==(const Module& y) const noexcept {
267 return _ivalue() == y._ivalue();
268 }
269
270 void set_delete_memory(std::shared_ptr<char> delete_mem) {
271 mem_to_delete_ = delete_mem;
272 }
273
274 // A set of functions to maintain input shapes through torch.jit.save and
275 // torch.jit.load. It only works on tensors and lists/dicts of tensors
276 // because tracing is only supported by these types.
277 void store_traced_inputs(std::string func_name, std::vector<IValue> inputs) {
278 if (inputs.size() == 0) {
279 return;
280 }
281 auto c10_inputs = c10::impl::GenericList(AnyType::get());
282 for (const IValue& value : inputs) {
283 // Not checking whether this is traceable type as that is already checked
284 // higher up in the stack and changing that would require a larger
285 // restructuring.
286 c10_inputs.push_back(value);
287 }
288 traced_inputs_.insert_or_assign(func_name, c10_inputs);
289 }
290
291 c10::Dict<std::string, c10::impl::GenericList> retrieve_traced_inputs()
292 const {
293 return traced_inputs_;
294 }
295
296 private:
297 Module clone_impl(
298 std::unordered_map<TypePtr, TypePtr>& type_remap,
299 bool inplace,
300 IValue::HashAliasedIValueMap memo,
301 const std::unordered_set<std::string>& ignored_methods,
302 const std::unordered_set<std::string>& ignored_attributes) const;
303
304 void clone_method(
305 const Module& orig,
306 const Function& method,
307 const std::unordered_map<TypePtr, TypePtr>& type_remap);
308
309 c10::QualifiedName getNameForMethod(std::string basename) const {
310 return QualifiedName(*type()->name(), std::move(basename));
311 }
312
313 void to_impl(
314 const c10::optional<at::Device>& device,
315 const c10::optional<at::ScalarType>& dtype,
316 bool non_blocking);
317
318 // Extra handle for the module to delete when itself is deleted
319 std::shared_ptr<char> mem_to_delete_;
320
321 // Map of function names to the traced inputs that they have been traced with
322 c10::Dict<std::string, c10::impl::GenericList> traced_inputs_;
323};
324
325// C++ equivalent api of `torch.jit.freeze`. See documentation there for
326// details.
327TORCH_API Module freeze(
328 const Module& module,
329 c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
330 bool optimize_numerics = true);
331
332// C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation
333// there for details.
334TORCH_API Module optimize_for_inference(
335 Module& module,
336 const std::vector<std::string>& other_methods = {});
337
338enum class FusionBehavior { STATIC, DYNAMIC };
339
340using FusionStrategy = std::vector<std::pair<FusionBehavior, size_t>>;
341// clang-format off
342/*
343Sets the type and number of specializations that can occur during fusion.
344
345Usage: provide a list of pairs (type, depth) where type is one of STATIC or DYNAMIC
346and depth is an integer.
347
348Behavior - static vs dynamic:
349 In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined
350 based on some initial profiling runs.
351 In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple
352 shapes are possible.
353
354In both cases, we also recompile on new striding behavior, device, or dtype.
355
356Behavior - fallback functions & depth:
357 When an input doesn't match the format required by the specialized compiled op, it will run
358 a fallback function. Fallback functions are recursively be compiled and specialized based
359 on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to
360 limit the number of specializations that can be compiled, before giving up on recompiling and
361 falling back to a completely un-fused, un-specialized implementation.
362
363The list of (type, depth) pairs controls the type of specializations and the number of
364specializations. For example: [(STATIC, 2), (DYNAMIC, 2)] indicates that the first
365two specializations will use static fusions, the following two specializations will use
366dynamic fusion, and any inputs that satisfy none of the 4 options will run an
367unfused implementation.
368
369NB: in the future, if more as more fusion backends are added there may be more granular
370apis for specific fusers.
371*/
372// clang-format on
373TORCH_API FusionStrategy getFusionStrategy();
374// returns previous strategy
375TORCH_API FusionStrategy setFusionStrategy(FusionStrategy& fusion_strategy);
376
377namespace detail {
378
379struct TORCH_API SlotCursor {
380 Module module_;
381 int64_t i_; // slot offset, -1 indicates the module itself
382};
383
384} // namespace detail
385
386// This iterator allows the (optionally recursive) enumeration of
387// the members of a Module. It performs a depth-first pre-order
388// traversal of the module. The Policy template parameter determines
389// which slots of the object should be included. For instance,
390// when iterating parameters, we return the parameter tensors,
391// but skip modules, buffers, and other attributes.
392// See ModulePolicy for comments about Policy object's API.
393template <typename Policy>
394struct slot_iterator_impl {
395 using SlotCursor = detail::SlotCursor;
396 using value_type = typename Policy::value_type;
397 slot_iterator_impl(
398 Module root,
399 bool recurse, // if true, do a depth-first search, otherwise, just look at
400 // slots of root
401 bool return_module) // if true include root itself as the first thing
402 // visited (used in modules())
403 : cursors_({SlotCursor{root, return_module ? -1 : 0}}),
404 recurse_(recurse) {
405 // advance iterator to first valid element (or the end, if empty)
406 while_not_valid_next();
407 }
408 // empty cursors_, represents end of iteration
409 slot_iterator_impl() : recurse_(false) {}
410 value_type operator*() const {
411 return Policy::create(cursors_, cur());
412 }
413 value_type operator->() const {
414 return **this;
415 }
416 slot_iterator_impl& operator++() {
417 next_valid();
418 return *this;
419 }
420 slot_iterator_impl operator++(int) {
421 // this is really expensive, should we delete it so people don't use it
422 // instead of prefix?
423 slot_iterator_impl old = *this;
424 ++(*this);
425 return old;
426 }
427
428 private:
429 // return_module() is a corner case where instead of returning a submodule
430 // of root, we are returning root itself, because we are iterating modules(),
431 // which contains the root module itself.
432 // It is represented with a single SlotCursor whose index is -1.
433 bool return_module() const {
434 return top().i_ == -1;
435 }
436 const SlotCursor& top() const {
437 return cursors_.back();
438 }
439 SlotCursor& top() {
440 return cursors_.back();
441 }
442 IValue cur() const {
443 return return_module() ? top().module_._ivalue()
444 : top().module_._ivalue()->getSlot(top().i_);
445 }
446
447 // advance to the next slot in a depth first pre-order traversal of the
448 // modules slots. This function does not guarantee the next slot is a
449 // valid element of the iteration. That is done by valid().
450 // invariant: !cursors_.empty()
451 void next() {
452 // we just returned the module itself, advance i_ to 0 so we are now
453 // at the first slot of the module.
454 if (return_module()) {
455 ++top().i_;
456 return;
457 }
458 // the last traversal action advanced beyond the number of slots in the
459 // module so continue the iteration in the parent.
460 if (top().i_ >= int64_t(top().module_._ivalue()->type()->numAttributes())) {
461 cursors_.pop_back();
462 if (!cursors_.empty()) {
463 ++top().i_;
464 }
465 return;
466 }
467 // if the current thing is a module, we have to scan it for recursive
468 // traversals. We do this by adding a new SlotCursor to track the traversal.
469 if (recurse_ &&
470 top().module_._ivalue()->type()->getAttribute(top().i_)->is_module()) {
471 cursors_.emplace_back(SlotCursor{cur().toModule(), 0});
472 return;
473 }
474 // common case: advance to the next slot.
475 ++top().i_;
476 }
477 // is the current position of the iterator a valid one?
478 // otherwise, we have to continue advancing.
479 bool valid() const {
480 return top().i_ <
481 int64_t(top().module_._ivalue()->type()->numAttributes()) &&
482 Policy::valid(
483 top().module_._ivalue()->type(),
484 top().i_,
485 top().module_._ivalue()->getSlot(top().i_));
486 }
487 void while_not_valid_next() {
488 // advance iteration until we are either at the end (cursors_.empty())
489 // or in a valid state. return_module() is a special case,
490 // and is always considered valid, regardless of Policy, because it is
491 // it is only true when we are iterating modules.
492 while (!cursors_.empty() && !return_module() && !valid()) {
493 next();
494 }
495 }
496 void next_valid() {
497 // avoid crashing if this is empty
498 if (cursors_.empty()) {
499 return;
500 }
501 // advance to next element, which is maybe not valid
502 next();
503 while_not_valid_next();
504 }
505
506 std::vector<SlotCursor> cursors_;
507 bool recurse_;
508
509 friend inline bool operator!=(
510 const slot_iterator_impl<Policy>& a,
511 const slot_iterator_impl<Policy>& b) {
512 // we are finished iteration when we have no more iteration SlotCursors.
513 // end is always an empty iterator with no cursors.
514 return (a.cursors_.empty() != b.cursors_.empty());
515 }
516};
517
518// This type represents lists of parameters, attributes, and
519// submodules contained in the module. It is abstract because
520// they are not stored directly in std::vectors but inside the
521// module's IValue object itself.
522template <typename Policy>
523struct slot_list_impl {
524 using iterator = slot_iterator_impl<Policy>;
525 using const_iterator = slot_iterator_impl<Policy>;
526 using value_type = typename iterator::value_type;
527 slot_iterator_impl<Policy> begin() const {
528 return slot_iterator_impl<Policy>(module_, recurse_, return_module_);
529 }
530 slot_iterator_impl<Policy> end() const {
531 return slot_iterator_impl<Policy>();
532 }
533 size_t size() const {
534 if (!size_) {
535 size_ = size_t(0);
536 // NOLINTNEXTLINE(clang-diagnostic-unused-variable)
537 for (const value_type& s : *(this)) {
538 (void)s; // Suppress unused variable warning
539 ++*size_;
540 }
541 }
542 return *size_;
543 }
544
545 slot_list_impl(Module module, bool recurse, bool return_module)
546 : module_(module),
547 recurse_(recurse),
548 return_module_(return_module),
549 size_(c10::nullopt) {
550 if (!recurse && !return_module && Policy::all_slots) {
551 size_ = module_.num_slots();
552 }
553 }
554
555 private:
556 Module module_;
557 bool recurse_;
558 bool return_module_;
559 // size of this list, cached on first request
560 // when we need to filter the slot list
561 mutable c10::optional<size_t> size_;
562 friend struct Module;
563};
564
565namespace detail {
566
567// slot_iterator_impl always iterate over all the slots in a module,
568// the Policy template argument determines slots should be returned and their
569// types
570struct TORCH_API ModulePolicy {
571 // the type of the value being returned
572 using value_type = Module;
573
574 // the logic for creating the type being returned, given the raw IValue
575 // of that object.
576 static value_type create(
577 const std::vector<detail::SlotCursor>& cursors,
578 IValue v) {
579 return Module(std::move(v).toObject());
580 }
581 // is slot i in typ something that this iterator should return, otherwise,
582 // we skip it.
583 static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) {
584 return typ->getAttribute(i)->is_module();
585 }
586 // are we going to return everything? If so, we can optimize the calculate
587 // of the size of the list.
588 static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false;
589};
590
591struct TORCH_API ParameterPolicy {
592 using value_type = at::Tensor;
593 static value_type create(
594 const std::vector<detail::SlotCursor>& cursors,
595 IValue v) {
596 return std::move(v).toTensor();
597 }
598 static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) {
599 return typ->is_parameter(i) && v.isTensor();
600 }
601 static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false;
602};
603
604struct TORCH_API BufferPolicy {
605 using value_type = at::Tensor;
606 static value_type create(
607 const std::vector<detail::SlotCursor>& cursors,
608 IValue v) {
609 return std::move(v).toTensor();
610 }
611 static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) {
612 return typ->getAttribute(i)->isSubtypeOf(*TensorType::get()) &&
613 typ->is_buffer(i);
614 }
615 static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false;
616};
617
618struct TORCH_API AttributePolicy {
619 using value_type = IValue;
620 static value_type create(
621 const std::vector<detail::SlotCursor>& cursors,
622 IValue v) {
623 return v;
624 }
625 static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) {
626 return true;
627 }
628 static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = true;
629};
630
631// take a Policy object, and make a version of it that returns the slot.
632// along with the fully qualified name of that slot. This is used for the named_
633// variants like named_parameters().
634template <typename Policy>
635struct NamedPolicy {
636 using value_type = Named<typename Policy::value_type>;
637 static value_type create(
638 const std::vector<detail::SlotCursor>& cursors,
639 IValue v) {
640 std::string name;
641 if (cursors.size() == 1) {
642 name = (cursors.back().i_ == -1) ? "" : nameFragment(cursors.back());
643 } else {
644 std::ostringstream ss;
645 for (const auto i : c10::irange(cursors.size())) {
646 if (i > 0) {
647 ss << ".";
648 }
649 ss << nameFragment(cursors[i]);
650 }
651 name = ss.str();
652 }
653 return value_type{std::move(name), Policy::create(cursors, std::move(v))};
654 }
655 static bool valid(const ClassTypePtr& t, size_t i, const IValue& v) {
656 return Policy::valid(t, i, v);
657 }
658 static constexpr bool all_slots = Policy::all_slots;
659
660 private:
661 static std::string nameFragment(const detail::SlotCursor& f) {
662 return f.module_.type()->getAttributeName(f.i_);
663 }
664};
665
666} // namespace detail
667
668TORCH_API bool& getInlineEverythingMode();
669
670namespace script {
671// We once had a `script::` namespace that was deleted. This is for backcompat
672// of the public API; new code should not use this type alias.
673using Module = ::torch::jit::Module;
674using ExtraFilesMap = ::torch::jit::ExtraFilesMap;
675} // namespace script
676
677} // namespace jit
678} // namespace torch
679