@@ -163,27 +163,41 @@ def do_iterate():
163
163
164
164
165
165
def get_value_range (value , target_type ):
166
+ """
167
+ This function converts all possible supplied values for augmentation
168
+ into the [start,end,r] ValueRange type. The expected inputs are of the form:
169
+
170
+ <number>
171
+ <number>~<number>
172
+ <number>:<number>~<number>
173
+
174
+ Any "missing" values are filled so that ValueRange always includes [start,end,r].
175
+ """
166
176
if isinstance (value , str ):
167
- r = target_type (0 )
168
- parts = value .split ('~' )
169
- if len (parts ) == 2 :
177
+ if '~' in value :
178
+ parts = value .split ('~' )
179
+ if len (parts ) != 2 :
180
+ raise ValueError ('Cannot parse value range' )
170
181
value = parts [0 ]
171
- r = target_type ( parts [1 ])
172
- elif len ( parts ) > 2 :
173
- raise ValueError ( 'Cannot parse value range' )
182
+ r = parts [1 ]
183
+ else :
184
+ r = 0 # if no <r> supplied, use 0
174
185
parts = value .split (':' )
175
186
if len (parts ) == 1 :
176
- parts .append (parts [0 ])
177
- elif len (parts ) > 2 :
187
+ parts .append (parts [0 ]) # only one <value> given, so double it
188
+ if len (parts ) != 2 :
178
189
raise ValueError ('Cannot parse value range' )
179
- return ValueRange (target_type (parts [0 ]), target_type (parts [1 ]), r )
190
+ return ValueRange (target_type (parts [0 ]), target_type (parts [1 ]), target_type ( r ) )
180
191
if isinstance (value , tuple ):
181
192
if len (value ) == 2 :
182
- return ValueRange (target_type (value [0 ]), target_type (value [1 ]), 0 )
193
+ return ValueRange (target_type (value [0 ]), target_type (value [1 ]), target_type ( 0 ) )
183
194
if len (value ) == 3 :
184
195
return ValueRange (target_type (value [0 ]), target_type (value [1 ]), target_type (value [2 ]))
185
- raise ValueError ('Cannot convert to ValueRange: Wrong tuple size' )
186
- return ValueRange (target_type (value ), target_type (value ), 0 )
196
+ else :
197
+ raise ValueError ('Cannot convert to ValueRange: Wrong tuple size' )
198
+ if isinstance (value , int ) or isinstance (value , float ):
199
+ return ValueRange (target_type (value ), target_type (value ), target_type (0 ))
200
+ raise ValueError ('Cannot convert to ValueRange: Wrong tuple size' )
187
201
188
202
189
203
def int_range (value ):
@@ -203,14 +217,20 @@ def pick_value_from_range(value_range, clock=None):
203
217
204
218
def tf_pick_value_from_range (value_range , clock = None , double_precision = False ):
205
219
import tensorflow as tf # pylint: disable=import-outside-toplevel
206
- clock = (tf .random .stateless_uniform ([], seed = (- 1 , 1 ), dtype = tf .float64 ) if clock is None
207
- else tf .maximum (tf .constant (0.0 , dtype = tf .float64 ), tf .minimum (tf .constant (1.0 , dtype = tf .float64 ), clock )))
220
+ if clock is None :
221
+ clock = tf .random .stateless_uniform ([], seed = (- 1 , 1 ), dtype = tf .float64 )
222
+ else :
223
+ clock = tf .maximum (tf .constant (0.0 , dtype = tf .float64 ),
224
+ tf .minimum (tf .constant (1.0 , dtype = tf .float64 ), clock ))
208
225
value = value_range .start + clock * (value_range .end - value_range .start )
209
- value = tf .random .stateless_uniform ([],
210
- minval = value - value_range .r ,
211
- maxval = value + value_range .r ,
212
- seed = (clock * tf .int32 .min , clock * tf .int32 .max ),
213
- dtype = tf .float64 )
226
+ if value_range .r :
227
+ # if the option <r> (<value>~<r>, randomization radius) is supplied,
228
+ # sample the value from a uniform distribution with "radius" <r>
229
+ value = tf .random .stateless_uniform ([],
230
+ minval = value - value_range .r ,
231
+ maxval = value + value_range .r ,
232
+ seed = (clock * tf .int32 .min , clock * tf .int32 .max ),
233
+ dtype = tf .float64 )
214
234
if isinstance (value_range .start , int ):
215
235
return tf .cast (tf .math .round (value ), tf .int64 if double_precision else tf .int32 )
216
236
return tf .cast (value , tf .float64 if double_precision else tf .float32 )
0 commit comments