38
38
`C`: The number of feature channels.
39
39
"""
40
40
41
+ from typing import Optional
42
+
41
43
import tensorflow as tf
42
44
from tensorflow_graphics .util import export_api
43
45
@@ -62,13 +64,15 @@ def __init__(self, channels, momentum):
62
64
self .channels = channels
63
65
self .momentum = momentum
64
66
65
- def build (self , input_shape ):
67
+ def build (self , input_shape : tf . Tensor ):
66
68
"""Builds the layer with a specified input_shape."""
67
69
self .conv = tf .keras .layers .Conv2D (
68
70
self .channels , (1 , 1 ), input_shape = input_shape )
69
71
self .bn = tf .keras .layers .BatchNormalization (momentum = self .momentum )
70
72
71
- def call (self , inputs , training = None ): # pylint: disable=arguments-differ
73
+ def call (self ,
74
+ inputs : tf .Tensor ,
75
+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
72
76
"""Executes the convolution.
73
77
74
78
Args:
@@ -96,12 +100,14 @@ def __init__(self, channels, momentum):
96
100
self .momentum = momentum
97
101
self .channels = channels
98
102
99
- def build (self , input_shape ):
103
+ def build (self , input_shape : tf . Tensor ):
100
104
"""Builds the layer with a specified input_shape."""
101
105
self .dense = tf .keras .layers .Dense (self .channels , input_shape = input_shape )
102
106
self .bn = tf .keras .layers .BatchNormalization (momentum = self .momentum )
103
107
104
- def call (self , inputs , training = None ): # pylint: disable=arguments-differ
108
+ def call (self ,
109
+ inputs : tf .Tensor ,
110
+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
105
111
"""Executes the convolution.
106
112
107
113
Args:
@@ -125,7 +131,7 @@ class VanillaEncoder(tf.keras.layers.Layer):
125
131
https://github.com/charlesq34/pointnet/blob/master/models/pointnet_cls_basic.py
126
132
"""
127
133
128
- def __init__ (self , momentum = .5 ):
134
+ def __init__ (self , momentum : float = .5 ):
129
135
"""Constructs a VanillaEncoder keras layer.
130
136
131
137
Args:
@@ -138,7 +144,9 @@ def __init__(self, momentum=.5):
138
144
self .conv4 = PointNetConv2Layer (128 , momentum )
139
145
self .conv5 = PointNetConv2Layer (1024 , momentum )
140
146
141
- def call (self , inputs , training = None ): # pylint: disable=arguments-differ
147
+ def call (self ,
148
+ inputs : tf .Tensor ,
149
+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
142
150
"""Computes the PointNet features.
143
151
144
152
Args:
@@ -166,7 +174,10 @@ class ClassificationHead(tf.keras.layers.Layer):
166
174
logits of the num_classes classes.
167
175
"""
168
176
169
- def __init__ (self , num_classes = 40 , momentum = 0.5 , dropout_rate = 0.3 ):
177
+ def __init__ (self ,
178
+ num_classes : int = 40 ,
179
+ momentum : float = 0.5 ,
180
+ dropout_rate : float = 0.3 ):
170
181
"""Constructor.
171
182
172
183
Args:
@@ -180,7 +191,9 @@ def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3):
180
191
self .dropout = tf .keras .layers .Dropout (dropout_rate )
181
192
self .dense3 = tf .keras .layers .Dense (num_classes , activation = "linear" )
182
193
183
- def call (self , inputs , training = None ): # pylint: disable=arguments-differ
194
+ def call (self ,
195
+ inputs : tf .Tensor ,
196
+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
184
197
"""Computes the classifiation logits given features (note: without softmax).
185
198
186
199
Args:
@@ -199,7 +212,10 @@ def call(self, inputs, training=None): # pylint: disable=arguments-differ
199
212
class PointNetVanillaClassifier (tf .keras .layers .Layer ):
200
213
"""The PointNet 'Vanilla' classifier (i.e. without spatial transformer)."""
201
214
202
- def __init__ (self , num_classes = 40 , momentum = .5 , dropout_rate = .3 ):
215
+ def __init__ (self ,
216
+ num_classes : int = 40 ,
217
+ momentum : float = .5 ,
218
+ dropout_rate : float = .3 ):
203
219
"""Constructor.
204
220
205
221
Args:
@@ -212,7 +228,9 @@ def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3):
212
228
self .classifier = ClassificationHead (
213
229
num_classes = num_classes , momentum = momentum , dropout_rate = dropout_rate )
214
230
215
- def call (self , points , training = None ): # pylint: disable=arguments-differ
231
+ def call (self ,
232
+ points : tf .Tensor ,
233
+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
216
234
"""Computes the classifiation logits of a point set.
217
235
218
236
Args:
@@ -227,7 +245,8 @@ def call(self, points, training=None): # pylint: disable=arguments-differ
227
245
return logits
228
246
229
247
@staticmethod
230
- def loss (labels , logits ):
248
+ def loss (labels : tf .Tensor ,
249
+ logits : tf .Tensor ) -> tf .Tensor :
231
250
"""The classification model training loss.
232
251
233
252
Note:
@@ -236,6 +255,9 @@ def loss(labels, logits):
236
255
Args:
237
256
labels: a tensor with shape `[B,]`
238
257
logits: a tensor with shape `[B,num_classes]`
258
+
259
+ Returns:
260
+ A tensor with the same shape as labels and of the same type as logits.
239
261
"""
240
262
cross_entropy = tf .nn .sparse_softmax_cross_entropy_with_logits
241
263
residual = cross_entropy (labels , logits )
0 commit comments