1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ |
16 | #define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ |
17 | |
18 | #include <algorithm> |
19 | #include <list> |
20 | #include <memory> |
21 | #include <string> |
22 | // TODO(b/114492873): Move this include into core/platform. |
23 | #include <thread> // NOLINT |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | #include "absl/container/flat_hash_map.h" |
28 | #include "tensorflow/core/framework/cancellation.h" |
29 | #include "tensorflow/core/framework/metrics.h" |
30 | #include "tensorflow/core/framework/model.pb.h" |
31 | #include "tensorflow/core/framework/types.h" |
32 | #include "tensorflow/core/lib/gtl/cleanup.h" |
33 | #include "tensorflow/core/lib/gtl/map_util.h" |
34 | #include "tensorflow/core/lib/histogram/histogram.h" |
35 | #include "tensorflow/core/lib/random/random.h" |
36 | #include "tensorflow/core/platform/cpu_info.h" |
37 | #include "tensorflow/core/platform/env.h" |
38 | #include "tensorflow/core/platform/mutex.h" |
39 | #include "tensorflow/core/platform/path.h" |
40 | #include "tensorflow/core/platform/statusor.h" |
41 | #include "tensorflow/core/platform/strcat.h" |
42 | #include "tensorflow/core/platform/stringprintf.h" |
43 | |
44 | namespace tensorflow { |
45 | namespace data { |
46 | namespace model { |
47 | |
48 | // A constant that can be used to enable auto-tuning. |
49 | constexpr int64_t kAutotune = -1; |
50 | constexpr char kParallelism[] = "parallelism" ; |
51 | constexpr char kBufferSize[] = "buffer_size" ; |
52 | constexpr char kCycleLength[] = "cycle_length" ; |
53 | constexpr char kDeterministic[] = "deterministic" ; |
54 | constexpr char kMaxBufferedElements[] = "max_buffered_elements" ; |
55 | |
56 | // A key used to identify the input time of the model. |
57 | constexpr char kModelInputTimeKey[] = "model_input_time" ; |
58 | |
59 | // Default share of available RAM that can be used by model's internal buffers. |
60 | constexpr double kRamBudgetShare = 0.5; |
61 | |
62 | // Weight of the latest processing time used in computing the exponential moving |
63 | // average of processing time per element. |
64 | constexpr double kProcessingTimeEmaWeight = 0.1; |
65 | |
66 | enum class TraversalOrder { |
67 | BFS = 0, |
68 | REVERSE_BFS = 1, |
69 | }; |
70 | |
71 | // Represents thread-safe state that can be shared between an input pipeline and |
72 | // the performance model. |
73 | struct SharedState { |
74 | public: |
75 | SharedState(int64_t value, std::shared_ptr<mutex> mu, |
76 | std::shared_ptr<condition_variable> cond_var) |
77 | : value(value), |
78 | mu(std::move(mu)), |
79 | cond_var(std::move(cond_var)), |
80 | tunable(value == kAutotune) {} |
81 | |
82 | double value; |
83 | const std::shared_ptr<mutex> mu; |
84 | const std::shared_ptr<condition_variable> cond_var; |
85 | const bool tunable; |
86 | }; |
87 | |
88 | // Represents a parameter. |
89 | struct Parameter { |
90 | Parameter(const string& name, std::shared_ptr<SharedState> state, double min, |
91 | double max) |
92 | : name(name), |
93 | // Sometimes non-autotune nodes (with `autotune_=false`) may contain |
94 | // parameters (for example inputs of parallel interleave dataset which |
95 | // are not in the current cycle). To avoid unrealistic situation |
96 | // (say `buffer_size=-1` or `parallelism=-1`) in the optimization |
97 | // computation, if the state value is `kAutotune=-1` (just to indicate |
98 | // the `SharedState` is tunable), we initialize the parameter value to |
99 | // be the minimal value of the state. |
100 | value(state == nullptr || state->value == kAutotune ? min |
101 | : state->value), |
102 | min(min), |
103 | max(max), |
104 | state(std::move(state)) {} |
105 | |
106 | // Human-readable name of the parameter. |
107 | const string name; |
108 | |
109 | // Identifies the model value of the parameter. This can be different from |
110 | // the actual value (e.g. during optimization search). |
111 | double value; |
112 | |
113 | // Identifies the minimum value of the parameter. |
114 | const double min; |
115 | |
116 | // Identifies the maximum value of the parameter. |
117 | const double max; |
118 | |
119 | // Shared state of the parameter. |
120 | std::shared_ptr<SharedState> state; |
121 | }; |
122 | |
123 | // Returns a new tunable parameter. |
124 | std::shared_ptr<Parameter> MakeParameter(const string& name, |
125 | std::shared_ptr<SharedState> state, |
126 | double min, double max); |
127 | |
128 | // Returns a new non-tunable parameter. |
129 | std::shared_ptr<Parameter> MakeNonTunableParameter(const string& name, |
130 | double value); |
131 | |
132 | // Abstract representation of a TensorFlow input pipeline node. It collects |
133 | // information about inputs to this node, processing time spent executing the |
134 | // node logic, number of elements produced by the node, various other |
135 | // information (e.g. batch size or execution parallelism). |
136 | // |
137 | // Developers of tf.data transformations are not expected to interact with |
138 | // this class directly. Boiler plate code for creating the abstract |
139 | // representation of the input pipeline and collecting common information has |
140 | // been added to the implementation of `DatasetBase` and `DatasetBaseIterator` |
141 | // respectively. |
142 | // |
143 | // In addition, `DatasetBaseIterator` provides wrappers that can be used for |
144 | // transformation-specific information collection. The `SetMetadata` wrapper |
145 | // can be used to pass arbitrary metadata to the modeling framework, while the |
146 | // `StartWork` and `StopWork` wrappers should be used to correctly account for |
147 | // processing time of multi-threaded transformation that yield the CPU; such |
148 | // transformations should invoke `StartWork()` when a transformation thread |
149 | // starts executing (e.g. when created or woken up) and `StopWork()` when a |
150 | // transformation thread stops executing (e.g. when returning or waiting). |
151 | class Node { |
152 | public: |
153 | // Arguments for `Node` constructor. |
154 | struct Args { |
155 | int64_t id; |
156 | string name; |
157 | std::shared_ptr<Node> output; |
158 | }; |
159 | |
160 | using Factory = std::function<std::shared_ptr<Node>(Args)>; |
161 | using NodeVector = std::vector<std::shared_ptr<Node>>; |
162 | using NodePairList = |
163 | std::list<std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>>; |
164 | using ModelParameters = |
165 | std::vector<std::pair<string, std::shared_ptr<Parameter>>>; |
166 | using NodeValues = absl::flat_hash_map<string, double>; |
167 | using ParameterGradients = |
168 | absl::flat_hash_map<std::pair<string, string>, double>; |
169 | |
170 | explicit Node(Args args) |
171 | : id_(args.id), |
172 | name_(std::move(args.name)), |
173 | autotune_(true), |
174 | buffered_bytes_(0), |
175 | buffered_elements_(0), |
176 | buffered_elements_low_(std::numeric_limits<int64_t>::max()), |
177 | buffered_elements_high_(std::numeric_limits<int64_t>::min()), |
178 | bytes_consumed_(0), |
179 | bytes_produced_(0), |
180 | num_elements_(0), |
181 | processing_time_(0), |
182 | record_metrics_(true), |
183 | metrics_(name_), |
184 | output_(args.output.get()) {} |
185 | |
186 | virtual ~Node() { |
187 | // Clear the sub-nodes instead of relying on implicit shared pointer |
188 | // destructor to avoid potential stack overflow when the tree is deep. |
189 | std::deque<std::shared_ptr<Node>> queue; |
190 | { |
191 | mutex_lock l(mu_); |
192 | while (inputs_.size() > 0) { |
193 | queue.push_back(inputs_.front()); |
194 | inputs_.pop_front(); |
195 | } |
196 | } |
197 | while (!queue.empty()) { |
198 | auto node = queue.back(); |
199 | queue.pop_back(); |
200 | { |
201 | mutex_lock l(node->mu_); |
202 | while (node->inputs_.size() > 0) { |
203 | queue.push_back(node->inputs_.front()); |
204 | node->inputs_.pop_front(); |
205 | } |
206 | } |
207 | } |
208 | |
209 | FlushMetrics(); |
210 | } |
211 | |
212 | // Adds an input. |
213 | void add_input(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_) { |
214 | mutex_lock l(mu_); |
215 | inputs_.push_back(node); |
216 | } |
217 | |
218 | // Increments the aggregate processing time by the given delta. |
219 | void add_processing_time(int64_t delta) TF_LOCKS_EXCLUDED(mu_) { |
220 | processing_time_ += delta; |
221 | } |
222 | |
223 | // Returns an indication whether autotuning is enabled for this node. |
224 | bool autotune() const TF_LOCKS_EXCLUDED(mu_) { return autotune_; } |
225 | |
226 | // Returns the number of bytes stored in this node's buffer. |
227 | int64_t buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) { |
228 | return buffered_bytes_; |
229 | } |
230 | |
231 | // Returns the number of elements stored in this node's buffer. |
232 | int64_t buffered_elements() const TF_LOCKS_EXCLUDED(mu_) { |
233 | return buffered_elements_; |
234 | } |
235 | |
236 | // Returns the low watermark of the number of elements stored in this node's |
237 | // buffer. The watermarks are reset at the beginning of the execution time and |
238 | // each time the buffer is upsized or downsized. |
239 | int64_t buffered_elements_low() const TF_LOCKS_EXCLUDED(mu_) { |
240 | return buffered_elements_low_; |
241 | } |
242 | |
243 | // Returns the high watermark of the number of elements stored in this node's |
244 | // buffer. The watermarks are reset at the beginning of the execution time and |
245 | // each time the buffer is upsized or downsized. |
246 | int64_t buffered_elements_high() const TF_LOCKS_EXCLUDED(mu_) { |
247 | return buffered_elements_high_; |
248 | } |
249 | |
250 | // Returns the number of bytes consumed by the node. |
251 | int64_t bytes_consumed() const TF_LOCKS_EXCLUDED(mu_) { |
252 | return bytes_consumed_; |
253 | } |
254 | |
255 | // Returns the number of bytes produced by the node. |
256 | int64_t bytes_produced() const TF_LOCKS_EXCLUDED(mu_) { |
257 | return bytes_produced_; |
258 | } |
259 | |
260 | // Indicates whether the node has tunable parameters. |
261 | bool has_tunable_parameters() const TF_LOCKS_EXCLUDED(mu_) { |
262 | tf_shared_lock l(mu_); |
263 | for (const auto& pair : parameters_) { |
264 | if (pair.second->state->tunable) return true; |
265 | } |
266 | return false; |
267 | } |
268 | |
269 | // Returns the unique node ID. |
270 | int64_t id() const TF_LOCKS_EXCLUDED(mu_) { return id_; } |
271 | |
272 | // Returns the node inputs. |
273 | std::list<std::shared_ptr<Node>> inputs() const TF_LOCKS_EXCLUDED(mu_) { |
274 | tf_shared_lock l(mu_); |
275 | return inputs_; |
276 | } |
277 | |
278 | // Returns a longer node name that is guaranteed to be unique. |
279 | string long_name() const { return strings::StrCat(name_, "(id:" , id_, ")" ); } |
280 | |
281 | // Returns the node name. |
282 | const string& name() const { return name_; } |
283 | |
284 | // Returns the number of elements produced by the node. |
285 | int64_t num_elements() const TF_LOCKS_EXCLUDED(mu_) { return num_elements_; } |
286 | |
287 | // Returns the node output. |
288 | Node* output() const { return output_; } |
289 | |
290 | // Returns the parameter value. |
291 | double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) { |
292 | tf_shared_lock l(mu_); |
293 | return parameters_.at(name)->state->value; |
294 | } |
295 | |
296 | // Returns the aggregate processing time. |
297 | int64_t processing_time() const TF_LOCKS_EXCLUDED(mu_) { |
298 | return processing_time_; |
299 | } |
300 | |
301 | // Records that the node consumed the given number of bytes. |
302 | void record_bytes_consumed(int64_t num_bytes) { |
303 | bytes_consumed_ += num_bytes; |
304 | } |
305 | |
306 | // Records that the node produced the given number of bytes. |
307 | void record_bytes_produced(int64_t num_bytes) { |
308 | bytes_produced_ += num_bytes; |
309 | } |
310 | |
311 | // Records the change in this node's buffer. |
312 | void record_buffer_event(int64_t bytes_delta, int64_t elements_delta) { |
313 | buffered_bytes_ += bytes_delta; |
314 | buffered_elements_ += elements_delta; |
315 | // There is no need to maintain watermarks for synchronous ops because we |
316 | // will not upsize or downsize the buffers of synchronous ops. |
317 | if (IsAsync()) { |
318 | int64_t low_watermark = |
319 | std::min(buffered_elements_low_, buffered_elements_); |
320 | buffered_elements_low_ = low_watermark; |
321 | int64_t high_watermark = |
322 | std::max(buffered_elements_high_, buffered_elements_); |
323 | buffered_elements_high_ = high_watermark; |
324 | } |
325 | } |
326 | |
327 | // Records that the node produced an element. |
328 | void record_element() TF_LOCKS_EXCLUDED(mu_) { |
329 | num_elements_++; |
330 | { |
331 | mutex_lock l(mu_); |
332 | UpdateProcessingTimeEma(); |
333 | } |
334 | } |
335 | |
336 | // Records that a node thread has started executing. |
337 | void record_start(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) { |
338 | DCHECK_EQ(work_start_, 0); |
339 | work_start_ = time_nanos; |
340 | } |
341 | |
342 | // Records that a node thread has stopped executing. |
343 | void record_stop(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) { |
344 | // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here. |
345 | if (work_start_ != 0) { |
346 | processing_time_ += time_nanos - work_start_; |
347 | work_start_ = 0; |
348 | } else { |
349 | VLOG(1) << "Encountered a stop event without a matching start event." ; |
350 | } |
351 | } |
352 | |
353 | // Returns whether work is currently being recorded, i.e. whether we are |
354 | // currently between a `record_start` and a `record_stop`. |
355 | bool is_recording() TF_LOCKS_EXCLUDED(mu_) { return work_start_ > 0; } |
356 | |
357 | // Removes an input. |
358 | void remove_input(std::shared_ptr<Node> input) TF_LOCKS_EXCLUDED(mu_) { |
359 | mutex_lock l(mu_); |
360 | inputs_.remove(input); |
361 | } |
362 | |
363 | // Sets the value that determines whether autotuning is enabled for this node. |
364 | void set_autotune(bool autotune) TF_LOCKS_EXCLUDED(mu_) { |
365 | autotune_.store(autotune); |
366 | } |
367 | |
368 | // Resets buffer watermarks to the current buffered elements. |
369 | void ResetBufferWatermarks() { |
370 | if (!IsAsync()) { |
371 | return; |
372 | } |
373 | int64_t current_buffer_size = buffered_elements_; |
374 | buffered_elements_low_ = current_buffer_size; |
375 | buffered_elements_high_ = current_buffer_size; |
376 | } |
377 | |
378 | // Returns true for asynchronous nodes; false otherwise. |
379 | virtual bool IsAsync() const { return false; } |
380 | |
381 | // Returns the ratio of the node, which is defined as the number of elements |
382 | // per input needed by the node to produce an element, e.g. batch size of a |
383 | // `Batch`. It can be 0 if the ratio is unknown. |
384 | virtual double Ratio() const { return 1.0; } |
385 | |
386 | // Computes the self time in nanoseconds of the node to produce one element. |
387 | virtual double ComputeSelfTime() const; |
388 | |
389 | // Returns the parameter value if it exists, not ok status otherwise. |
390 | StatusOr<double> ParameterValue(const std::string& parameter_name) const |
391 | TF_LOCKS_EXCLUDED(mu_) { |
392 | tf_shared_lock l(mu_); |
393 | if (parameters_.contains(parameter_name)) { |
394 | return parameters_.at(parameter_name)->value; |
395 | } |
396 | return errors::NotFound("Parameter " , parameter_name, |
397 | " was not found in model node " , long_name()); |
398 | } |
399 | |
400 | // Given the average time between events when the elements in the buffer are |
401 | // produced (`producer_time`), the average time between events when elements |
402 | // in the buffer are consumed (`consumer_time`) and the buffer size, the |
403 | // method computes the expected time an consumer event will have to wait. |
404 | // |
405 | // The wait time is approximated as the product of the probability the buffer |
406 | // will be empty and the time it takes to produce an element into the buffer. |
407 | // |
408 | // The formula used for computing the probability is derived by modeling the |
409 | // problem as an M/M/1/K queue |
410 | // (https://en.wikipedia.org/wiki/Birth%E2%80%93death_process#M/M/1/K_queue). |
411 | // |
412 | // Collects derivatives of `ComputeWaitTime` w.r.t `producer_time`, |
413 | // `consumer_time' and `buffer_size` if the corresponding pointers are not |
414 | // `nullptr`. |
415 | static double ComputeWaitTime(const double& producer_time, |
416 | const double& consumer_time, |
417 | const double& buffer_size, |
418 | double* producer_time_derivative, |
419 | double* consumer_time_derivative, |
420 | double* buffer_size_derivative); |
421 | |
422 | // Collects tunable parameters in the subtree rooted in this node. |
423 | ModelParameters CollectTunableParameters() const TF_LOCKS_EXCLUDED(mu_); |
424 | |
425 | // Collects tunable parameters in this node. |
426 | ModelParameters CollectNodeTunableParameters() const TF_LOCKS_EXCLUDED(mu_); |
427 | |
428 | // Returns a human-readable representation of this node. |
429 | string DebugString() const TF_LOCKS_EXCLUDED(mu_); |
430 | |
431 | // Flushes the metrics recorded by this node. |
432 | void FlushMetrics() TF_LOCKS_EXCLUDED(mu_); |
433 | |
434 | // Returns the per-element output time for this node and if `gradients` is not |
435 | // `nullptr`, collects the output time gradient w.r.t. tunable parameters of |
436 | // the subtree rooted in this node. |
437 | double OutputTime(NodeValues* input_times, |
438 | ParameterGradients* gradients) const TF_LOCKS_EXCLUDED(mu_); |
439 | |
440 | // Returns a copy of this node, making a deep copy of its inputs and a |
441 | // shallow copy of its tunable parameters. |
442 | // |
443 | // The purpose for this method is to allow the model optimization logic to |
444 | // operate over immutable state while allowing concurrent model updates. |
445 | std::shared_ptr<Node> Snapshot() const TF_LOCKS_EXCLUDED(mu_); |
446 | |
447 | // Returns the per-element processing time in nanoseconds spent in this node. |
448 | double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_); |
449 | |
450 | // Returns the total number of bytes buffered in all nodes in the subtree for |
451 | // which autotuning is enabled. |
452 | double TotalBufferedBytes() const TF_LOCKS_EXCLUDED(mu_); |
453 | |
454 | // Collects the total buffer limit of all nodes in the subtree for which |
455 | // autotuning is enabled. This number represents the amount of memory that |
456 | // would be used by the subtree nodes if all of their buffers were full. |
457 | double TotalMaximumBufferedBytes() const TF_LOCKS_EXCLUDED(mu_); |
458 | |
459 | // Returns the per-element CPU time in nanoseconds spent in the subtree rooted |
460 | // in this node. If `processing_times` is not `nullptr`, collects the |
461 | // per-element CPU time spent in each node of the subtree. |
462 | double TotalProcessingTime(NodeValues* processing_times) |
463 | TF_LOCKS_EXCLUDED(mu_); |
464 | |
465 | // Produces a proto for this node. Does not produce a proto for input nodes. |
466 | virtual Status ToProto(ModelProto::Node* node_proto) const; |
467 | |
468 | // Restores a node from the proto. Does not restore input nodes. |
469 | static Status FromProto(ModelProto::Node node_proto, |
470 | std::shared_ptr<Node> output, |
471 | std::shared_ptr<Node>* node); |
472 | |
473 | // Returns a vector of nodes of the subtree rooted in this node. The nodes are |
474 | // either in breadth-first search or reverse breadth-first search order |
475 | // depending on the `order` argument. The nodes are collected based on the |
476 | // results of the `collect_node` predicate: if the predicate returns `false` |
477 | // for a given node, then the subtree rooted in this node is excluded. The |
478 | // root node itself is not collected. |
479 | NodeVector CollectNodes(TraversalOrder order, |
480 | bool collect_node(const std::shared_ptr<Node>)) const |
481 | TF_LOCKS_EXCLUDED(mu_); |
482 | |
483 | // Downsizes buffer parameters of this node. Returns true if any buffer is |
484 | // downsized. |
485 | bool TryDownsizeBuffer(); |
486 | |
487 | // Collects buffer parameters of this node that should be upsized. |
488 | void CollectBufferParametersToUpsize( |
489 | absl::flat_hash_map<Node*, Parameter*>& node_parameters); |
490 | |
491 | protected: |
492 | // Used for (incrementally) recording metrics. The class is thread-safe. |
493 | class Metrics { |
494 | public: |
495 | explicit Metrics(const string& name) |
496 | : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)), |
497 | bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)), |
498 | num_elements_counter_(metrics::GetTFDataElementsCounter(name)), |
499 | recorded_bytes_consumed_(0), |
500 | recorded_bytes_produced_(0), |
501 | recorded_num_elements_(0) {} |
502 | |
503 | // Expects the total number of bytes consumed and records the delta since |
504 | // last invocation. |
505 | void record_bytes_consumed(int64_t total_bytes) { |
506 | int64_t delta = |
507 | total_bytes - recorded_bytes_consumed_.exchange(total_bytes); |
508 | bytes_consumed_counter_->IncrementBy(delta); |
509 | } |
510 | |
511 | // Expects the total number of bytes produced and records the delta since |
512 | // last invocation. |
513 | void record_bytes_produced(int64_t total_bytes) { |
514 | int64_t delta = |
515 | total_bytes - recorded_bytes_produced_.exchange(total_bytes); |
516 | bytes_produced_counter_->IncrementBy(delta); |
517 | } |
518 | |
519 | // Expects the total number of elements produced and records the delta since |
520 | // last invocation. |
521 | void record_num_elements(int64_t total_elements) { |
522 | int64_t delta = |
523 | total_elements - recorded_num_elements_.exchange(total_elements); |
524 | num_elements_counter_->IncrementBy(delta); |
525 | } |
526 | |
527 | private: |
528 | monitoring::CounterCell* const bytes_consumed_counter_; |
529 | monitoring::CounterCell* const bytes_produced_counter_; |
530 | monitoring::CounterCell* const num_elements_counter_; |
531 | std::atomic<int64_t> recorded_bytes_consumed_; |
532 | std::atomic<int64_t> recorded_bytes_produced_; |
533 | std::atomic<int64_t> recorded_num_elements_; |
534 | }; |
535 | |
536 | // Computes the exponential moving average of processing time per element. |
537 | void UpdateProcessingTimeEma() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
538 | if (previous_processing_time_ == 0) { |
539 | if (num_elements_ > 0) { |
540 | processing_time_ema_ = static_cast<double>(processing_time_) / |
541 | static_cast<double>(num_elements_); |
542 | } else { |
543 | processing_time_ema_ = static_cast<double>(processing_time_); |
544 | } |
545 | } else { |
546 | processing_time_ema_ = |
547 | (1.0 - kProcessingTimeEmaWeight) * processing_time_ema_ + |
548 | kProcessingTimeEmaWeight * |
549 | static_cast<double>(processing_time_ - previous_processing_time_); |
550 | } |
551 | previous_processing_time_ = processing_time_; |
552 | } |
553 | |
554 | // Returns the number of inputs. |
555 | int64_t num_inputs() const TF_SHARED_LOCKS_REQUIRED(mu_) { |
556 | int64_t num_inputs = 0; |
557 | for (auto& input : inputs_) { |
558 | // Inputs for which autotuning is disabled are excluded. |
559 | if (input->autotune()) { |
560 | ++num_inputs; |
561 | } |
562 | } |
563 | return num_inputs; |
564 | } |
565 | |
566 | // Creates a clone of this node. |
567 | virtual std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const |
568 | TF_SHARED_LOCKS_REQUIRED(mu_) = 0; |
569 | |
570 | // Returns the average size of an element buffered in this node. |
571 | double AverageBufferedElementSize() const TF_SHARED_LOCKS_REQUIRED(mu_); |
572 | |
573 | // Returns the sum of per-element output time for the tunable inputs of this |
574 | // node. |
575 | double OutputTimeForInputs(const NodeValues& output_times) const |
576 | TF_SHARED_LOCKS_REQUIRED(mu_); |
577 | |
578 | // Returns the sum of output time gradient w.r.t. input time for the tunable |
579 | // inputs of this node. |
580 | double OutputTimeGradientsForInputs(const NodeValues& output_time_gradients) |
581 | const TF_SHARED_LOCKS_REQUIRED(mu_); |
582 | |
583 | // Computes the input time for this node and stores it in `input_times`. |
584 | virtual void InputTimeLocked(NodeValues* input_times) const |
585 | TF_SHARED_LOCKS_REQUIRED(mu_) = 0; |
586 | |
587 | // Computes the per-element output time for this node and stores it in |
588 | // `output_times`. If `gradients` is not `nullptr`, computes the output time |
589 | // gradient w.r.t. tunable parameters of the subtree rooted in this node and |
590 | // stores it in `gradients`, also computes the output time gradient w.r.t. |
591 | // input time and stores it in `output_time_gradients`. |
592 | virtual void OutputTimeLocked(const NodeValues& input_times, |
593 | ParameterGradients* gradients, |
594 | NodeValues* output_times, |
595 | NodeValues* output_time_gradients) const |
596 | TF_SHARED_LOCKS_REQUIRED(mu_) = 0; |
597 | |
598 | // Returns the sum of per-element processing time for the inputs of this node |
599 | // by adding values for input nodes in `total_processing_times`. Processing |
600 | // time for a given input is a weighted combination of a statistic based on |
601 | // history of input processing time and the actual time. This is done to |
602 | // improve accuracy of processing time estimation for newly created inputs. |
603 | // |
604 | // Uniform distribution of per-element processing times across different |
605 | // inputs is assumed. |
606 | double TotalProcessingTimeForInputs(const NodeValues& total_processing_times) |
607 | TF_SHARED_LOCKS_REQUIRED(mu_); |
608 | |
609 | // Returns the per-element processing time spent in this node. |
610 | double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_); |
611 | |
612 | // Computes the per-element CPU time spent in the subtree rooted in this node |
613 | // and stores it in `total_processing_times`. If `processing_times` is not |
614 | // `nullptr`, collects the per-element CPU time spent in each node of the |
615 | // subtree. |
616 | virtual void TotalProcessingTimeLocked(NodeValues* processing_times, |
617 | NodeValues* total_processing_times) |
618 | TF_SHARED_LOCKS_REQUIRED(mu_) = 0; |
619 | |
620 | // This is the locked version of the public `CollectNodes`. |
621 | NodeVector CollectNodesLocked(TraversalOrder order, |
622 | bool collect_node(const std::shared_ptr<Node>)) |
623 | const TF_SHARED_LOCKS_REQUIRED(mu_); |
624 | |
625 | // Collects tunable parameters in the subtree rooted in this node assuming |
626 | // mutex locked. |
627 | ModelParameters CollectTunableParametersLocked() const |
628 | TF_SHARED_LOCKS_REQUIRED(mu_); |
629 | |
630 | // Collect tunable parameters on the nodes which have recorded |
631 | // elements. |
632 | void CollectTunableParametersHelper(ModelParameters* parameters) const |
633 | TF_SHARED_LOCKS_REQUIRED(mu_); |
634 | |
635 | // Build up debug string for the node and store in the debug strings map. |
636 | void DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings) |
637 | const TF_SHARED_LOCKS_REQUIRED(mu_); |
638 | |
639 | // Copy the node and add the (input, copy) pairs to the NodePairList. |
640 | std::shared_ptr<Node> SnapshotHelper(std::shared_ptr<Node> cloned_output, |
641 | NodePairList* node_pairs) const; |
642 | |
643 | // Compute total buffered bytes for the node and store in the total bytes map. |
644 | void TotalBufferedBytesHelper(NodeValues* total_bytes) const |
645 | TF_SHARED_LOCKS_REQUIRED(mu_); |
646 | |
647 | // Compute total maximum buffered bytes for the node and store in the total |
648 | // bytes map. |
649 | void TotalMaximumBufferedBytesHelper(NodeValues* total_bytes) const |
650 | TF_SHARED_LOCKS_REQUIRED(mu_); |
651 | |
652 | // Compute and return the maximum buffered bytes on the node itself. By |
653 | // default non-tunable nodes are assumed not to buffer any bytes, so the |
654 | // tunable nodes as subclasses are expected to override this method to ensure |
655 | // that the optimization algorithm respects the memory budget. |
656 | virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_); |
657 | |
658 | // Restores node from the proto. Note that this is not done recursively, i.e. |
659 | // input nodes are not restored. |
660 | static Status FromProtoHelper(ModelProto::Node node_proto, |
661 | std::shared_ptr<Node> node); |
662 | |
663 | // Stores the time passed to the last call to `Node::record_start()` on the |
664 | // current thread. |
665 | // |
666 | // NOTE: This thread-local variable is shared between all instances of `Node` |
667 | // on which the same thread calls `record_start()` or `record_stop()`. It |
668 | // relies on the invariant that at most one `Node` can be "active" on a |
669 | // particular thread at any time. Therefore if `n->record_start()` is called |
670 | // on thread `t`, then `n->record_stop()` must be called before another call |
671 | // to `Node::record_start()` (for any node). |
672 | static thread_local int64_t work_start_; // Will be initialized to zero. |
673 | |
674 | mutable mutex mu_; |
675 | const int64_t id_; |
676 | const string name_; |
677 | |
678 | // Indicates whether the subtree rooted in this node should be included in |
679 | // autotuning. In particular, if this is `false`, then the subtree is excluded |
680 | // from computation of output time and processing time. |
681 | std::atomic<bool> autotune_; |
682 | std::atomic<int64_t> buffered_bytes_; |
683 | std::atomic<int64_t> buffered_elements_; |
684 | std::atomic<int64_t> buffered_elements_low_; |
685 | std::atomic<int64_t> buffered_elements_high_; |
686 | std::atomic<int64_t> bytes_consumed_; |
687 | std::atomic<int64_t> bytes_produced_; |
688 | std::atomic<int64_t> num_elements_; |
689 | std::atomic<int64_t> processing_time_; |
690 | std::atomic<bool> record_metrics_; |
691 | Metrics metrics_; |
692 | absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_ |
693 | TF_GUARDED_BY(mu_); |
694 | |
695 | // Statistic of inputs processing time history. |
696 | double input_processing_time_sum_ = 0.0L; |
697 | int64_t input_processing_time_count_ = 0; |
698 | |
699 | // Holds the previous processing time and the per element processing time |
700 | // exponential moving average. |
701 | int64_t previous_processing_time_ TF_GUARDED_BY(mu_) = 0; |
702 | double processing_time_ema_ TF_GUARDED_BY(mu_) = 0.0; |
703 | |
704 | // Inputs of this node. These can represent an iterator created from the input |
705 | // dataset but also other input iterators (e.g. created by the user-defined |
706 | // functions of `flat_map` or `interleave`). |
707 | std::list<std::shared_ptr<Node>> inputs_ TF_GUARDED_BY(mu_); |
708 | |
709 | // The reference to the output node is not owned so that deletion of a |
710 | // node results in recursive deletion of the subtree rooted in the node. |
711 | Node* const output_; |
712 | }; |
713 | |
714 | // InterleaveMany is used to model datasets whose inputs are used to create |
715 | // datasets whose elements are then interleaved. |
716 | std::shared_ptr<Node> MakeInterleaveManyNode( |
717 | Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters); |
718 | |
719 | // AsyncInterleaveMany nodes are the asynchronous version of InterleaveMany |
720 | // nodes. |
721 | std::shared_ptr<Node> MakeAsyncInterleaveManyNode( |
722 | Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters); |
723 | |
724 | // KnownMany nodes model datasets that synchronously consume known number of |
725 | // input element per output element. |
726 | std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio); |
727 | |
728 | // AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes. |
729 | std::shared_ptr<Node> MakeAsyncKnownRatioNode( |
730 | Node::Args args, double ratio, double memory_ratio, |
731 | std::vector<std::shared_ptr<Parameter>> parameters); |
732 | |
733 | std::shared_ptr<Node> MakeAsyncKnownRatioNode( |
734 | Node::Args args, double ratio, |
735 | std::vector<std::shared_ptr<Parameter>> parameters); |
736 | |
737 | // Source nodes represent data sources. |
738 | std::shared_ptr<Node> MakeSourceNode(Node::Args args); |
739 | |
740 | // UnknownMany nodes represent datasets that synchronously consume an |
741 | // unknown number of input elements per output. |
742 | // |
743 | // Unlike KnownRatio nodes which expect the ratio between inputs and outputs is |
744 | // specified as a parameter, UnknownRatio estimates the ratio empirically. |
745 | std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args); |
746 | |
747 | // AsyncUnknownRatio nodes are the asynchronous version of unknown ratio nodes. |
748 | std::shared_ptr<Node> MakeAsyncUnknownRatioNode( |
749 | Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters); |
750 | |
751 | // Unknown nodes represent datasets for which we do not have a model. It acts |
752 | // as pass-through between inputs and output. |
753 | std::shared_ptr<Node> MakeUnknownNode(Node::Args args); |
754 | |
755 | // Abstract representation of a TensorFlow input pipeline that can be used |
756 | // for collecting runtime information and optimizing performance. It collects |
757 | // runtime information about execution of the input pipeline that is used to |
758 | // create a performance model, which is in turn used to identify optimal values |
759 | // of tunable parameters. |
760 | // |
761 | // Developers of tf.data transformations are not expected to interact with this |
762 | // class directly. Boiler plate code for creating the abstract representation of |
763 | // the input pipeline and collecting runtime information has been added to the |
764 | // implementation of `DatasetBase` and `DatasetBaseIterator` respectively. |
765 | // |
766 | // The order of locks acquired is SharedState lock, Model lock, Node lock. |
767 | // SharedState lock is acquired first because it shares the same lock as the |
768 | // dataset iterator that contains it. |
769 | class Model { |
770 | public: |
771 | using OptimizationParams = ModelProto::OptimizationParams; |
772 | using ModelParameters = Node::ModelParameters; |
773 | using NodeValues = Node::NodeValues; |
774 | using ParameterGradients = Node::ParameterGradients; |
775 | |
776 | Model(); |
777 | ~Model(); |
778 | |
779 | // Returns a pointer to the model's output node. |
780 | const std::shared_ptr<Node> output() const { |
781 | mutex_lock l(mu_); |
782 | return output_; |
783 | } |
784 | |
785 | // Set the experiment that this job is part of. |
786 | void SetExperiment(const string& experiment) { experiment_ = experiment; } |
787 | |
788 | // Adds a node with the given name and given parent. |
789 | void AddNode(Node::Factory factory, const string& name, |
790 | std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node) |
791 | TF_LOCKS_EXCLUDED(mu_); |
792 | |
793 | // Returns a human-readable string representation of the model. This method |
794 | // can be invoked automatically by monitoring gauges and to avoid frequent |
795 | // recomputation, the implementation caches the result. |
796 | std::string DebugString(); |
797 | |
798 | // Uses the given algorithm and resource budgets to periodically perform the |
799 | // autotuning optimization. |
800 | // |
801 | // To terminate the execution of the optimization loop, the caller needs to |
802 | // invoke `cancellation_mgr->StartCancel()`. |
803 | Status OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget, |
804 | int64_t ram_budget, |
805 | CancellationManager* cancellation_manager); |
806 | |
807 | // Uses the given algorithm and resource budgets to perform the autotuning |
808 | // optimization. |
809 | void Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget, |
810 | int64_t ram_budget, double model_input_time, |
811 | CancellationManager* cancellation_manager); |
812 | |
813 | // Optimizes buffers in the pipeline rooted at `snapshot`. It downsizes |
814 | // buffers that are too large and upsizes buffers that are too small while |
815 | // respecting the ram budget. If any node is downsized or upsized, the |
816 | // watermarks of all nodes are reset to the buffered elements. |
817 | void OptimizeBuffers(std::shared_ptr<Node> snapshot, int64_t ram_budget); |
818 | |
819 | // Collects the output time and if `gradients` is not `nullptr`, the output |
820 | // time gradient w.r.t. tunable parameters of the subtree rooted in the given |
821 | // node. |
822 | double OutputTime(std::shared_ptr<Node> node, double model_input_time, |
823 | ParameterGradients* gradients); |
824 | |
825 | // Removes the given node. |
826 | void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_); |
827 | |
828 | // Produces a proto for this model. |
829 | Status ToProto(ModelProto* model_proto); |
830 | |
831 | // Restores a model from the proto. |
832 | static Status FromProto(ModelProto model_proto, |
833 | std::unique_ptr<Model>* model); |
834 | |
835 | // Saves this model with a given snapshot and its optimization parameters to a |
836 | // file. Note that the file directory must already exist. |
837 | Status Save(const string& fname, std::shared_ptr<Node> snapshot, |
838 | const OptimizationParams& optimization_params); |
839 | |
840 | // Loads a model and its optimization parameters from a file with the given |
841 | // name. |
842 | static Status Load(const string& fname, std::unique_ptr<Model>* model, |
843 | OptimizationParams* optimization_params); |
844 | |
845 | // Records gap time between consecutive `GetNext()` calls. |
846 | void RecordIteratorGapTime(uint64_t duration_usec); |
847 | |
848 | // Computes the target time in nsecs to use for `STAGE_BASED` autotune |
849 | // algorithm. |
850 | double ComputeTargetTimeNsec(); |
851 | |
852 | private: |
853 | // Determines whether optimization should stop given total processing time, |
854 | // estimated output time, and estimated number of buffers bytes. |
855 | using StopPredicate = |
856 | std::function<bool(const ModelParameters&, double, double, double)>; |
857 | |
858 | static constexpr int64_t kOptimizationPeriodMinMs = 10; |
859 | static constexpr int64_t kOptimizationPeriodMaxMs = |
860 | 60 * EnvTime::kSecondsToMillis; |
861 | |
862 | // Collects tunable parameters in the tree rooted in the given node, returning |
863 | // a vector which contains pairs of node names and tunable parameters. |
864 | ModelParameters CollectTunableParameters(std::shared_ptr<Node> node); |
865 | |
866 | // Downsizes buffers that are too large for all nodes rooted at `snapshot`. |
867 | // Returns true if any buffer is downsized. |
868 | bool DownsizeBuffers(std::shared_ptr<Node> snapshot); |
869 | |
870 | // Upsizes buffers that are too small for all nodes rooted at `snapshot` while |
871 | // respecting the ram budget. Returns true if any buffer is upsized. |
872 | bool UpsizeBuffers(std::shared_ptr<Node> snapshot, int64_t ram_budget); |
873 | |
874 | // Reset buffer watermarks of all asynchronous nodes to their buffered |
875 | // elements. |
876 | void ResetBufferWatermarks(); |
877 | |
878 | // Collects buffer parameters of all nodes in the model that should be |
879 | // upsized. |
880 | absl::flat_hash_map<Node*, Parameter*> CollectBufferParametersToUpsize( |
881 | std::shared_ptr<Node> snapshot); |
882 | |
883 | // Flushes metrics recorded by the model. |
884 | void FlushMetrics() TF_LOCKS_EXCLUDED(mu_); |
885 | |
886 | // This optimization algorithm starts by setting all tunable parallelism |
887 | // parameters to the minimum value. It then improves current parameters by |
888 | // making a step in the direction opposite to the gradient of `OutputTime` and |
889 | // projecting resulting values on the feasible intervals. Improvement step is |
890 | // repeated until either the output time improvement is smaller than threshold |
891 | // value or the output time is less than the processing time needed to produce |
892 | // an element divided by CPU budget. |
893 | void OptimizeGradientDescent(std::shared_ptr<Node> snapshot, |
894 | const OptimizationParams& optimization_params, |
895 | CancellationManager* cancellation_manager); |
896 | |
897 | // Helper method for implementing hill-climb optimization that can be |
898 | // parametrized by a predicate to use for stopping the optimization. |
899 | void OptimizeHillClimbHelper(std::shared_ptr<Node> snapshot, |
900 | const OptimizationParams& optimization_params, |
901 | CancellationManager* cancellation_manager, |
902 | StopPredicate should_stop); |
903 | |
904 | // This optimization algorithm starts by setting all tunable parallelism |
905 | // parameters to the minimum value. It then repeatedly identifies the |
906 | // parameter whose increase in parallelism decreases the output time the most. |
907 | // This process is repeated until all parameters reach their maximum values or |
908 | // the projected output time is less than or equal to the processing time |
909 | // needed to produce an element divided by CPU budget. |
910 | void OptimizeHillClimb(std::shared_ptr<Node> snapshot, |
911 | const OptimizationParams& optimization_params, |
912 | CancellationManager* cancellation_manager); |
913 | |
914 | // This optimization behaves similarly to the hill climb optimization but uses |
915 | // a relaxed stoping condition, allowing the optimization to oversubscribe |
916 | // CPU. |
917 | void OptimizeMaxParallelism(std::shared_ptr<Node> snapshot, |
918 | const OptimizationParams& optimization_params, |
919 | CancellationManager* cancellation_manager); |
920 | |
921 | // This optimization starts by setting all tunable parallelism parameters to |
922 | // their minimum values. It then repeatedly increases the parallelism |
923 | // parameter of the longest stage by 1 until either the longest stage is |
924 | // faster than the target time or the memory or CPU budget is fully utilized. |
925 | // TODO(b/226910071): The second part of this algorithm optimizes the buffer |
926 | // sizes of parallel ops. |
927 | void OptimizeStageBased(std::shared_ptr<Node> snapshot, |
928 | const OptimizationParams& optimization_params, |
929 | CancellationManager* cancellation_manager); |
930 | |
931 | // This is the first part of the stage-based optimization that optimizes |
932 | // tunable parallelism parameters. |
933 | void OptimizeStageBasedParallelism( |
934 | std::shared_ptr<Node> snapshot, double target_time_nsec, |
935 | const OptimizationParams& optimization_params, |
936 | CancellationManager* cancellation_manager); |
937 | |
938 | // Determines if we should stop the gradient descent optimization iterations |
939 | // based on number of increasable parameters, CPU budget, RAM budget and |
940 | // current resource usage. |
941 | bool ShouldStop(int64_t cpu_budget, int64_t ram_budget, |
942 | const ModelParameters& parameters, |
943 | const ModelParameters& parallelism_parameters, |
944 | const ModelParameters& buffer_size_parameters, |
945 | std::shared_ptr<Node> snapshot, bool* cpu_budget_reached); |
946 | |
947 | // Collects the processing time for the given node. |
948 | double TotalProcessingTime(std::shared_ptr<Node> node); |
949 | |
950 | // Collects the total number of bytes buffered in all nodes in the subtree |
951 | // rooted in the given node for which autotuning is enabled. |
952 | double TotalBufferedBytes(std::shared_ptr<Node> node); |
953 | |
954 | // Collects the total buffer limit of all nodes in the subtree rooted in the |
955 | // given node for which autotuning is enabled. This number represents the |
956 | // amount of memory that would be used by the subtree nodes if all of their |
957 | // buffers were full. |
958 | double TotalMaximumBufferedBytes(std::shared_ptr<Node> node); |
959 | |
960 | // Used for coordination between different input pipeline threads. Exclusive |
961 | // access is required only when adding or removing nodes. Concurrent access to |
962 | // existing nodes is protected by a node mutex. |
963 | mutable mutex mu_; |
964 | // Used for coordinating the optimization loop and model modifications. |
965 | condition_variable optimize_cond_var_; |
966 | int64_t id_counter_ TF_GUARDED_BY(mu_) = 1; |
967 | std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_) = nullptr; |
968 | |
969 | // Determines the time the optimization loop should wait between |
970 | // running optimizations. |
971 | int64_t optimization_period_ms_ TF_GUARDED_BY(mu_); |
972 | |
973 | // Gauge cell that can be used to collect the state of the model. |
974 | monitoring::GaugeCell<std::function<std::string()>>* model_gauge_cell_ = |
975 | nullptr; |
976 | // Time use for rate limitting the recomputation of human-readable string |
977 | // represention of the model. |
978 | absl::Time cache_until_ = absl::InfinitePast(); |
979 | // Cached result of the `DebugString()` invocation used to implement rate |
980 | // limitting of the computation. |
981 | std::string cached_debug_string_ = "" ; |
982 | // Used to coordinate gap time updates between different threads. Gap time is |
983 | // the time between the completion of the previous `GetNext()` and the start |
984 | // of the next `GetNext()`. |
985 | mutable mutex gap_mu_; |
986 | // Stores the latest gap times between consecutive `GetNext()`. |
987 | std::deque<uint64_t> gap_times_usec_ TF_GUARDED_BY(gap_mu_); |
988 | // The experiment that this job is part of. |
989 | std::string experiment_ = "" ; |
990 | }; |
991 | |
992 | // Class to compute timing information for a model. |
993 | class ModelTiming { |
994 | public: |
995 | struct NodeTiming { |
996 | // Pipeline ratio is the number of elements this node needs to produce in |
997 | // order to produce an element at the root of the pipeline. |
998 | double pipeline_ratio = 0.0; |
999 | // The self time it takes this node to produce the elements needed to |
1000 | // produce one element of the root of the pipeline. |
1001 | double self_time_nsec = 0.0; |
1002 | // The total time it takes this node and the subtree rooted at this node to |
1003 | // produce the elements needed to produce one element at the root of the |
1004 | // pipeline. |
1005 | double total_time_nsec = 0.0; |
1006 | }; |
1007 | |
1008 | explicit ModelTiming(std::shared_ptr<Node> root); |
1009 | |
1010 | // Returns the timing data for `node`. |
1011 | const NodeTiming* GetTiming(const Node* node) const; |
1012 | |
1013 | // Returns the root nodes of all stages. |
1014 | std::vector<std::shared_ptr<Node>> GetStageRoots() const; |
1015 | |
1016 | // Returns all the nodes of a stage given the stage root. |
1017 | std::vector<std::shared_ptr<Node>> GetStageNodes( |
1018 | std::shared_ptr<Node> stage_root) const; |
1019 | |
1020 | // Computes the total time for a node. |
1021 | void ComputeNodeTotalTime(const Node& node); |
1022 | |
1023 | private: |
1024 | // Computes the pipeline ratios of all nodes. |
1025 | void ComputePipelineRatios(const Node::NodeVector& bfs_nodes); |
1026 | |
1027 | // Computes the total time for all nodes. The `reverse_bfs_nodes` are assumed |
1028 | // to be a vector of model nodes in reversed BFS manner. |
1029 | void ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes); |
1030 | |
1031 | // Computes the total time of a node that is not an async interleave node. |
1032 | void ComputeNonAsyncInterleaveManyTotalTime(const Node& node); |
1033 | |
1034 | // Computes the total time of an async interleave node. |
1035 | void ComputeAsyncInterleaveManyTotalTime(const Node& node); |
1036 | |
1037 | // Returns a vector of all nodes in the model. The nodes are either in |
1038 | // breadth-first search or reverse breadth-first search order depending on the |
1039 | // `order` argument. The nodes are collected based on the results of the |
1040 | // `collect_node` predicate: if the predicate returns `false` for a given |
1041 | // node, then the subtree rooted in this node is excluded. The root node |
1042 | // itself is not collected. |
1043 | Node::NodeVector CollectNodes( |
1044 | std::shared_ptr<Node> root, TraversalOrder order, |
1045 | bool collect_node(const std::shared_ptr<Node>)) const; |
1046 | |
1047 | // Stores a pointer to the root of a model. |
1048 | std::shared_ptr<Node> root_; |
1049 | |
1050 | // Holds a mapping from node to its timing node. |
1051 | absl::flat_hash_map<const Node*, NodeTiming> timing_nodes_; |
1052 | }; |
1053 | |
1054 | } // namespace model |
1055 | } // namespace data |
1056 | } // namespace tensorflow |
1057 | |
1058 | #endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ |
1059 | |