Skip to content

Commit bab2352

Browse files
G4Gcopybara-github
authored andcommitted
Adds typing information to the nn.layer.pointnet module.
PiperOrigin-RevId: 417443268
1 parent 75fe1a1 commit bab2352

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

tensorflow_graphics/nn/layer/pointnet.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
`C`: The number of feature channels.
3939
"""
4040

41+
from typing import Optional
42+
4143
import tensorflow as tf
4244
from tensorflow_graphics.util import export_api
4345

@@ -62,13 +64,15 @@ def __init__(self, channels, momentum):
6264
self.channels = channels
6365
self.momentum = momentum
6466

65-
def build(self, input_shape):
67+
def build(self, input_shape: tf.Tensor):
6668
"""Builds the layer with a specified input_shape."""
6769
self.conv = tf.keras.layers.Conv2D(
6870
self.channels, (1, 1), input_shape=input_shape)
6971
self.bn = tf.keras.layers.BatchNormalization(momentum=self.momentum)
7072

71-
def call(self, inputs, training=None): # pylint: disable=arguments-differ
73+
def call(self,
74+
inputs: tf.Tensor,
75+
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
7276
"""Executes the convolution.
7377
7478
Args:
@@ -96,12 +100,14 @@ def __init__(self, channels, momentum):
96100
self.momentum = momentum
97101
self.channels = channels
98102

99-
def build(self, input_shape):
103+
def build(self, input_shape: tf.Tensor):
100104
"""Builds the layer with a specified input_shape."""
101105
self.dense = tf.keras.layers.Dense(self.channels, input_shape=input_shape)
102106
self.bn = tf.keras.layers.BatchNormalization(momentum=self.momentum)
103107

104-
def call(self, inputs, training=None): # pylint: disable=arguments-differ
108+
def call(self,
109+
inputs: tf.Tensor,
110+
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
105111
"""Executes the convolution.
106112
107113
Args:
@@ -125,7 +131,7 @@ class VanillaEncoder(tf.keras.layers.Layer):
125131
https://github.com/charlesq34/pointnet/blob/master/models/pointnet_cls_basic.py
126132
"""
127133

128-
def __init__(self, momentum=.5):
134+
def __init__(self, momentum: float = .5):
129135
"""Constructs a VanillaEncoder keras layer.
130136
131137
Args:
@@ -138,7 +144,9 @@ def __init__(self, momentum=.5):
138144
self.conv4 = PointNetConv2Layer(128, momentum)
139145
self.conv5 = PointNetConv2Layer(1024, momentum)
140146

141-
def call(self, inputs, training=None): # pylint: disable=arguments-differ
147+
def call(self,
148+
inputs: tf.Tensor,
149+
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
142150
"""Computes the PointNet features.
143151
144152
Args:
@@ -166,7 +174,10 @@ class ClassificationHead(tf.keras.layers.Layer):
166174
logits of the num_classes classes.
167175
"""
168176

169-
def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3):
177+
def __init__(self,
178+
num_classes: int = 40,
179+
momentum: float = 0.5,
180+
dropout_rate: float = 0.3):
170181
"""Constructor.
171182
172183
Args:
@@ -180,7 +191,9 @@ def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3):
180191
self.dropout = tf.keras.layers.Dropout(dropout_rate)
181192
self.dense3 = tf.keras.layers.Dense(num_classes, activation="linear")
182193

183-
def call(self, inputs, training=None): # pylint: disable=arguments-differ
194+
def call(self,
195+
inputs: tf.Tensor,
196+
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
184197
"""Computes the classifiation logits given features (note: without softmax).
185198
186199
Args:
@@ -199,7 +212,10 @@ def call(self, inputs, training=None): # pylint: disable=arguments-differ
199212
class PointNetVanillaClassifier(tf.keras.layers.Layer):
200213
"""The PointNet 'Vanilla' classifier (i.e. without spatial transformer)."""
201214

202-
def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3):
215+
def __init__(self,
216+
num_classes: int = 40,
217+
momentum: float = .5,
218+
dropout_rate: float = .3):
203219
"""Constructor.
204220
205221
Args:
@@ -212,7 +228,9 @@ def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3):
212228
self.classifier = ClassificationHead(
213229
num_classes=num_classes, momentum=momentum, dropout_rate=dropout_rate)
214230

215-
def call(self, points, training=None): # pylint: disable=arguments-differ
231+
def call(self,
232+
points: tf.Tensor,
233+
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
216234
"""Computes the classifiation logits of a point set.
217235
218236
Args:
@@ -227,7 +245,8 @@ def call(self, points, training=None): # pylint: disable=arguments-differ
227245
return logits
228246

229247
@staticmethod
230-
def loss(labels, logits):
248+
def loss(labels: tf.Tensor,
249+
logits: tf.Tensor) -> tf.Tensor:
231250
"""The classification model training loss.
232251
233252
Note:
@@ -236,6 +255,9 @@ def loss(labels, logits):
236255
Args:
237256
labels: a tensor with shape `[B,]`
238257
logits: a tensor with shape `[B,num_classes]`
258+
259+
Returns:
260+
A tensor with the same shape as labels and of the same type as logits.
239261
"""
240262
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits
241263
residual = cross_entropy(labels, logits)

0 commit comments

Comments
 (0)