1 | #pragma once |
2 | #include <c10/util/Exception.h> |
3 | #include <torch/csrc/autograd/variable.h> |
4 | #include <torch/csrc/jit/api/object.h> |
5 | #include <torch/csrc/jit/frontend/source_range.h> |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/ir/named_value.h> |
8 | #include <torch/csrc/jit/runtime/argument_spec.h> |
9 | #include <torch/csrc/jit/runtime/graph_executor.h> |
10 | |
11 | #include <torch/csrc/Export.h> |
12 | #include <torch/csrc/api/include/torch/ordered_dict.h> |
13 | #include <torch/csrc/jit/api/compilation_unit.h> |
14 | #include <torch/csrc/utils/memory.h> |
15 | |
16 | #include <ATen/core/function_schema.h> |
17 | #include <ATen/core/qualified_name.h> |
18 | #include <c10/util/ArrayRef.h> |
19 | #include <c10/util/Optional.h> |
20 | #include <c10/util/irange.h> |
21 | |
22 | #include <functional> |
23 | #include <memory> |
24 | #include <mutex> |
25 | #include <ostream> |
26 | #include <string> |
27 | #include <unordered_map> |
28 | #include <unordered_set> |
29 | #include <utility> |
30 | #include <vector> |
31 | |
32 | // This file contains classes which assist in desugaring Python style |
33 | // modules and their methods into flattened graphs which don't have any |
34 | // function calls. |
35 | |
36 | namespace torch { |
37 | namespace jit { |
38 | |
39 | using ::c10::Argument; |
40 | using ::c10::FunctionSchema; |
41 | using ::c10::QualifiedName; |
42 | // Map which stores filename to content. |
43 | using = std::unordered_map<std::string, std::string>; |
44 | |
45 | using ModulePtr = c10::intrusive_ptr<c10::ivalue::Object>; |
46 | |
47 | struct Module; |
48 | |
49 | template <typename T> |
50 | struct slot_list_impl; |
51 | |
52 | template <typename T> |
53 | struct Named { |
54 | std::string name; |
55 | T value; |
56 | }; |
57 | |
58 | using NameModule = Named<Module>; |
59 | using NameValue = Named<IValue>; |
60 | using NameTensor = Named<at::Tensor>; |
61 | |
62 | namespace detail { |
63 | struct TORCH_API ModulePolicy; |
64 | struct TORCH_API ParameterPolicy; |
65 | struct TORCH_API AttributePolicy; |
66 | struct TORCH_API BufferPolicy; |
67 | template <typename P> |
68 | struct NamedPolicy; |
69 | } // namespace detail |
70 | |
71 | using module_list = slot_list_impl<detail::ModulePolicy>; |
72 | using named_module_list = |
73 | slot_list_impl<detail::NamedPolicy<detail::ModulePolicy>>; |
74 | |
75 | using parameter_list = slot_list_impl<detail::ParameterPolicy>; |
76 | using named_parameter_list = |
77 | slot_list_impl<detail::NamedPolicy<detail::ParameterPolicy>>; |
78 | |
79 | using attribute_list = slot_list_impl<detail::AttributePolicy>; |
80 | using named_attribute_list = |
81 | slot_list_impl<detail::NamedPolicy<detail::AttributePolicy>>; |
82 | |
83 | using buffer_list = slot_list_impl<detail::BufferPolicy>; |
84 | using named_buffer_list = |
85 | slot_list_impl<detail::NamedPolicy<detail::BufferPolicy>>; |
86 | |
87 | using ModuleLookup = std::function<Module(const std::vector<std::string>&)>; |
88 | |
89 | struct TORCH_API Module : public Object { |
90 | explicit Module(c10::QualifiedName class_name); |
91 | Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type); |
92 | Module() = default; |
93 | Module( |
94 | c10::QualifiedName, |
95 | std::shared_ptr<CompilationUnit> cu, |
96 | bool shouldMangle = false); |
97 | Module(ModulePtr module_value) : Object(std::move(module_value)) {} |
98 | ~Module() = default; |
99 | |
100 | void set_optimized(bool o) { |
101 | TORCH_WARN( |
102 | "Module::set_optimized() is deprecated and has no effect. " |
103 | "Please use setGraphExecutorOptimize()" ); |
104 | } |
105 | |
106 | bool is_optimized() const { |
107 | TORCH_WARN( |
108 | "Module::is_optimized() is deprecated and always returns true. " |
109 | "Please use getGraphExecutorOptimize()" ); |
110 | return true; |
111 | } |
112 | |
113 | IValue forward(std::vector<IValue> inputs, const Kwargs& kwargs = Kwargs()) { |
114 | return get_method("forward" )(std::move(inputs), kwargs); |
115 | } |
116 | |
117 | // In script modules, buffers are Tensors attribute that are _not_ registered |
118 | // as parameters. This is different than in nn.Module where there is a special |
119 | // register_buffer method. With this simplification, we only need to track |
120 | // whether a slot is a parameter to be able to classify it. |
121 | void register_buffer(const std::string& name, at::Tensor v) { |
122 | bool is_param = false; |
123 | bool is_buffer = true; |
124 | type()->addOrCheckAttribute(name, TensorType::get(), is_param, is_buffer); |
125 | _ivalue()->setAttr(name, std::move(v)); |
126 | } |
127 | |
128 | void register_parameter( |
129 | const std::string& name, |
130 | at::Tensor v, |
131 | bool is_buffer) { |
132 | type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer, is_buffer); |
133 | _ivalue()->setAttr(name, std::move(v)); |
134 | } |
135 | |
136 | void register_attribute( |
137 | const std::string& name, |
138 | const TypePtr& t, |
139 | IValue v, |
140 | bool is_param = false, |
141 | bool is_buffer = false) { |
142 | type()->addOrCheckAttribute(name, t, is_param, is_buffer); |
143 | _ivalue()->setAttr(name, std::move(v)); |
144 | } |
145 | |
146 | void register_module(const std::string& name, const Module& module) { |
147 | type()->addOrCheckAttribute(name, module.type()); |
148 | _ivalue()->setAttr(name, module._ivalue()); |
149 | } |
150 | |
151 | void apply(const std::function<void(Module&)>& fn); |
152 | |
153 | buffer_list buffers(bool recurse = true) const; |
154 | named_buffer_list named_buffers(bool recurse = true) const; |
155 | |
156 | module_list children() const; // direct modules |
157 | named_module_list named_children() const; |
158 | module_list modules() const; // all modules, including this one, recursively |
159 | named_module_list named_modules() const; |
160 | |
161 | // all tensors involved in gradient optimization |
162 | parameter_list parameters(bool recurse = true) const; |
163 | named_parameter_list named_parameters(bool recurse = true) const; |
164 | |
165 | // all members of the object, similar to iterating over dir(obj) in python |
166 | attribute_list attributes(bool recurse = true) const; |
167 | named_attribute_list named_attributes(bool recurse = true) const; |
168 | |
169 | void dump( |
170 | bool print_method_bodies, |
171 | bool print_attr_values, |
172 | bool print_param_values) const; |
173 | |
174 | std::string dump_to_str( |
175 | bool print_method_bodies, |
176 | bool print_attr_values, |
177 | bool print_param_values) const; |
178 | |
179 | /// Enables "training" mode. |
180 | void train(bool on = true); |
181 | /// Calls train(false) to enable "eval" mode. |
182 | /// Do not override this method, override `train()` instead. |
183 | void eval() { |
184 | train(/*on=*/false); |
185 | } |
186 | /// True if the module is in training mode. |
187 | bool is_training() const { |
188 | return attr("training" , true).toBool(); |
189 | } |
190 | |
191 | /// Recursively casts all parameters to the given `dtype` and `device`. |
192 | /// |
193 | /// If `non_blocking` is true and the source is in pinned memory and |
194 | /// destination is on the GPU or vice versa, the copy is performed |
195 | /// asynchronously with respect to the host. Otherwise, the argument has no |
196 | /// effect. |
197 | void to(at::Device device, at::ScalarType dtype, bool non_blocking = false); |
198 | |
199 | /// Recursively casts all parameters to the given dtype. |
200 | /// |
201 | /// If `non_blocking` is true and the source is in pinned memory and |
202 | /// destination is on the GPU or vice versa, the copy is performed |
203 | /// asynchronously with respect to the host. Otherwise, the argument has no |
204 | /// effect. |
205 | void to(at::ScalarType dtype, bool non_blocking = false); |
206 | |
207 | /// Recursively moves all parameters to the given device. |
208 | /// |
209 | /// If `non_blocking` is true and the source is in pinned memory and |
210 | /// destination is on the GPU or vice versa, the copy is performed |
211 | /// asynchronously with respect to the host. Otherwise, the argument has no |
212 | /// effect. |
213 | void to(at::Device device, bool non_blocking = false); |
214 | |
215 | void save( |
216 | std::ostream& out, |
217 | const ExtraFilesMap& = ExtraFilesMap()) const; |
218 | |
219 | void save( |
220 | const std::string& filename, |
221 | const ExtraFilesMap& = ExtraFilesMap()) const; |
222 | |
223 | void _save_for_mobile( |
224 | std::ostream& out, |
225 | const ExtraFilesMap& = ExtraFilesMap(), |
226 | bool save_mobile_debug_info = false, |
227 | bool use_flatbuffer = false) const; |
228 | |
229 | void _save_for_mobile( |
230 | const std::string& filename, |
231 | const ExtraFilesMap& = ExtraFilesMap(), |
232 | bool save_mobile_debug_info = false, |
233 | bool use_flatbuffer = false) const; |
234 | |
235 | Module copy() const; |
236 | |
237 | Module deepcopy() const; |
238 | |
239 | // Clones both the underlying `ClassType` and the module instance(data), this |
240 | // function creates a new `ClassType` and returns a new instance that has the |
241 | // same data as the current instance but with the new type, shared ClassType |
242 | // will be preserved as well |
243 | Module clone(bool inplace = false) const; |
244 | |
245 | // Clones both the underlying `ClassType` and the module instance(data), this |
246 | // function creates a new `ClassType` and returns a new instance that has the |
247 | // same data as the current instance but with the new type, shared ClassType |
248 | // will be preserved as well. Also allows the caller to specify a set of |
249 | // method and attribute names to not clone. |
250 | Module clone( |
251 | bool inplace, |
252 | const std::unordered_set<std::string>& ignored_method, |
253 | const std::unordered_set<std::string>& ignored_attributes) const; |
254 | |
255 | void clone_method(const Module& orig, const std::string& name); |
256 | |
257 | IValue operator()(std::vector<IValue> inputs); |
258 | |
259 | template <typename... Types> |
260 | IValue create_class(const c10::QualifiedName& name, Types&&... args) const { |
261 | return create_class(name, {IValue(std::forward<Types>(args))...}); |
262 | } |
263 | |
264 | IValue create_class(const c10::QualifiedName& name, Stack stack) const; |
265 | |
266 | inline bool operator==(const Module& y) const noexcept { |
267 | return _ivalue() == y._ivalue(); |
268 | } |
269 | |
270 | void set_delete_memory(std::shared_ptr<char> delete_mem) { |
271 | mem_to_delete_ = delete_mem; |
272 | } |
273 | |
274 | // A set of functions to maintain input shapes through torch.jit.save and |
275 | // torch.jit.load. It only works on tensors and lists/dicts of tensors |
276 | // because tracing is only supported by these types. |
277 | void store_traced_inputs(std::string func_name, std::vector<IValue> inputs) { |
278 | if (inputs.size() == 0) { |
279 | return; |
280 | } |
281 | auto c10_inputs = c10::impl::GenericList(AnyType::get()); |
282 | for (const IValue& value : inputs) { |
283 | // Not checking whether this is traceable type as that is already checked |
284 | // higher up in the stack and changing that would require a larger |
285 | // restructuring. |
286 | c10_inputs.push_back(value); |
287 | } |
288 | traced_inputs_.insert_or_assign(func_name, c10_inputs); |
289 | } |
290 | |
291 | c10::Dict<std::string, c10::impl::GenericList> retrieve_traced_inputs() |
292 | const { |
293 | return traced_inputs_; |
294 | } |
295 | |
296 | private: |
297 | Module clone_impl( |
298 | std::unordered_map<TypePtr, TypePtr>& type_remap, |
299 | bool inplace, |
300 | IValue::HashAliasedIValueMap memo, |
301 | const std::unordered_set<std::string>& ignored_methods, |
302 | const std::unordered_set<std::string>& ignored_attributes) const; |
303 | |
304 | void clone_method( |
305 | const Module& orig, |
306 | const Function& method, |
307 | const std::unordered_map<TypePtr, TypePtr>& type_remap); |
308 | |
309 | c10::QualifiedName getNameForMethod(std::string basename) const { |
310 | return QualifiedName(*type()->name(), std::move(basename)); |
311 | } |
312 | |
313 | void to_impl( |
314 | const c10::optional<at::Device>& device, |
315 | const c10::optional<at::ScalarType>& dtype, |
316 | bool non_blocking); |
317 | |
318 | // Extra handle for the module to delete when itself is deleted |
319 | std::shared_ptr<char> mem_to_delete_; |
320 | |
321 | // Map of function names to the traced inputs that they have been traced with |
322 | c10::Dict<std::string, c10::impl::GenericList> traced_inputs_; |
323 | }; |
324 | |
325 | // C++ equivalent api of `torch.jit.freeze`. See documentation there for |
326 | // details. |
327 | TORCH_API Module freeze( |
328 | const Module& module, |
329 | c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt, |
330 | bool optimize_numerics = true); |
331 | |
332 | // C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation |
333 | // there for details. |
334 | TORCH_API Module optimize_for_inference( |
335 | Module& module, |
336 | const std::vector<std::string>& other_methods = {}); |
337 | |
338 | enum class FusionBehavior { STATIC, DYNAMIC }; |
339 | |
340 | using FusionStrategy = std::vector<std::pair<FusionBehavior, size_t>>; |
341 | // clang-format off |
342 | /* |
343 | Sets the type and number of specializations that can occur during fusion. |
344 | |
345 | Usage: provide a list of pairs (type, depth) where type is one of STATIC or DYNAMIC |
346 | and depth is an integer. |
347 | |
348 | Behavior - static vs dynamic: |
349 | In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined |
350 | based on some initial profiling runs. |
351 | In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple |
352 | shapes are possible. |
353 | |
354 | In both cases, we also recompile on new striding behavior, device, or dtype. |
355 | |
356 | Behavior - fallback functions & depth: |
357 | When an input doesn't match the format required by the specialized compiled op, it will run |
358 | a fallback function. Fallback functions are recursively be compiled and specialized based |
359 | on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to |
360 | limit the number of specializations that can be compiled, before giving up on recompiling and |
361 | falling back to a completely un-fused, un-specialized implementation. |
362 | |
363 | The list of (type, depth) pairs controls the type of specializations and the number of |
364 | specializations. For example: [(STATIC, 2), (DYNAMIC, 2)] indicates that the first |
365 | two specializations will use static fusions, the following two specializations will use |
366 | dynamic fusion, and any inputs that satisfy none of the 4 options will run an |
367 | unfused implementation. |
368 | |
369 | NB: in the future, if more as more fusion backends are added there may be more granular |
370 | apis for specific fusers. |
371 | */ |
372 | // clang-format on |
373 | TORCH_API FusionStrategy getFusionStrategy(); |
374 | // returns previous strategy |
375 | TORCH_API FusionStrategy setFusionStrategy(FusionStrategy& fusion_strategy); |
376 | |
377 | namespace detail { |
378 | |
379 | struct TORCH_API SlotCursor { |
380 | Module module_; |
381 | int64_t i_; // slot offset, -1 indicates the module itself |
382 | }; |
383 | |
384 | } // namespace detail |
385 | |
386 | // This iterator allows the (optionally recursive) enumeration of |
387 | // the members of a Module. It performs a depth-first pre-order |
388 | // traversal of the module. The Policy template parameter determines |
389 | // which slots of the object should be included. For instance, |
390 | // when iterating parameters, we return the parameter tensors, |
391 | // but skip modules, buffers, and other attributes. |
392 | // See ModulePolicy for comments about Policy object's API. |
393 | template <typename Policy> |
394 | struct slot_iterator_impl { |
395 | using SlotCursor = detail::SlotCursor; |
396 | using value_type = typename Policy::value_type; |
397 | slot_iterator_impl( |
398 | Module root, |
399 | bool recurse, // if true, do a depth-first search, otherwise, just look at |
400 | // slots of root |
401 | bool return_module) // if true include root itself as the first thing |
402 | // visited (used in modules()) |
403 | : cursors_({SlotCursor{root, return_module ? -1 : 0}}), |
404 | recurse_(recurse) { |
405 | // advance iterator to first valid element (or the end, if empty) |
406 | while_not_valid_next(); |
407 | } |
408 | // empty cursors_, represents end of iteration |
409 | slot_iterator_impl() : recurse_(false) {} |
410 | value_type operator*() const { |
411 | return Policy::create(cursors_, cur()); |
412 | } |
413 | value_type operator->() const { |
414 | return **this; |
415 | } |
416 | slot_iterator_impl& operator++() { |
417 | next_valid(); |
418 | return *this; |
419 | } |
420 | slot_iterator_impl operator++(int) { |
421 | // this is really expensive, should we delete it so people don't use it |
422 | // instead of prefix? |
423 | slot_iterator_impl old = *this; |
424 | ++(*this); |
425 | return old; |
426 | } |
427 | |
428 | private: |
429 | // return_module() is a corner case where instead of returning a submodule |
430 | // of root, we are returning root itself, because we are iterating modules(), |
431 | // which contains the root module itself. |
432 | // It is represented with a single SlotCursor whose index is -1. |
433 | bool return_module() const { |
434 | return top().i_ == -1; |
435 | } |
436 | const SlotCursor& top() const { |
437 | return cursors_.back(); |
438 | } |
439 | SlotCursor& top() { |
440 | return cursors_.back(); |
441 | } |
442 | IValue cur() const { |
443 | return return_module() ? top().module_._ivalue() |
444 | : top().module_._ivalue()->getSlot(top().i_); |
445 | } |
446 | |
447 | // advance to the next slot in a depth first pre-order traversal of the |
448 | // modules slots. This function does not guarantee the next slot is a |
449 | // valid element of the iteration. That is done by valid(). |
450 | // invariant: !cursors_.empty() |
451 | void next() { |
452 | // we just returned the module itself, advance i_ to 0 so we are now |
453 | // at the first slot of the module. |
454 | if (return_module()) { |
455 | ++top().i_; |
456 | return; |
457 | } |
458 | // the last traversal action advanced beyond the number of slots in the |
459 | // module so continue the iteration in the parent. |
460 | if (top().i_ >= int64_t(top().module_._ivalue()->type()->numAttributes())) { |
461 | cursors_.pop_back(); |
462 | if (!cursors_.empty()) { |
463 | ++top().i_; |
464 | } |
465 | return; |
466 | } |
467 | // if the current thing is a module, we have to scan it for recursive |
468 | // traversals. We do this by adding a new SlotCursor to track the traversal. |
469 | if (recurse_ && |
470 | top().module_._ivalue()->type()->getAttribute(top().i_)->is_module()) { |
471 | cursors_.emplace_back(SlotCursor{cur().toModule(), 0}); |
472 | return; |
473 | } |
474 | // common case: advance to the next slot. |
475 | ++top().i_; |
476 | } |
477 | // is the current position of the iterator a valid one? |
478 | // otherwise, we have to continue advancing. |
479 | bool valid() const { |
480 | return top().i_ < |
481 | int64_t(top().module_._ivalue()->type()->numAttributes()) && |
482 | Policy::valid( |
483 | top().module_._ivalue()->type(), |
484 | top().i_, |
485 | top().module_._ivalue()->getSlot(top().i_)); |
486 | } |
487 | void while_not_valid_next() { |
488 | // advance iteration until we are either at the end (cursors_.empty()) |
489 | // or in a valid state. return_module() is a special case, |
490 | // and is always considered valid, regardless of Policy, because it is |
491 | // it is only true when we are iterating modules. |
492 | while (!cursors_.empty() && !return_module() && !valid()) { |
493 | next(); |
494 | } |
495 | } |
496 | void next_valid() { |
497 | // avoid crashing if this is empty |
498 | if (cursors_.empty()) { |
499 | return; |
500 | } |
501 | // advance to next element, which is maybe not valid |
502 | next(); |
503 | while_not_valid_next(); |
504 | } |
505 | |
506 | std::vector<SlotCursor> cursors_; |
507 | bool recurse_; |
508 | |
509 | friend inline bool operator!=( |
510 | const slot_iterator_impl<Policy>& a, |
511 | const slot_iterator_impl<Policy>& b) { |
512 | // we are finished iteration when we have no more iteration SlotCursors. |
513 | // end is always an empty iterator with no cursors. |
514 | return (a.cursors_.empty() != b.cursors_.empty()); |
515 | } |
516 | }; |
517 | |
518 | // This type represents lists of parameters, attributes, and |
519 | // submodules contained in the module. It is abstract because |
520 | // they are not stored directly in std::vectors but inside the |
521 | // module's IValue object itself. |
522 | template <typename Policy> |
523 | struct slot_list_impl { |
524 | using iterator = slot_iterator_impl<Policy>; |
525 | using const_iterator = slot_iterator_impl<Policy>; |
526 | using value_type = typename iterator::value_type; |
527 | slot_iterator_impl<Policy> begin() const { |
528 | return slot_iterator_impl<Policy>(module_, recurse_, return_module_); |
529 | } |
530 | slot_iterator_impl<Policy> end() const { |
531 | return slot_iterator_impl<Policy>(); |
532 | } |
533 | size_t size() const { |
534 | if (!size_) { |
535 | size_ = size_t(0); |
536 | // NOLINTNEXTLINE(clang-diagnostic-unused-variable) |
537 | for (const value_type& s : *(this)) { |
538 | (void)s; // Suppress unused variable warning |
539 | ++*size_; |
540 | } |
541 | } |
542 | return *size_; |
543 | } |
544 | |
545 | slot_list_impl(Module module, bool recurse, bool return_module) |
546 | : module_(module), |
547 | recurse_(recurse), |
548 | return_module_(return_module), |
549 | size_(c10::nullopt) { |
550 | if (!recurse && !return_module && Policy::all_slots) { |
551 | size_ = module_.num_slots(); |
552 | } |
553 | } |
554 | |
555 | private: |
556 | Module module_; |
557 | bool recurse_; |
558 | bool return_module_; |
559 | // size of this list, cached on first request |
560 | // when we need to filter the slot list |
561 | mutable c10::optional<size_t> size_; |
562 | friend struct Module; |
563 | }; |
564 | |
565 | namespace detail { |
566 | |
567 | // slot_iterator_impl always iterate over all the slots in a module, |
568 | // the Policy template argument determines slots should be returned and their |
569 | // types |
570 | struct TORCH_API ModulePolicy { |
571 | // the type of the value being returned |
572 | using value_type = Module; |
573 | |
574 | // the logic for creating the type being returned, given the raw IValue |
575 | // of that object. |
576 | static value_type create( |
577 | const std::vector<detail::SlotCursor>& cursors, |
578 | IValue v) { |
579 | return Module(std::move(v).toObject()); |
580 | } |
581 | // is slot i in typ something that this iterator should return, otherwise, |
582 | // we skip it. |
583 | static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { |
584 | return typ->getAttribute(i)->is_module(); |
585 | } |
586 | // are we going to return everything? If so, we can optimize the calculate |
587 | // of the size of the list. |
588 | static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false; |
589 | }; |
590 | |
591 | struct TORCH_API ParameterPolicy { |
592 | using value_type = at::Tensor; |
593 | static value_type create( |
594 | const std::vector<detail::SlotCursor>& cursors, |
595 | IValue v) { |
596 | return std::move(v).toTensor(); |
597 | } |
598 | static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { |
599 | return typ->is_parameter(i) && v.isTensor(); |
600 | } |
601 | static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false; |
602 | }; |
603 | |
604 | struct TORCH_API BufferPolicy { |
605 | using value_type = at::Tensor; |
606 | static value_type create( |
607 | const std::vector<detail::SlotCursor>& cursors, |
608 | IValue v) { |
609 | return std::move(v).toTensor(); |
610 | } |
611 | static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { |
612 | return typ->getAttribute(i)->isSubtypeOf(*TensorType::get()) && |
613 | typ->is_buffer(i); |
614 | } |
615 | static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false; |
616 | }; |
617 | |
618 | struct TORCH_API AttributePolicy { |
619 | using value_type = IValue; |
620 | static value_type create( |
621 | const std::vector<detail::SlotCursor>& cursors, |
622 | IValue v) { |
623 | return v; |
624 | } |
625 | static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { |
626 | return true; |
627 | } |
628 | static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = true; |
629 | }; |
630 | |
631 | // take a Policy object, and make a version of it that returns the slot. |
632 | // along with the fully qualified name of that slot. This is used for the named_ |
633 | // variants like named_parameters(). |
634 | template <typename Policy> |
635 | struct NamedPolicy { |
636 | using value_type = Named<typename Policy::value_type>; |
637 | static value_type create( |
638 | const std::vector<detail::SlotCursor>& cursors, |
639 | IValue v) { |
640 | std::string name; |
641 | if (cursors.size() == 1) { |
642 | name = (cursors.back().i_ == -1) ? "" : nameFragment(cursors.back()); |
643 | } else { |
644 | std::ostringstream ss; |
645 | for (const auto i : c10::irange(cursors.size())) { |
646 | if (i > 0) { |
647 | ss << "." ; |
648 | } |
649 | ss << nameFragment(cursors[i]); |
650 | } |
651 | name = ss.str(); |
652 | } |
653 | return value_type{std::move(name), Policy::create(cursors, std::move(v))}; |
654 | } |
655 | static bool valid(const ClassTypePtr& t, size_t i, const IValue& v) { |
656 | return Policy::valid(t, i, v); |
657 | } |
658 | static constexpr bool all_slots = Policy::all_slots; |
659 | |
660 | private: |
661 | static std::string nameFragment(const detail::SlotCursor& f) { |
662 | return f.module_.type()->getAttributeName(f.i_); |
663 | } |
664 | }; |
665 | |
666 | } // namespace detail |
667 | |
668 | TORCH_API bool& getInlineEverythingMode(); |
669 | |
670 | namespace script { |
671 | // We once had a `script::` namespace that was deleted. This is for backcompat |
672 | // of the public API; new code should not use this type alias. |
673 | using Module = ::torch::jit::Module; |
674 | using = ::torch::jit::ExtraFilesMap; |
675 | } // namespace script |
676 | |
677 | } // namespace jit |
678 | } // namespace torch |
679 | |