if statement (Tensorflow)
Using tf.where() is an effective way for computing an “if statement” witn in tensorflow grahs.
tf.where( condition, x=None, y=None, name=None )
Return the elements, either from x or y, depending on the condition. In other words it should look like this:
tf.where(condition_bool, return_if_conditions_is_true, return_if_conditions_is_false)
If both x and y are None, then this operation returns the coordinates of true elements of condition. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.
If both non-None, x and y must have the same shape. The condition tensor must be a scalar if x and y are scalar. If x and y are vectors of higher rank, then condition must be either a vector with size matching the first dimension of x, or must have the same shape as x.
The condition tensor acts as a mask that chooses, based on the value at each element, whether the corresponding element / row in the output should be taken from x (if true) or y (if false).
If condition is a vector and x and y are higher rank matrices, then it chooses which row (outer dimension) to copy from x and y. If condition has the same shape as x and y, then it chooses which element to copy from x and y.
1. If statement in Tensorflow Example:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import tensorflow as tf
# Placeholder variables
a = tf.Variable([[1, 1], [1, 1]], dtype=tf.float32, name="a")
b = tf.Variable([[2, 2], [2, 2]], dtype=tf.float32, name="b")
training_mode = tf.placeholder_with_default(False, [], name=None)
# augment while training
val_if_false = a
val_if_true = tf.matmul(a, b, name="op_matmul")
multiplication = tf.where(training_mode, val_if_true, val_if_false)
# initialize session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# training
results = sess.run(multiplication, feed_dict={training_mode: True})
print("true", results)
# training
results = sess.run(multiplication, feed_dict={training_mode: False})
print("false", results)
output:
true [[ 4. 4.][ 4. 4.]]
false [[ 1. 1.][ 1. 1.]]