Skip to content

Commit

Permalink
linter updates
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Aug 28, 2024
1 parent fbc40f5 commit d81d61e
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 36 deletions.
9 changes: 5 additions & 4 deletions izer/backend/max7800x.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,11 +3560,12 @@ def run_eltwise(
for i in range(terminating_layer + 1):
if output_layer[i]:
if output_width[i] != 32:
count2add = (output_chan[i] * output_dim[i][0] * output_dim[i][1] \
+ (32 // output_width[i] - 1)) // (32 // output_width[i])
if scale_output:
count2add *= 2
output_count += count2add
output_count += (output_chan[i] * output_dim[i][0] * output_dim[i][1] \
+ (32 // (2 * output_width[i]) - 1)) // (32 // (2 * output_width[i]))
else:
output_count += (output_chan[i] * output_dim[i][0] * output_dim[i][1] \
+ (32 // output_width[i] - 1)) // (32 // output_width[i])
else:
output_count += output_chan[i] * output_dim[i][0] * output_dim[i][1]
insert = summary_stats + \
Expand Down
3 changes: 2 additions & 1 deletion izer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def get_max_bit_shift(t, clamp_bits, shift_quantile, return_bit_shift=False):
threshold_name = '.'.join([layer, 'threshold'])
if threshold_name in checkpoint_state:
threshold = checkpoint_state[threshold_name]
out_shift = (out_shift - threshold).clamp(min=-7.-clamp_bits, max=23.-clamp_bits)
out_shift = (out_shift - threshold).clamp(min=-7.-clamp_bits,
max=23.-clamp_bits)
new_checkpoint_state[out_shift_name] = out_shift
if new_masks_dict is not None:
new_masks_dict[out_shift_name] = out_shift
Expand Down
6 changes: 3 additions & 3 deletions izer/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def main(
memfile.write(' softmax_layer();\n')
elif unload:
memfile.write(' cnn_unload((uint32_t *) '
f'ml_data);\n')
'ml_data);\n')

if embedded_code:
memfile.write('\n printf("\\n*** PASS ***\\n\\n");\n\n'
Expand Down Expand Up @@ -844,7 +844,7 @@ def main(
' digs = (1000 * ml_softmax[i] + 0x4000) >> 15;\n'
' tens = digs % 10;\n'
' digs = digs / 10;\n'
f' printf("[%7d] -> Class %d: %d.%d%%\\n", ml_data{output_width if output_width!=32 else ""}[i], '
' printf("[%7d] -> Class %d: %d.%d%%\\n", ml_data[i], '
'i, digs, tens);\n'
' }\n')
else:
Expand Down Expand Up @@ -912,7 +912,7 @@ def softmax_layer(
function_header(memfile, prefix='',
function='softmax_layer',
return_type='void')
memfile.write(f' cnn_unload((uint32_t *) ml_data);\n')
memfile.write(' cnn_unload((uint32_t *) ml_data);\n')

if output_width == 32:
if shift == 0:
Expand Down
74 changes: 46 additions & 28 deletions izer/unload.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,16 @@ def scaled_mlator_write_one(
f'{prefix} out_buf[offs++] = ((val >> 16) & 0xff) << {scale};\n' \
f'{prefix} out_buf[offs++] = ((val >> 24) & 0xff) << {scale};\n'

elif scale < 0:
if scale < 0:
return f'{prefix} val = *mlat;{comment}\n' \
f'{prefix} out_buf[offs++] = (int16_t)((val & 0xff) << 8) >> {abs(scale) + 8};\n' \
f'{prefix} out_buf[offs++] = (int16_t)(((val >> 8) & 0xff) << 8) >> {abs(scale) + 8};\n' \
f'{prefix} out_buf[offs++] = (int16_t)(((val >> 16) & 0xff) << 8) >> {abs(scale) + 8};\n' \
f'{prefix} out_buf[offs++] = (int16_t)(((val >> 24) & 0xff) << 8) >> {abs(scale) + 8};\n'
f'{prefix} out_buf[offs++] = (int16_t)((val & 0xff) << 8) >> '\
f'{abs(scale) + 8};\n' \
f'{prefix} out_buf[offs++] = (int16_t)(((val >> 8) & 0xff) << 8) >> '\
f'{abs(scale) + 8};\n' \
f'{prefix} out_buf[offs++] = (int16_t)(((val >> 16) & 0xff) << 8) >> '\
f'{abs(scale) + 8};\n' \
f'{prefix} out_buf[offs++] = (int16_t)(((val >> 24) & 0xff) << 8) >> '\
f'{abs(scale) + 8};\n'
else:
return mlator_write_one(prefix, comment, out_size)

Expand All @@ -98,13 +102,13 @@ def scaled_mlator_write_one(
break

if not scale_output and final_scale_detected:
wprint(f'Non-zero output scale detected, but --scale-output not set. '
f'Unload operation will be realized without scaling. '
wprint('Non-zero output scale detected, but --scale-output not set. '
'Unload operation will be realized without scaling. '
f'Final scales are {final_scale}.')

if scale_output and not final_scale_detected:
nprint(f'--scale-output set, but all output scales are zero. '
f'Unload operation will be realized without scaling.')
nprint('--scale-output set, but all output scales are zero. '
'Unload operation will be realized without scaling.')

assert not state.block_mode or not mlator

Expand Down Expand Up @@ -245,6 +249,7 @@ def scaled_mlator_write_one(
else: # mlator
def mlator_loop(
num: int = 1,
ll: int = ll,
) -> None:
"""
Print multiple mlator unload lines using a partially unrolled loop
Expand All @@ -259,7 +264,8 @@ def mlator_loop(
result += f' for (i = 0; i < {num // mlator_chunk}; i++) {{\n'
for _ in range(mlator_chunk):
if scale_output:
result += scaled_mlator_write_one(' ', '', out_size, final_scale[ll])
result += scaled_mlator_write_one(' ', '', out_size,
final_scale[ll])
else:
result += mlator_write_one(' ', '', out_size)
result += ' }\n'
Expand Down Expand Up @@ -410,14 +416,16 @@ def mlator_loop(
prefix = ''
for _ in range(min(remaining, chunk)):
if final_scale[ll] != 0 and final_scale[ll] > 0 and scale_output:
out_text += f'{prefix} *out_buf++ = (*addr++) << {final_scale[ll]};\n'
out_text += f'{prefix} *out_buf++ = (*addr++) <<'\
f' {final_scale[ll]};\n'
elif final_scale[ll] != 0 and final_scale[ll] < 0 and scale_output:
out_text += f'{prefix} *out_buf++ = (int{o_width}_t)(*addr++) >> {abs(final_scale[ll])};\n'
out_text += f'{prefix} *out_buf++ = (int{o_width}_t)(*addr++)'\
f' >> {abs(final_scale[ll])};\n'
else:
out_text += f'{prefix} *out_buf++ = *addr++;\n'
if delta_r != 4:
f'{prefix} addr {"+" if delta_r >= 0 else "-"}= ' \
f'0x{abs(delta_r) // 4:04x};\n'
f'{prefix} addr {"+" if delta_r >= 0 else "-"}= ' \
f'0x{abs(delta_r) // 4:04x};\n'
if loop_runs > 1:
out_text += ' }\n'
remaining -= loop_runs * chunk
Expand All @@ -432,13 +440,17 @@ def mlator_loop(
out_text += ' offs = 0x0000;\n'
if not first_output:
if scale_output and out_size == 1:
out_text += f' out_buf = ((uint{o_width*2}_t *) out_buf32) + 0x{(written // 2):04x};\n'
out_text += f' out_buf = ((uint{o_width*2}_t *) out_buf32)'\
f'+ 0x{(written // 2):04x};\n'
elif scale_output and out_size == 4:
out_text += f' temp_out_buf = ((uint32_t *) out_buf32) + 0x{(written // 4):04x};\n'
out_text += f' temp_out_buf = ((uint32_t *) out_buf32)'\
f'+ 0x{(written // 4):04x};\n'
elif not scale_output and out_size == 4:
out_text += f' temp_out_buf = ((uint32_t *) out_buf32) + 0x{(written // 4):04x};\n'
out_text += f' temp_out_buf = ((uint32_t *) out_buf32)'\
f'+ 0x{(written // 4):04x};\n'
else:
out_text += f' out_buf = ((uint{o_width}_t *) out_buf32) + 0x{written:04x};\n'
out_text += f' out_buf = ((uint{o_width}_t *) out_buf32)'\
f'+ 0x{written:04x};\n'
while idx < len(emit_list):
# Find how many have the same r/w addresses with different shift,
# then how many the same deltas between rs and ws with the same set of shifts.
Expand Down Expand Up @@ -482,18 +494,20 @@ def mlator_loop(
for _ in range(min(remaining, chunk)):
if out_size == 4:
if final_scale[ll] != 0 and final_scale[ll] > 0 and scale_output:
out_text += f'{prefix} *temp_out_buf++ = (*addr++) << {final_scale[ll]};\n'
out_text += f'{prefix} *temp_out_buf++ = (*addr++)'\
f'<< {final_scale[ll]};\n'
elif final_scale[ll] != 0 and final_scale[ll] < 0 and scale_output:
out_text += f'{prefix} *temp_out_buf++ = (int32_t)(*addr++) >> {abs(final_scale[ll])};\n'
out_text += f'{prefix} *temp_out_buf++ = (int32_t)(*addr++)'\
f' >> {abs(final_scale[ll])};\n'
else:
out_text += f'{prefix} *temp_out_buf++ = *addr++;\n'
else:
if delta_r == 4:
out_text += f'{prefix} val = *addr++;\n'
else:
out_text += f'{prefix} val = *addr;\n' \
f'{prefix} addr {"+" if delta_r >= 0 else "-"}= ' \
f'0x{abs(delta_r) // 4:04x};\n'
f'{prefix} addr {"+" if delta_r >= 0 else "-"}= '\
f' 0x{abs(delta_r) // 4:04x};\n'
for shift in shift_list:
if not short_write:
out_text += f'{prefix} out_buf[offs'
Expand All @@ -509,15 +523,16 @@ def mlator_loop(
out_text += f'(int{o_width*2}_t)(((val >> {shift * 8})'
else:
if shift == 0:
out_text += f'(val'
out_text += '(val'
else:
out_text += f'((val >> {shift * 8})'
if not scale_output or final_scale[ll] == 0:
out_text += ' & 0xff);\n'
elif final_scale[ll] > 0:
out_text += f' & 0xff)) << {final_scale[ll]};\n'
elif final_scale[ll] < 0:
out_text += f' & 0xff) << 8) >> {8 + abs(final_scale[ll])};\n'
out_text += ' & 0xff) << 8) >>'\
f'{8 + abs(final_scale[ll])};\n'

if not short_write:
out_text += f'{prefix} offs++;\n'
Expand All @@ -527,7 +542,8 @@ def mlator_loop(
out_addr += 4 * loop_runs * chunk

idx += (run + 1) * shift_count
if not short_write and idx < len(emit_list) and shift_count > 1 and out_size == 1:
if not short_write and idx < len(emit_list)\
and shift_count > 1 and out_size == 1:
out_text += f' offs += 0x{xy_dim * (shift_count - 1):04x};\n'


Expand All @@ -536,9 +552,11 @@ def mlator_loop(
* out_size
else:
if scale_output:
written += ((input_shape[ll][0] * input_shape[ll][1] * input_shape[ll][2] + 1) // 2) * 4
written += ((input_shape[ll][0] * input_shape[ll][1] *
input_shape[ll][2] + 1) // 2) * 4
else:
written += ((input_shape[ll][0] * input_shape[ll][1] * input_shape[ll][2] + 3) // 4) * 4
written += ((input_shape[ll][0] * input_shape[ll][1] *
input_shape[ll][2] + 3) // 4) * 4

first_output = False
prev_out_size = out_size
Expand All @@ -550,7 +568,7 @@ def mlator_loop(
memfile.write(f' uint{o_width*2}_t *out_buf = (uint{o_width*2}_t *) out_buf32;\n')
memfile.write(' uint32_t val;\n')
if 32 in o_widths and o_width != 32:
memfile.write(f' uint32_t *temp_out_buf;\n')
memfile.write(' uint32_t *temp_out_buf;\n')
if o_width == 32 or have_non_mlator:
memfile.write(' volatile uint32_t *addr;\n')
if mlator_layers:
Expand Down

0 comments on commit d81d61e

Please sign in to comment.