1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_
17#define TENSORFLOW_CC_OPS_WHILE_LOOP_H_
18
19#include <string>
20#include <vector>
21
22#include "tensorflow/cc/framework/ops.h"
23#include "tensorflow/cc/framework/scope.h"
24
25namespace tensorflow {
26namespace ops {
27
28// Function that takes cond graph inputs and returns cond graph boolean output.
29// 'output' need not be set if an error is returned.
30typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
31 Output* output)>
32 CondGraphBuilderFn;
33
34// Function that takes body graph inputs and returns body graph outputs.
35// 'outputs' need not be populated if an error is returned.
36typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
37 std::vector<Output>* outputs)>
38 BodyGraphBuilderFn;
39
40// Constructs a while loop.
41//
42// Arguments:
43// * scope: used to construct the while loop.
44// * inputs: the initial values of the loop variables. Must be non-empty.
45// * cond: a function that builds the condition graph of the loop. Takes the
46// current loop variables as inputs and returns a scalar boolean Output
47// indicating whether the loop should continue.
48// * body: a function that builds the body graph of the loop. Takes the current
49// loop variables as inputs and returns the updated loop variables.
50// * frame_name: the frame name to use for this while loop. This should be a
51// unique name. This will be used as a prefix for created operations.
52// * outputs: output param that returns final loop variable outputs in non-error
53// case. Must be non-null and empty.
54// * create_while_ctx: if true, a WhileContext is created and populated for this
55// loop. See core/graph/while_context.h for more details on
56// WhileContexts. This is set to false for loops used as part of gradient
57// computations, since they're part of the gradient for a loop in the
58// forward-pass.
59// TODO(skyewm): revisit this. Should we create WhileContexts for all loops,
60// even if we don't need them?
61// * cond_output: if non-null, the output of the predicate is returned. This
62// will always be a LoopCond node.
63//
64// Returns an error if the while loop could not be fully constructed.
65//
66// TODO(skyewm): clean up partially-constructed loop in error case
67// TODO(skyewm): create public interface to this method
68Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
69 const CondGraphBuilderFn& cond,
70 const BodyGraphBuilderFn& body, const string& frame_name,
71 OutputList* outputs, bool create_while_ctx = true,
72 Output* cond_output = nullptr);
73
74} // namespace ops
75} // namespace tensorflow
76
77#endif // TENSORFLOW_CC_OPS_WHILE_LOOP_H_
78