Skip to content

Commit

Permalink
Added function to balance the steering angle histogram.
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenMuc committed May 22, 2017
1 parent bac5203 commit d026f08
Showing 1 changed file with 70 additions and 11 deletions.
81 changes: 70 additions & 11 deletions toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import os
import math
from random import randint
from random import randint, shuffle
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -48,15 +48,16 @@ def prepare_datasets(csv_filename, validation_set_proportion=0.0):
return train_samples, validation_samples


def plot_steering_angle_histogram(steering_angles, title='Histogram of steering angle', show=True):
def plot_steering_angle_histogram(steering_angles, title='Histogram of steering angle', nb_bins=200, show=True):
""" Plot steering angle histogram.
:param steering_angles: Array containing steering angles in radiant.
:param steering_angles: Array containing steering angles in degree.
:param title: Title of the histogram.
:param nb_bins: Number of bins.
:param show: If true, `plot.show()` is called at the end.
"""
fig, ax1 = plt.subplots()
n, bins, patches = ax1.hist(steering_angles, 200, normed=False, facecolor='green', edgecolor='black', alpha=0.75,
n, bins, patches = ax1.hist(steering_angles, nb_bins, normed=False, facecolor='green', edgecolor='black', alpha=0.75,
histtype='bar', rwidth=0.85, label='steering angles')
ax1.set_xlabel('steering angle [degree]')
ax1.set_ylabel('frequency')
Expand All @@ -68,15 +69,16 @@ def plot_steering_angle_histogram(steering_angles, title='Histogram of steering
plt.show()


def plot_normed_steering_angle_histogram(steering_angles, title='Histogram of steering angle', show=True):
def plot_normed_steering_angle_histogram(steering_angles, title='Histogram of steering angle', nb_bins=200, show=True):
""" Plot normed steering angle histogram with expected distribution.
:param steering_angles: Array containing steering angles in radiant.
:param steering_angles: Array containing steering angles in degree.
:param title: Title of the histogram.
:param nb_bins: Number of bins.
:param show: If true, `plot.show()` is called at the end.
"""
fig, ax1 = plt.subplots()
n, bins, patches = ax1.hist(steering_angles, 200, normed=True, facecolor='green', edgecolor='black', alpha=0.75,
n, bins, patches = ax1.hist(steering_angles, nb_bins, normed=True, facecolor='green', edgecolor='black', alpha=0.75,
histtype='bar', rwidth=0.85, label='steering angles')
ax1.set_xlabel('steering angle [degree]')
ax1.set_ylabel('frequency')
Expand Down Expand Up @@ -572,20 +574,77 @@ def balance_steering_angles(source_csv_file, destination_csv_file, max_samples_p
print('Bin size: {:.2f}'.format(bin_size))
print('Number of bins: {:d}'.format(nb_bins))
print('')
print('Reduce dataset...', end='', flush=True)
print('Balance dataset...', end='', flush=True)

angles = np.array([])

# get all steering angles
for sample in samples:
angles = np.append(angles,float(sample[3]) * 25.)

# TODO: calculate histogram an limit number of samples per bin
# calculate histogram and sort angles according assigned bin
hist, bin_edges = np.histogram(angles, bins=nb_bins)
bin_idx = np.digitize(angles, bins=bin_edges)
bin_idx = np.digitize(angles, bins=bin_edges) - 1
angle_idx = np.arange(0, len(angles), 1)
bin_idx_sorted = bin_idx.argsort(axis=0)
bin_idx = bin_idx[bin_idx_sorted]
angle_idx = angle_idx[bin_idx_sorted]

bin_angle_idx_map = [[]] * (nb_bins + 1) # contains the mapping between bin and angle idx
buf = []
last_idx = bin_idx[0]

for i, idx in enumerate(bin_idx):
if last_idx == idx:
buf.append(angle_idx[i])
else:
bin_angle_idx_map[last_idx] = buf
buf = []
buf.append(angle_idx[i])
last_idx = idx

bin_angle_idx_map[last_idx] = buf

# limit number of samples per bin (if > max_samples_per_bin then take random samples)
remaining_angle_idx = []

for angle_idx_list in bin_angle_idx_map:
if len(angle_idx_list) > max_samples_per_bin:
shuffle(angle_idx_list)
remaining_angle_idx.append(angle_idx_list[0:max_samples_per_bin])
elif len(angle_idx_list) > 0:
remaining_angle_idx.append(angle_idx_list)

remaining_angle_idx = sum(remaining_angle_idx, [])
samples_balanced = samples[remaining_angle_idx]

print('done')

# write balances dataset into destination CSV file
print('Save balanced dataset...', end='', flush=True)

csv_file = open(destination_csv_file, 'w')
fieldnames = ['center', 'left', 'right', 'steering', 'throttle', 'brake', 'speed']
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()

print(bin_idx)
for sample in samples_balanced:
writer.writerow({'center': sample[0],
'left': sample[1],
'right': sample[2],
'steering': sample[3],
'throttle': sample[4],
'brake': sample[5],
'speed': sample[6]})

print('done')
print('Close the figures to continue...')

plot_steering_angle_histogram(angles, title='Steering angle histogram in original dataset (N={:d})'.format(len(angles)),
nb_bins=nb_bins, show=False)
plot_steering_angle_histogram(angles[remaining_angle_idx],
title='Steering angle histogram in balanced dataset (N={:d})'.format(len(angles[remaining_angle_idx])),
nb_bins=nb_bins, show=True)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Data preparation and model analysis toolbox')
Expand Down

0 comments on commit d026f08

Please sign in to comment.