1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
16#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
17
18#include <algorithm>
19#include <list>
20#include <memory>
21#include <string>
22// TODO(b/114492873): Move this include into core/platform.
23#include <thread> // NOLINT
24#include <utility>
25#include <vector>
26
27#include "absl/container/flat_hash_map.h"
28#include "tensorflow/core/framework/cancellation.h"
29#include "tensorflow/core/framework/metrics.h"
30#include "tensorflow/core/framework/model.pb.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/lib/gtl/cleanup.h"
33#include "tensorflow/core/lib/gtl/map_util.h"
34#include "tensorflow/core/lib/histogram/histogram.h"
35#include "tensorflow/core/lib/random/random.h"
36#include "tensorflow/core/platform/cpu_info.h"
37#include "tensorflow/core/platform/env.h"
38#include "tensorflow/core/platform/mutex.h"
39#include "tensorflow/core/platform/path.h"
40#include "tensorflow/core/platform/statusor.h"
41#include "tensorflow/core/platform/strcat.h"
42#include "tensorflow/core/platform/stringprintf.h"
43
44namespace tensorflow {
45namespace data {
46namespace model {
47
48// A constant that can be used to enable auto-tuning.
49constexpr int64_t kAutotune = -1;
50constexpr char kParallelism[] = "parallelism";
51constexpr char kBufferSize[] = "buffer_size";
52constexpr char kCycleLength[] = "cycle_length";
53constexpr char kDeterministic[] = "deterministic";
54constexpr char kMaxBufferedElements[] = "max_buffered_elements";
55
56// A key used to identify the input time of the model.
57constexpr char kModelInputTimeKey[] = "model_input_time";
58
59// Default share of available RAM that can be used by model's internal buffers.
60constexpr double kRamBudgetShare = 0.5;
61
62// Weight of the latest processing time used in computing the exponential moving
63// average of processing time per element.
64constexpr double kProcessingTimeEmaWeight = 0.1;
65
66enum class TraversalOrder {
67 BFS = 0,
68 REVERSE_BFS = 1,
69};
70
71// Represents thread-safe state that can be shared between an input pipeline and
72// the performance model.
73struct SharedState {
74 public:
75 SharedState(int64_t value, std::shared_ptr<mutex> mu,
76 std::shared_ptr<condition_variable> cond_var)
77 : value(value),
78 mu(std::move(mu)),
79 cond_var(std::move(cond_var)),
80 tunable(value == kAutotune) {}
81
82 double value;
83 const std::shared_ptr<mutex> mu;
84 const std::shared_ptr<condition_variable> cond_var;
85 const bool tunable;
86};
87
88// Represents a parameter.
89struct Parameter {
90 Parameter(const string& name, std::shared_ptr<SharedState> state, double min,
91 double max)
92 : name(name),
93 // Sometimes non-autotune nodes (with `autotune_=false`) may contain
94 // parameters (for example inputs of parallel interleave dataset which
95 // are not in the current cycle). To avoid unrealistic situation
96 // (say `buffer_size=-1` or `parallelism=-1`) in the optimization
97 // computation, if the state value is `kAutotune=-1` (just to indicate
98 // the `SharedState` is tunable), we initialize the parameter value to
99 // be the minimal value of the state.
100 value(state == nullptr || state->value == kAutotune ? min
101 : state->value),
102 min(min),
103 max(max),
104 state(std::move(state)) {}
105
106 // Human-readable name of the parameter.
107 const string name;
108
109 // Identifies the model value of the parameter. This can be different from
110 // the actual value (e.g. during optimization search).
111 double value;
112
113 // Identifies the minimum value of the parameter.
114 const double min;
115
116 // Identifies the maximum value of the parameter.
117 const double max;
118
119 // Shared state of the parameter.
120 std::shared_ptr<SharedState> state;
121};
122
123// Returns a new tunable parameter.
124std::shared_ptr<Parameter> MakeParameter(const string& name,
125 std::shared_ptr<SharedState> state,
126 double min, double max);
127
128// Returns a new non-tunable parameter.
129std::shared_ptr<Parameter> MakeNonTunableParameter(const string& name,
130 double value);
131
132// Abstract representation of a TensorFlow input pipeline node. It collects
133// information about inputs to this node, processing time spent executing the
134// node logic, number of elements produced by the node, various other
135// information (e.g. batch size or execution parallelism).
136//
137// Developers of tf.data transformations are not expected to interact with
138// this class directly. Boiler plate code for creating the abstract
139// representation of the input pipeline and collecting common information has
140// been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
141// respectively.
142//
143// In addition, `DatasetBaseIterator` provides wrappers that can be used for
144// transformation-specific information collection. The `SetMetadata` wrapper
145// can be used to pass arbitrary metadata to the modeling framework, while the
146// `StartWork` and `StopWork` wrappers should be used to correctly account for
147// processing time of multi-threaded transformation that yield the CPU; such
148// transformations should invoke `StartWork()` when a transformation thread
149// starts executing (e.g. when created or woken up) and `StopWork()` when a
150// transformation thread stops executing (e.g. when returning or waiting).
151class Node {
152 public:
153 // Arguments for `Node` constructor.
154 struct Args {
155 int64_t id;
156 string name;
157 std::shared_ptr<Node> output;
158 };
159
160 using Factory = std::function<std::shared_ptr<Node>(Args)>;
161 using NodeVector = std::vector<std::shared_ptr<Node>>;
162 using NodePairList =
163 std::list<std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>>;
164 using ModelParameters =
165 std::vector<std::pair<string, std::shared_ptr<Parameter>>>;
166 using NodeValues = absl::flat_hash_map<string, double>;
167 using ParameterGradients =
168 absl::flat_hash_map<std::pair<string, string>, double>;
169
170 explicit Node(Args args)
171 : id_(args.id),
172 name_(std::move(args.name)),
173 autotune_(true),
174 buffered_bytes_(0),
175 buffered_elements_(0),
176 buffered_elements_low_(std::numeric_limits<int64_t>::max()),
177 buffered_elements_high_(std::numeric_limits<int64_t>::min()),
178 bytes_consumed_(0),
179 bytes_produced_(0),
180 num_elements_(0),
181 processing_time_(0),
182 record_metrics_(true),
183 metrics_(name_),
184 output_(args.output.get()) {}
185
186 virtual ~Node() {
187 // Clear the sub-nodes instead of relying on implicit shared pointer
188 // destructor to avoid potential stack overflow when the tree is deep.
189 std::deque<std::shared_ptr<Node>> queue;
190 {
191 mutex_lock l(mu_);
192 while (inputs_.size() > 0) {
193 queue.push_back(inputs_.front());
194 inputs_.pop_front();
195 }
196 }
197 while (!queue.empty()) {
198 auto node = queue.back();
199 queue.pop_back();
200 {
201 mutex_lock l(node->mu_);
202 while (node->inputs_.size() > 0) {
203 queue.push_back(node->inputs_.front());
204 node->inputs_.pop_front();
205 }
206 }
207 }
208
209 FlushMetrics();
210 }
211
212 // Adds an input.
213 void add_input(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_) {
214 mutex_lock l(mu_);
215 inputs_.push_back(node);
216 }
217
218 // Increments the aggregate processing time by the given delta.
219 void add_processing_time(int64_t delta) TF_LOCKS_EXCLUDED(mu_) {
220 processing_time_ += delta;
221 }
222
223 // Returns an indication whether autotuning is enabled for this node.
224 bool autotune() const TF_LOCKS_EXCLUDED(mu_) { return autotune_; }
225
226 // Returns the number of bytes stored in this node's buffer.
227 int64_t buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) {
228 return buffered_bytes_;
229 }
230
231 // Returns the number of elements stored in this node's buffer.
232 int64_t buffered_elements() const TF_LOCKS_EXCLUDED(mu_) {
233 return buffered_elements_;
234 }
235
236 // Returns the low watermark of the number of elements stored in this node's
237 // buffer. The watermarks are reset at the beginning of the execution time and
238 // each time the buffer is upsized or downsized.
239 int64_t buffered_elements_low() const TF_LOCKS_EXCLUDED(mu_) {
240 return buffered_elements_low_;
241 }
242
243 // Returns the high watermark of the number of elements stored in this node's
244 // buffer. The watermarks are reset at the beginning of the execution time and
245 // each time the buffer is upsized or downsized.
246 int64_t buffered_elements_high() const TF_LOCKS_EXCLUDED(mu_) {
247 return buffered_elements_high_;
248 }
249
250 // Returns the number of bytes consumed by the node.
251 int64_t bytes_consumed() const TF_LOCKS_EXCLUDED(mu_) {
252 return bytes_consumed_;
253 }
254
255 // Returns the number of bytes produced by the node.
256 int64_t bytes_produced() const TF_LOCKS_EXCLUDED(mu_) {
257 return bytes_produced_;
258 }
259
260 // Indicates whether the node has tunable parameters.
261 bool has_tunable_parameters() const TF_LOCKS_EXCLUDED(mu_) {
262 tf_shared_lock l(mu_);
263 for (const auto& pair : parameters_) {
264 if (pair.second->state->tunable) return true;
265 }
266 return false;
267 }
268
269 // Returns the unique node ID.
270 int64_t id() const TF_LOCKS_EXCLUDED(mu_) { return id_; }
271
272 // Returns the node inputs.
273 std::list<std::shared_ptr<Node>> inputs() const TF_LOCKS_EXCLUDED(mu_) {
274 tf_shared_lock l(mu_);
275 return inputs_;
276 }
277
278 // Returns a longer node name that is guaranteed to be unique.
279 string long_name() const { return strings::StrCat(name_, "(id:", id_, ")"); }
280
281 // Returns the node name.
282 const string& name() const { return name_; }
283
284 // Returns the number of elements produced by the node.
285 int64_t num_elements() const TF_LOCKS_EXCLUDED(mu_) { return num_elements_; }
286
287 // Returns the node output.
288 Node* output() const { return output_; }
289
290 // Returns the parameter value.
291 double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) {
292 tf_shared_lock l(mu_);
293 return parameters_.at(name)->state->value;
294 }
295
296 // Returns the aggregate processing time.
297 int64_t processing_time() const TF_LOCKS_EXCLUDED(mu_) {
298 return processing_time_;
299 }
300
301 // Records that the node consumed the given number of bytes.
302 void record_bytes_consumed(int64_t num_bytes) {
303 bytes_consumed_ += num_bytes;
304 }
305
306 // Records that the node produced the given number of bytes.
307 void record_bytes_produced(int64_t num_bytes) {
308 bytes_produced_ += num_bytes;
309 }
310
311 // Records the change in this node's buffer.
312 void record_buffer_event(int64_t bytes_delta, int64_t elements_delta) {
313 buffered_bytes_ += bytes_delta;
314 buffered_elements_ += elements_delta;
315 // There is no need to maintain watermarks for synchronous ops because we
316 // will not upsize or downsize the buffers of synchronous ops.
317 if (IsAsync()) {
318 int64_t low_watermark =
319 std::min(buffered_elements_low_, buffered_elements_);
320 buffered_elements_low_ = low_watermark;
321 int64_t high_watermark =
322 std::max(buffered_elements_high_, buffered_elements_);
323 buffered_elements_high_ = high_watermark;
324 }
325 }
326
327 // Records that the node produced an element.
328 void record_element() TF_LOCKS_EXCLUDED(mu_) {
329 num_elements_++;
330 {
331 mutex_lock l(mu_);
332 UpdateProcessingTimeEma();
333 }
334 }
335
336 // Records that a node thread has started executing.
337 void record_start(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) {
338 DCHECK_EQ(work_start_, 0);
339 work_start_ = time_nanos;
340 }
341
342 // Records that a node thread has stopped executing.
343 void record_stop(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) {
344 // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here.
345 if (work_start_ != 0) {
346 processing_time_ += time_nanos - work_start_;
347 work_start_ = 0;
348 } else {
349 VLOG(1) << "Encountered a stop event without a matching start event.";
350 }
351 }
352
353 // Returns whether work is currently being recorded, i.e. whether we are
354 // currently between a `record_start` and a `record_stop`.
355 bool is_recording() TF_LOCKS_EXCLUDED(mu_) { return work_start_ > 0; }
356
357 // Removes an input.
358 void remove_input(std::shared_ptr<Node> input) TF_LOCKS_EXCLUDED(mu_) {
359 mutex_lock l(mu_);
360 inputs_.remove(input);
361 }
362
363 // Sets the value that determines whether autotuning is enabled for this node.
364 void set_autotune(bool autotune) TF_LOCKS_EXCLUDED(mu_) {
365 autotune_.store(autotune);
366 }
367
368 // Resets buffer watermarks to the current buffered elements.
369 void ResetBufferWatermarks() {
370 if (!IsAsync()) {
371 return;
372 }
373 int64_t current_buffer_size = buffered_elements_;
374 buffered_elements_low_ = current_buffer_size;
375 buffered_elements_high_ = current_buffer_size;
376 }
377
378 // Returns true for asynchronous nodes; false otherwise.
379 virtual bool IsAsync() const { return false; }
380
381 // Returns the ratio of the node, which is defined as the number of elements
382 // per input needed by the node to produce an element, e.g. batch size of a
383 // `Batch`. It can be 0 if the ratio is unknown.
384 virtual double Ratio() const { return 1.0; }
385
386 // Computes the self time in nanoseconds of the node to produce one element.
387 virtual double ComputeSelfTime() const;
388
389 // Returns the parameter value if it exists, not ok status otherwise.
390 StatusOr<double> ParameterValue(const std::string& parameter_name) const
391 TF_LOCKS_EXCLUDED(mu_) {
392 tf_shared_lock l(mu_);
393 if (parameters_.contains(parameter_name)) {
394 return parameters_.at(parameter_name)->value;
395 }
396 return errors::NotFound("Parameter ", parameter_name,
397 " was not found in model node ", long_name());
398 }
399
400 // Given the average time between events when the elements in the buffer are
401 // produced (`producer_time`), the average time between events when elements
402 // in the buffer are consumed (`consumer_time`) and the buffer size, the
403 // method computes the expected time an consumer event will have to wait.
404 //
405 // The wait time is approximated as the product of the probability the buffer
406 // will be empty and the time it takes to produce an element into the buffer.
407 //
408 // The formula used for computing the probability is derived by modeling the
409 // problem as an M/M/1/K queue
410 // (https://en.wikipedia.org/wiki/Birth%E2%80%93death_process#M/M/1/K_queue).
411 //
412 // Collects derivatives of `ComputeWaitTime` w.r.t `producer_time`,
413 // `consumer_time' and `buffer_size` if the corresponding pointers are not
414 // `nullptr`.
415 static double ComputeWaitTime(const double& producer_time,
416 const double& consumer_time,
417 const double& buffer_size,
418 double* producer_time_derivative,
419 double* consumer_time_derivative,
420 double* buffer_size_derivative);
421
422 // Collects tunable parameters in the subtree rooted in this node.
423 ModelParameters CollectTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
424
425 // Collects tunable parameters in this node.
426 ModelParameters CollectNodeTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
427
428 // Returns a human-readable representation of this node.
429 string DebugString() const TF_LOCKS_EXCLUDED(mu_);
430
431 // Flushes the metrics recorded by this node.
432 void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
433
434 // Returns the per-element output time for this node and if `gradients` is not
435 // `nullptr`, collects the output time gradient w.r.t. tunable parameters of
436 // the subtree rooted in this node.
437 double OutputTime(NodeValues* input_times,
438 ParameterGradients* gradients) const TF_LOCKS_EXCLUDED(mu_);
439
440 // Returns a copy of this node, making a deep copy of its inputs and a
441 // shallow copy of its tunable parameters.
442 //
443 // The purpose for this method is to allow the model optimization logic to
444 // operate over immutable state while allowing concurrent model updates.
445 std::shared_ptr<Node> Snapshot() const TF_LOCKS_EXCLUDED(mu_);
446
447 // Returns the per-element processing time in nanoseconds spent in this node.
448 double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);
449
450 // Returns the total number of bytes buffered in all nodes in the subtree for
451 // which autotuning is enabled.
452 double TotalBufferedBytes() const TF_LOCKS_EXCLUDED(mu_);
453
454 // Collects the total buffer limit of all nodes in the subtree for which
455 // autotuning is enabled. This number represents the amount of memory that
456 // would be used by the subtree nodes if all of their buffers were full.
457 double TotalMaximumBufferedBytes() const TF_LOCKS_EXCLUDED(mu_);
458
459 // Returns the per-element CPU time in nanoseconds spent in the subtree rooted
460 // in this node. If `processing_times` is not `nullptr`, collects the
461 // per-element CPU time spent in each node of the subtree.
462 double TotalProcessingTime(NodeValues* processing_times)
463 TF_LOCKS_EXCLUDED(mu_);
464
465 // Produces a proto for this node. Does not produce a proto for input nodes.
466 virtual Status ToProto(ModelProto::Node* node_proto) const;
467
468 // Restores a node from the proto. Does not restore input nodes.
469 static Status FromProto(ModelProto::Node node_proto,
470 std::shared_ptr<Node> output,
471 std::shared_ptr<Node>* node);
472
473 // Returns a vector of nodes of the subtree rooted in this node. The nodes are
474 // either in breadth-first search or reverse breadth-first search order
475 // depending on the `order` argument. The nodes are collected based on the
476 // results of the `collect_node` predicate: if the predicate returns `false`
477 // for a given node, then the subtree rooted in this node is excluded. The
478 // root node itself is not collected.
479 NodeVector CollectNodes(TraversalOrder order,
480 bool collect_node(const std::shared_ptr<Node>)) const
481 TF_LOCKS_EXCLUDED(mu_);
482
483 // Downsizes buffer parameters of this node. Returns true if any buffer is
484 // downsized.
485 bool TryDownsizeBuffer();
486
487 // Collects buffer parameters of this node that should be upsized.
488 void CollectBufferParametersToUpsize(
489 absl::flat_hash_map<Node*, Parameter*>& node_parameters);
490
491 protected:
492 // Used for (incrementally) recording metrics. The class is thread-safe.
493 class Metrics {
494 public:
495 explicit Metrics(const string& name)
496 : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)),
497 bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)),
498 num_elements_counter_(metrics::GetTFDataElementsCounter(name)),
499 recorded_bytes_consumed_(0),
500 recorded_bytes_produced_(0),
501 recorded_num_elements_(0) {}
502
503 // Expects the total number of bytes consumed and records the delta since
504 // last invocation.
505 void record_bytes_consumed(int64_t total_bytes) {
506 int64_t delta =
507 total_bytes - recorded_bytes_consumed_.exchange(total_bytes);
508 bytes_consumed_counter_->IncrementBy(delta);
509 }
510
511 // Expects the total number of bytes produced and records the delta since
512 // last invocation.
513 void record_bytes_produced(int64_t total_bytes) {
514 int64_t delta =
515 total_bytes - recorded_bytes_produced_.exchange(total_bytes);
516 bytes_produced_counter_->IncrementBy(delta);
517 }
518
519 // Expects the total number of elements produced and records the delta since
520 // last invocation.
521 void record_num_elements(int64_t total_elements) {
522 int64_t delta =
523 total_elements - recorded_num_elements_.exchange(total_elements);
524 num_elements_counter_->IncrementBy(delta);
525 }
526
527 private:
528 monitoring::CounterCell* const bytes_consumed_counter_;
529 monitoring::CounterCell* const bytes_produced_counter_;
530 monitoring::CounterCell* const num_elements_counter_;
531 std::atomic<int64_t> recorded_bytes_consumed_;
532 std::atomic<int64_t> recorded_bytes_produced_;
533 std::atomic<int64_t> recorded_num_elements_;
534 };
535
536 // Computes the exponential moving average of processing time per element.
537 void UpdateProcessingTimeEma() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
538 if (previous_processing_time_ == 0) {
539 if (num_elements_ > 0) {
540 processing_time_ema_ = static_cast<double>(processing_time_) /
541 static_cast<double>(num_elements_);
542 } else {
543 processing_time_ema_ = static_cast<double>(processing_time_);
544 }
545 } else {
546 processing_time_ema_ =
547 (1.0 - kProcessingTimeEmaWeight) * processing_time_ema_ +
548 kProcessingTimeEmaWeight *
549 static_cast<double>(processing_time_ - previous_processing_time_);
550 }
551 previous_processing_time_ = processing_time_;
552 }
553
554 // Returns the number of inputs.
555 int64_t num_inputs() const TF_SHARED_LOCKS_REQUIRED(mu_) {
556 int64_t num_inputs = 0;
557 for (auto& input : inputs_) {
558 // Inputs for which autotuning is disabled are excluded.
559 if (input->autotune()) {
560 ++num_inputs;
561 }
562 }
563 return num_inputs;
564 }
565
566 // Creates a clone of this node.
567 virtual std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const
568 TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
569
570 // Returns the average size of an element buffered in this node.
571 double AverageBufferedElementSize() const TF_SHARED_LOCKS_REQUIRED(mu_);
572
573 // Returns the sum of per-element output time for the tunable inputs of this
574 // node.
575 double OutputTimeForInputs(const NodeValues& output_times) const
576 TF_SHARED_LOCKS_REQUIRED(mu_);
577
578 // Returns the sum of output time gradient w.r.t. input time for the tunable
579 // inputs of this node.
580 double OutputTimeGradientsForInputs(const NodeValues& output_time_gradients)
581 const TF_SHARED_LOCKS_REQUIRED(mu_);
582
583 // Computes the input time for this node and stores it in `input_times`.
584 virtual void InputTimeLocked(NodeValues* input_times) const
585 TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
586
587 // Computes the per-element output time for this node and stores it in
588 // `output_times`. If `gradients` is not `nullptr`, computes the output time
589 // gradient w.r.t. tunable parameters of the subtree rooted in this node and
590 // stores it in `gradients`, also computes the output time gradient w.r.t.
591 // input time and stores it in `output_time_gradients`.
592 virtual void OutputTimeLocked(const NodeValues& input_times,
593 ParameterGradients* gradients,
594 NodeValues* output_times,
595 NodeValues* output_time_gradients) const
596 TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
597
598 // Returns the sum of per-element processing time for the inputs of this node
599 // by adding values for input nodes in `total_processing_times`. Processing
600 // time for a given input is a weighted combination of a statistic based on
601 // history of input processing time and the actual time. This is done to
602 // improve accuracy of processing time estimation for newly created inputs.
603 //
604 // Uniform distribution of per-element processing times across different
605 // inputs is assumed.
606 double TotalProcessingTimeForInputs(const NodeValues& total_processing_times)
607 TF_SHARED_LOCKS_REQUIRED(mu_);
608
609 // Returns the per-element processing time spent in this node.
610 double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_);
611
612 // Computes the per-element CPU time spent in the subtree rooted in this node
613 // and stores it in `total_processing_times`. If `processing_times` is not
614 // `nullptr`, collects the per-element CPU time spent in each node of the
615 // subtree.
616 virtual void TotalProcessingTimeLocked(NodeValues* processing_times,
617 NodeValues* total_processing_times)
618 TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
619
620 // This is the locked version of the public `CollectNodes`.
621 NodeVector CollectNodesLocked(TraversalOrder order,
622 bool collect_node(const std::shared_ptr<Node>))
623 const TF_SHARED_LOCKS_REQUIRED(mu_);
624
625 // Collects tunable parameters in the subtree rooted in this node assuming
626 // mutex locked.
627 ModelParameters CollectTunableParametersLocked() const
628 TF_SHARED_LOCKS_REQUIRED(mu_);
629
630 // Collect tunable parameters on the nodes which have recorded
631 // elements.
632 void CollectTunableParametersHelper(ModelParameters* parameters) const
633 TF_SHARED_LOCKS_REQUIRED(mu_);
634
635 // Build up debug string for the node and store in the debug strings map.
636 void DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
637 const TF_SHARED_LOCKS_REQUIRED(mu_);
638
639 // Copy the node and add the (input, copy) pairs to the NodePairList.
640 std::shared_ptr<Node> SnapshotHelper(std::shared_ptr<Node> cloned_output,
641 NodePairList* node_pairs) const;
642
643 // Compute total buffered bytes for the node and store in the total bytes map.
644 void TotalBufferedBytesHelper(NodeValues* total_bytes) const
645 TF_SHARED_LOCKS_REQUIRED(mu_);
646
647 // Compute total maximum buffered bytes for the node and store in the total
648 // bytes map.
649 void TotalMaximumBufferedBytesHelper(NodeValues* total_bytes) const
650 TF_SHARED_LOCKS_REQUIRED(mu_);
651
652 // Compute and return the maximum buffered bytes on the node itself. By
653 // default non-tunable nodes are assumed not to buffer any bytes, so the
654 // tunable nodes as subclasses are expected to override this method to ensure
655 // that the optimization algorithm respects the memory budget.
656 virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_);
657
658 // Restores node from the proto. Note that this is not done recursively, i.e.
659 // input nodes are not restored.
660 static Status FromProtoHelper(ModelProto::Node node_proto,
661 std::shared_ptr<Node> node);
662
663 // Stores the time passed to the last call to `Node::record_start()` on the
664 // current thread.
665 //
666 // NOTE: This thread-local variable is shared between all instances of `Node`
667 // on which the same thread calls `record_start()` or `record_stop()`. It
668 // relies on the invariant that at most one `Node` can be "active" on a
669 // particular thread at any time. Therefore if `n->record_start()` is called
670 // on thread `t`, then `n->record_stop()` must be called before another call
671 // to `Node::record_start()` (for any node).
672 static thread_local int64_t work_start_; // Will be initialized to zero.
673
674 mutable mutex mu_;
675 const int64_t id_;
676 const string name_;
677
678 // Indicates whether the subtree rooted in this node should be included in
679 // autotuning. In particular, if this is `false`, then the subtree is excluded
680 // from computation of output time and processing time.
681 std::atomic<bool> autotune_;
682 std::atomic<int64_t> buffered_bytes_;
683 std::atomic<int64_t> buffered_elements_;
684 std::atomic<int64_t> buffered_elements_low_;
685 std::atomic<int64_t> buffered_elements_high_;
686 std::atomic<int64_t> bytes_consumed_;
687 std::atomic<int64_t> bytes_produced_;
688 std::atomic<int64_t> num_elements_;
689 std::atomic<int64_t> processing_time_;
690 std::atomic<bool> record_metrics_;
691 Metrics metrics_;
692 absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
693 TF_GUARDED_BY(mu_);
694
695 // Statistic of inputs processing time history.
696 double input_processing_time_sum_ = 0.0L;
697 int64_t input_processing_time_count_ = 0;
698
699 // Holds the previous processing time and the per element processing time
700 // exponential moving average.
701 int64_t previous_processing_time_ TF_GUARDED_BY(mu_) = 0;
702 double processing_time_ema_ TF_GUARDED_BY(mu_) = 0.0;
703
704 // Inputs of this node. These can represent an iterator created from the input
705 // dataset but also other input iterators (e.g. created by the user-defined
706 // functions of `flat_map` or `interleave`).
707 std::list<std::shared_ptr<Node>> inputs_ TF_GUARDED_BY(mu_);
708
709 // The reference to the output node is not owned so that deletion of a
710 // node results in recursive deletion of the subtree rooted in the node.
711 Node* const output_;
712};
713
714// InterleaveMany is used to model datasets whose inputs are used to create
715// datasets whose elements are then interleaved.
716std::shared_ptr<Node> MakeInterleaveManyNode(
717 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters);
718
719// AsyncInterleaveMany nodes are the asynchronous version of InterleaveMany
720// nodes.
721std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
722 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters);
723
724// KnownMany nodes model datasets that synchronously consume known number of
725// input element per output element.
726std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio);
727
728// AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes.
729std::shared_ptr<Node> MakeAsyncKnownRatioNode(
730 Node::Args args, double ratio, double memory_ratio,
731 std::vector<std::shared_ptr<Parameter>> parameters);
732
733std::shared_ptr<Node> MakeAsyncKnownRatioNode(
734 Node::Args args, double ratio,
735 std::vector<std::shared_ptr<Parameter>> parameters);
736
737// Source nodes represent data sources.
738std::shared_ptr<Node> MakeSourceNode(Node::Args args);
739
740// UnknownMany nodes represent datasets that synchronously consume an
741// unknown number of input elements per output.
742//
743// Unlike KnownRatio nodes which expect the ratio between inputs and outputs is
744// specified as a parameter, UnknownRatio estimates the ratio empirically.
745std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args);
746
747// AsyncUnknownRatio nodes are the asynchronous version of unknown ratio nodes.
748std::shared_ptr<Node> MakeAsyncUnknownRatioNode(
749 Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters);
750
751// Unknown nodes represent datasets for which we do not have a model. It acts
752// as pass-through between inputs and output.
753std::shared_ptr<Node> MakeUnknownNode(Node::Args args);
754
755// Abstract representation of a TensorFlow input pipeline that can be used
756// for collecting runtime information and optimizing performance. It collects
757// runtime information about execution of the input pipeline that is used to
758// create a performance model, which is in turn used to identify optimal values
759// of tunable parameters.
760//
761// Developers of tf.data transformations are not expected to interact with this
762// class directly. Boiler plate code for creating the abstract representation of
763// the input pipeline and collecting runtime information has been added to the
764// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
765//
766// The order of locks acquired is SharedState lock, Model lock, Node lock.
767// SharedState lock is acquired first because it shares the same lock as the
768// dataset iterator that contains it.
769class Model {
770 public:
771 using OptimizationParams = ModelProto::OptimizationParams;
772 using ModelParameters = Node::ModelParameters;
773 using NodeValues = Node::NodeValues;
774 using ParameterGradients = Node::ParameterGradients;
775
776 Model();
777 ~Model();
778
779 // Returns a pointer to the model's output node.
780 const std::shared_ptr<Node> output() const {
781 mutex_lock l(mu_);
782 return output_;
783 }
784
785 // Set the experiment that this job is part of.
786 void SetExperiment(const string& experiment) { experiment_ = experiment; }
787
788 // Adds a node with the given name and given parent.
789 void AddNode(Node::Factory factory, const string& name,
790 std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
791 TF_LOCKS_EXCLUDED(mu_);
792
793 // Returns a human-readable string representation of the model. This method
794 // can be invoked automatically by monitoring gauges and to avoid frequent
795 // recomputation, the implementation caches the result.
796 std::string DebugString();
797
798 // Uses the given algorithm and resource budgets to periodically perform the
799 // autotuning optimization.
800 //
801 // To terminate the execution of the optimization loop, the caller needs to
802 // invoke `cancellation_mgr->StartCancel()`.
803 Status OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
804 int64_t ram_budget,
805 CancellationManager* cancellation_manager);
806
807 // Uses the given algorithm and resource budgets to perform the autotuning
808 // optimization.
809 void Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
810 int64_t ram_budget, double model_input_time,
811 CancellationManager* cancellation_manager);
812
813 // Optimizes buffers in the pipeline rooted at `snapshot`. It downsizes
814 // buffers that are too large and upsizes buffers that are too small while
815 // respecting the ram budget. If any node is downsized or upsized, the
816 // watermarks of all nodes are reset to the buffered elements.
817 void OptimizeBuffers(std::shared_ptr<Node> snapshot, int64_t ram_budget);
818
819 // Collects the output time and if `gradients` is not `nullptr`, the output
820 // time gradient w.r.t. tunable parameters of the subtree rooted in the given
821 // node.
822 double OutputTime(std::shared_ptr<Node> node, double model_input_time,
823 ParameterGradients* gradients);
824
825 // Removes the given node.
826 void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
827
828 // Produces a proto for this model.
829 Status ToProto(ModelProto* model_proto);
830
831 // Restores a model from the proto.
832 static Status FromProto(ModelProto model_proto,
833 std::unique_ptr<Model>* model);
834
835 // Saves this model with a given snapshot and its optimization parameters to a
836 // file. Note that the file directory must already exist.
837 Status Save(const string& fname, std::shared_ptr<Node> snapshot,
838 const OptimizationParams& optimization_params);
839
840 // Loads a model and its optimization parameters from a file with the given
841 // name.
842 static Status Load(const string& fname, std::unique_ptr<Model>* model,
843 OptimizationParams* optimization_params);
844
845 // Records gap time between consecutive `GetNext()` calls.
846 void RecordIteratorGapTime(uint64_t duration_usec);
847
848 // Computes the target time in nsecs to use for `STAGE_BASED` autotune
849 // algorithm.
850 double ComputeTargetTimeNsec();
851
852 private:
853 // Determines whether optimization should stop given total processing time,
854 // estimated output time, and estimated number of buffers bytes.
855 using StopPredicate =
856 std::function<bool(const ModelParameters&, double, double, double)>;
857
858 static constexpr int64_t kOptimizationPeriodMinMs = 10;
859 static constexpr int64_t kOptimizationPeriodMaxMs =
860 60 * EnvTime::kSecondsToMillis;
861
862 // Collects tunable parameters in the tree rooted in the given node, returning
863 // a vector which contains pairs of node names and tunable parameters.
864 ModelParameters CollectTunableParameters(std::shared_ptr<Node> node);
865
866 // Downsizes buffers that are too large for all nodes rooted at `snapshot`.
867 // Returns true if any buffer is downsized.
868 bool DownsizeBuffers(std::shared_ptr<Node> snapshot);
869
870 // Upsizes buffers that are too small for all nodes rooted at `snapshot` while
871 // respecting the ram budget. Returns true if any buffer is upsized.
872 bool UpsizeBuffers(std::shared_ptr<Node> snapshot, int64_t ram_budget);
873
874 // Reset buffer watermarks of all asynchronous nodes to their buffered
875 // elements.
876 void ResetBufferWatermarks();
877
878 // Collects buffer parameters of all nodes in the model that should be
879 // upsized.
880 absl::flat_hash_map<Node*, Parameter*> CollectBufferParametersToUpsize(
881 std::shared_ptr<Node> snapshot);
882
883 // Flushes metrics recorded by the model.
884 void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
885
886 // This optimization algorithm starts by setting all tunable parallelism
887 // parameters to the minimum value. It then improves current parameters by
888 // making a step in the direction opposite to the gradient of `OutputTime` and
889 // projecting resulting values on the feasible intervals. Improvement step is
890 // repeated until either the output time improvement is smaller than threshold
891 // value or the output time is less than the processing time needed to produce
892 // an element divided by CPU budget.
893 void OptimizeGradientDescent(std::shared_ptr<Node> snapshot,
894 const OptimizationParams& optimization_params,
895 CancellationManager* cancellation_manager);
896
897 // Helper method for implementing hill-climb optimization that can be
898 // parametrized by a predicate to use for stopping the optimization.
899 void OptimizeHillClimbHelper(std::shared_ptr<Node> snapshot,
900 const OptimizationParams& optimization_params,
901 CancellationManager* cancellation_manager,
902 StopPredicate should_stop);
903
904 // This optimization algorithm starts by setting all tunable parallelism
905 // parameters to the minimum value. It then repeatedly identifies the
906 // parameter whose increase in parallelism decreases the output time the most.
907 // This process is repeated until all parameters reach their maximum values or
908 // the projected output time is less than or equal to the processing time
909 // needed to produce an element divided by CPU budget.
910 void OptimizeHillClimb(std::shared_ptr<Node> snapshot,
911 const OptimizationParams& optimization_params,
912 CancellationManager* cancellation_manager);
913
914 // This optimization behaves similarly to the hill climb optimization but uses
915 // a relaxed stoping condition, allowing the optimization to oversubscribe
916 // CPU.
917 void OptimizeMaxParallelism(std::shared_ptr<Node> snapshot,
918 const OptimizationParams& optimization_params,
919 CancellationManager* cancellation_manager);
920
921 // This optimization starts by setting all tunable parallelism parameters to
922 // their minimum values. It then repeatedly increases the parallelism
923 // parameter of the longest stage by 1 until either the longest stage is
924 // faster than the target time or the memory or CPU budget is fully utilized.
925 // TODO(b/226910071): The second part of this algorithm optimizes the buffer
926 // sizes of parallel ops.
927 void OptimizeStageBased(std::shared_ptr<Node> snapshot,
928 const OptimizationParams& optimization_params,
929 CancellationManager* cancellation_manager);
930
931 // This is the first part of the stage-based optimization that optimizes
932 // tunable parallelism parameters.
933 void OptimizeStageBasedParallelism(
934 std::shared_ptr<Node> snapshot, double target_time_nsec,
935 const OptimizationParams& optimization_params,
936 CancellationManager* cancellation_manager);
937
938 // Determines if we should stop the gradient descent optimization iterations
939 // based on number of increasable parameters, CPU budget, RAM budget and
940 // current resource usage.
941 bool ShouldStop(int64_t cpu_budget, int64_t ram_budget,
942 const ModelParameters& parameters,
943 const ModelParameters& parallelism_parameters,
944 const ModelParameters& buffer_size_parameters,
945 std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
946
947 // Collects the processing time for the given node.
948 double TotalProcessingTime(std::shared_ptr<Node> node);
949
950 // Collects the total number of bytes buffered in all nodes in the subtree
951 // rooted in the given node for which autotuning is enabled.
952 double TotalBufferedBytes(std::shared_ptr<Node> node);
953
954 // Collects the total buffer limit of all nodes in the subtree rooted in the
955 // given node for which autotuning is enabled. This number represents the
956 // amount of memory that would be used by the subtree nodes if all of their
957 // buffers were full.
958 double TotalMaximumBufferedBytes(std::shared_ptr<Node> node);
959
960 // Used for coordination between different input pipeline threads. Exclusive
961 // access is required only when adding or removing nodes. Concurrent access to
962 // existing nodes is protected by a node mutex.
963 mutable mutex mu_;
964 // Used for coordinating the optimization loop and model modifications.
965 condition_variable optimize_cond_var_;
966 int64_t id_counter_ TF_GUARDED_BY(mu_) = 1;
967 std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_) = nullptr;
968
969 // Determines the time the optimization loop should wait between
970 // running optimizations.
971 int64_t optimization_period_ms_ TF_GUARDED_BY(mu_);
972
973 // Gauge cell that can be used to collect the state of the model.
974 monitoring::GaugeCell<std::function<std::string()>>* model_gauge_cell_ =
975 nullptr;
976 // Time use for rate limitting the recomputation of human-readable string
977 // represention of the model.
978 absl::Time cache_until_ = absl::InfinitePast();
979 // Cached result of the `DebugString()` invocation used to implement rate
980 // limitting of the computation.
981 std::string cached_debug_string_ = "";
982 // Used to coordinate gap time updates between different threads. Gap time is
983 // the time between the completion of the previous `GetNext()` and the start
984 // of the next `GetNext()`.
985 mutable mutex gap_mu_;
986 // Stores the latest gap times between consecutive `GetNext()`.
987 std::deque<uint64_t> gap_times_usec_ TF_GUARDED_BY(gap_mu_);
988 // The experiment that this job is part of.
989 std::string experiment_ = "";
990};
991
992// Class to compute timing information for a model.
993class ModelTiming {
994 public:
995 struct NodeTiming {
996 // Pipeline ratio is the number of elements this node needs to produce in
997 // order to produce an element at the root of the pipeline.
998 double pipeline_ratio = 0.0;
999 // The self time it takes this node to produce the elements needed to
1000 // produce one element of the root of the pipeline.
1001 double self_time_nsec = 0.0;
1002 // The total time it takes this node and the subtree rooted at this node to
1003 // produce the elements needed to produce one element at the root of the
1004 // pipeline.
1005 double total_time_nsec = 0.0;
1006 };
1007
1008 explicit ModelTiming(std::shared_ptr<Node> root);
1009
1010 // Returns the timing data for `node`.
1011 const NodeTiming* GetTiming(const Node* node) const;
1012
1013 // Returns the root nodes of all stages.
1014 std::vector<std::shared_ptr<Node>> GetStageRoots() const;
1015
1016 // Returns all the nodes of a stage given the stage root.
1017 std::vector<std::shared_ptr<Node>> GetStageNodes(
1018 std::shared_ptr<Node> stage_root) const;
1019
1020 // Computes the total time for a node.
1021 void ComputeNodeTotalTime(const Node& node);
1022
1023 private:
1024 // Computes the pipeline ratios of all nodes.
1025 void ComputePipelineRatios(const Node::NodeVector& bfs_nodes);
1026
1027 // Computes the total time for all nodes. The `reverse_bfs_nodes` are assumed
1028 // to be a vector of model nodes in reversed BFS manner.
1029 void ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes);
1030
1031 // Computes the total time of a node that is not an async interleave node.
1032 void ComputeNonAsyncInterleaveManyTotalTime(const Node& node);
1033
1034 // Computes the total time of an async interleave node.
1035 void ComputeAsyncInterleaveManyTotalTime(const Node& node);
1036
1037 // Returns a vector of all nodes in the model. The nodes are either in
1038 // breadth-first search or reverse breadth-first search order depending on the
1039 // `order` argument. The nodes are collected based on the results of the
1040 // `collect_node` predicate: if the predicate returns `false` for a given
1041 // node, then the subtree rooted in this node is excluded. The root node
1042 // itself is not collected.
1043 Node::NodeVector CollectNodes(
1044 std::shared_ptr<Node> root, TraversalOrder order,
1045 bool collect_node(const std::shared_ptr<Node>)) const;
1046
1047 // Stores a pointer to the root of a model.
1048 std::shared_ptr<Node> root_;
1049
1050 // Holds a mapping from node to its timing node.
1051 absl::flat_hash_map<const Node*, NodeTiming> timing_nodes_;
1052};
1053
1054} // namespace model
1055} // namespace data
1056} // namespace tensorflow
1057
1058#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
1059