diff --git a/toolbox.py b/toolbox.py index 2eb576e..fecfd37 100644 --- a/toolbox.py +++ b/toolbox.py @@ -4,7 +4,7 @@ import sys import os import math -from random import randint +from random import randint, shuffle import matplotlib.mlab as mlab import matplotlib.pyplot as plt import numpy as np @@ -48,15 +48,16 @@ def prepare_datasets(csv_filename, validation_set_proportion=0.0): return train_samples, validation_samples -def plot_steering_angle_histogram(steering_angles, title='Histogram of steering angle', show=True): +def plot_steering_angle_histogram(steering_angles, title='Histogram of steering angle', nb_bins=200, show=True): """ Plot steering angle histogram. - :param steering_angles: Array containing steering angles in radiant. + :param steering_angles: Array containing steering angles in degree. :param title: Title of the histogram. + :param nb_bins: Number of bins. :param show: If true, `plot.show()` is called at the end. """ fig, ax1 = plt.subplots() - n, bins, patches = ax1.hist(steering_angles, 200, normed=False, facecolor='green', edgecolor='black', alpha=0.75, + n, bins, patches = ax1.hist(steering_angles, nb_bins, normed=False, facecolor='green', edgecolor='black', alpha=0.75, histtype='bar', rwidth=0.85, label='steering angles') ax1.set_xlabel('steering angle [degree]') ax1.set_ylabel('frequency') @@ -68,15 +69,16 @@ def plot_steering_angle_histogram(steering_angles, title='Histogram of steering plt.show() -def plot_normed_steering_angle_histogram(steering_angles, title='Histogram of steering angle', show=True): +def plot_normed_steering_angle_histogram(steering_angles, title='Histogram of steering angle', nb_bins=200, show=True): """ Plot normed steering angle histogram with expected distribution. - :param steering_angles: Array containing steering angles in radiant. + :param steering_angles: Array containing steering angles in degree. :param title: Title of the histogram. + :param nb_bins: Number of bins. :param show: If true, `plot.show()` is called at the end. """ fig, ax1 = plt.subplots() - n, bins, patches = ax1.hist(steering_angles, 200, normed=True, facecolor='green', edgecolor='black', alpha=0.75, + n, bins, patches = ax1.hist(steering_angles, nb_bins, normed=True, facecolor='green', edgecolor='black', alpha=0.75, histtype='bar', rwidth=0.85, label='steering angles') ax1.set_xlabel('steering angle [degree]') ax1.set_ylabel('frequency') @@ -572,7 +574,7 @@ def balance_steering_angles(source_csv_file, destination_csv_file, max_samples_p print('Bin size: {:.2f}'.format(bin_size)) print('Number of bins: {:d}'.format(nb_bins)) print('') - print('Reduce dataset...', end='', flush=True) + print('Balance dataset...', end='', flush=True) angles = np.array([]) @@ -580,12 +582,69 @@ def balance_steering_angles(source_csv_file, destination_csv_file, max_samples_p for sample in samples: angles = np.append(angles,float(sample[3]) * 25.) - # TODO: calculate histogram an limit number of samples per bin + # calculate histogram and sort angles according assigned bin hist, bin_edges = np.histogram(angles, bins=nb_bins) - bin_idx = np.digitize(angles, bins=bin_edges) + bin_idx = np.digitize(angles, bins=bin_edges) - 1 + angle_idx = np.arange(0, len(angles), 1) + bin_idx_sorted = bin_idx.argsort(axis=0) + bin_idx = bin_idx[bin_idx_sorted] + angle_idx = angle_idx[bin_idx_sorted] + + bin_angle_idx_map = [[]] * (nb_bins + 1) # contains the mapping between bin and angle idx + buf = [] + last_idx = bin_idx[0] + + for i, idx in enumerate(bin_idx): + if last_idx == idx: + buf.append(angle_idx[i]) + else: + bin_angle_idx_map[last_idx] = buf + buf = [] + buf.append(angle_idx[i]) + last_idx = idx + + bin_angle_idx_map[last_idx] = buf + + # limit number of samples per bin (if > max_samples_per_bin then take random samples) + remaining_angle_idx = [] + + for angle_idx_list in bin_angle_idx_map: + if len(angle_idx_list) > max_samples_per_bin: + shuffle(angle_idx_list) + remaining_angle_idx.append(angle_idx_list[0:max_samples_per_bin]) + elif len(angle_idx_list) > 0: + remaining_angle_idx.append(angle_idx_list) + + remaining_angle_idx = sum(remaining_angle_idx, []) + samples_balanced = samples[remaining_angle_idx] + + print('done') + + # write balances dataset into destination CSV file + print('Save balanced dataset...', end='', flush=True) + + csv_file = open(destination_csv_file, 'w') + fieldnames = ['center', 'left', 'right', 'steering', 'throttle', 'brake', 'speed'] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + writer.writeheader() - print(bin_idx) + for sample in samples_balanced: + writer.writerow({'center': sample[0], + 'left': sample[1], + 'right': sample[2], + 'steering': sample[3], + 'throttle': sample[4], + 'brake': sample[5], + 'speed': sample[6]}) + + print('done') + print('Close the figures to continue...') + plot_steering_angle_histogram(angles, title='Steering angle histogram in original dataset (N={:d})'.format(len(angles)), + nb_bins=nb_bins, show=False) + plot_steering_angle_histogram(angles[remaining_angle_idx], + title='Steering angle histogram in balanced dataset (N={:d})'.format(len(angles[remaining_angle_idx])), + nb_bins=nb_bins, show=True) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Data preparation and model analysis toolbox')