Skip to content

Commit 0a84872

Browse files
committed
Add dataset generator and test.
1 parent 9e6ef77 commit 0a84872

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java

+1
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ public void testTf2()
214214
testTf2("tf2_test_dataset7.py", "add", 2, 2, 2, 3);
215215
testTf2("tf2_test_dataset8.py", "add", 2, 2, 2, 3);
216216
testTf2("tf2_test_dataset9.py", "add", 2, 2, 2, 3);
217+
testTf2("tf2_test_dataset10.py", "add", 2, 2, 2, 3);
217218
testTf2("tf2_test_tensor_list.py", "add", 2, 2, 2, 3);
218219
testTf2("tf2_test_tensor_list2.py", "add", 0, 0);
219220
testTf2("tf2_test_tensor_list3.py", "add", 0, 0);

com.ibm.wala.cast.python.ml/data/tensorflow.xml

+10
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
<putfield class="LRoot" field="numpy_input_fn" fieldType="LRoot" ref="inputs" value="numpy_input_fn" />
6262
<new def="from_tensor_slices" class="Ltensorflow/data/Dataset/from_tensor_slices" />
6363
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />
64+
<new def="from_generator" class="Ltensorflow/data/Dataset/from_generator" />
65+
<putfield class="LRoot" field="from_generator" fieldType="LRoot" ref="Dataset" value="from_generator" />
6466
<new def="reshape" class="Ltensorflow/functions/reshape" />
6567
<putfield class="LRoot" field="reshape" fieldType="LRoot" ref="x" value="reshape" />
6668
<new def="conv2d" class="Ltensorflow/functions/conv2d" />
@@ -790,6 +792,14 @@
790792
<return value="xx" />
791793
</method>
792794
</class>
795+
<class name="from_generator" allocatable="true">
796+
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/data/Dataset#from_generator -->
797+
<method name="do" descriptor="()LRoot;" numArgs="6" paramNames="generator output_types output_shapes args output_signature name">
798+
<new def="x" class="Ltensorflow/data/Dataset" />
799+
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
800+
<return value="xx" />
801+
</method>
802+
</class>
793803
</package>
794804
<package name="tensorflow/estimator/train">
795805
<class name="train" allocatable="true">
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import tensorflow as tf
2+
3+
4+
def gen():
5+
ragged_tensor = tf.ragged.constant([[1, 2], [3]])
6+
yield 42, ragged_tensor
7+
8+
9+
def add(a, b):
10+
return a + b
11+
12+
13+
dataset = tf.data.Dataset.from_generator(
14+
gen,
15+
output_signature=(
16+
tf.TensorSpec(shape=(), dtype=tf.int32),
17+
tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
18+
19+
for element in dataset:
20+
c = add(element, element)

0 commit comments

Comments
 (0)