diff --git a/README.md b/README.md new file mode 100644 index 0000000..ccb898f --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +# STransfer + +Compile and Run STSender and STReceiver in sender and receiver respectively. +And then run STClient.py file as follows: +On Sender Side: +```bash +javac STSender.java; +java -Xmx3296m STSender; +``` +On Receiver Side: +```bash +javac STReceiver.java; +java -Xmx3296m STSender; +``` +On controller Side: +```bash +python STClient.py 192.168.1.8:52005/dir/of/the/files/to/send/form/ 192.168.1.7:53823/dir/where/to/save/the/files/on/receiver/side #(--conv=avg --method=GA --generation=5 --population=10) --> Optional +``` +Optional variables which can be passed as parameter while running java functions are: +```bash +--method +--generation +--population +--conv +--convTime +``` +Method refers to the method which will be used for finding optimal throughput. It can be either "GA" or "random".
+Generation refers to number of generation in genetic algorithm.
+Population refers to number of individual in each generation of GA.
+Conv refers to convergence function used in which getting the fitness of each individual. The variables for conv can be "avg" and "ar".
+convTime refers to total throughput data to calculate the average convergence time with. + +And to run the python code, Scikit-learn, Scipy and Numpy will need to be installed. + diff --git a/STClient.py b/STClient.py new file mode 100644 index 0000000..45fe03b --- /dev/null +++ b/STClient.py @@ -0,0 +1,790 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Fri Jan 10 18:25:34 2020 + +@author: hem +""" + +import socket +import sys +import random as rd +import math +import time +from threading import Thread +import bisect as _bisect +import argparse +from statsmodels.tsa.ar_model import AR +from sklearn.externals import joblib +import adaptive_iterative +import warnings +warnings.filterwarnings('ignore', 'statsmodels.tsa.ar_model.AR', FutureWarning) +warnings.filterwarnings("ignore",category=FutureWarning) + +def print_message(message): + print(message) + + +class STClient: + def __init__(self, args): + self.max_write_threads = 10 + self.max_read_threads = 10 + self.max_transfer_threads = 10 + self.max_write_queue = 100 + self.max_buffer_size = 128*1024 + self.max_read_queue = 100 + + self.debug = False + self.total_transfer_done = 0 + self.transfer_done = False + self.stop_time = 0 + self.transfer_port = 0 + self.start_time = 0 + if(not args.sender or not args.receiver): + print("[+] Sender and receiver not specified") + exit(0) + sender_info = self.parse_IP(args.sender) + receiver_info = self.parse_IP(args.receiver) + self.sender_ip = sender_info[0] + self.sender_port = int(sender_info[1]) + self.sender_path = sender_info[2] + print_message("[Sender] is %s:%d%s"%(self.sender_ip, self.sender_port, self.sender_path)) + + self.receiver_ip = receiver_info[0] + self.receiver_port = int(receiver_info[1]) + self.receiver_path = receiver_info[2] + self.interface_ip = self.receiver_ip + print_message("[Receiver] is %s:%d%s"%(self.receiver_ip, self.receiver_port, self.receiver_path)) + if(args.interface): + self.interface_ip = args.interface + + def parse_IP(self, ip): + parsed_ip = [] + messages = ip.split(":") + parsed_ip.append(messages[0]) + if "~" in messages[1]: + new_msg = messages[1].split("~") + parsed_ip.append(new_msg[0]) + parsed_ip.append("~"+new_msg[1]) + else: + new_msg = messages[1].split("/", 1) + parsed_ip.append(new_msg[0]) + parsed_ip.append("/"+new_msg[1]) + return parsed_ip + + + +class TalkSend(Thread): + def __init__(self, st_client_param): + Thread.__init__(self) + self.st_client = st_client_param + self.stop_probing = False + self.sender_finish_blocks = False + self.start_current_probing = False + self.stop_current_probing = False + self.can_start_probing = False + self.current_probe_started = False + self.parameters = "" + self.parameter_list = {} + self.serversocket = None + self.client_data, self.client_addr = None, None + + def send_message(self, message): + message = str(message)+"\n" + self.socket.send(message.encode()) + def get_next_line(self): + message = self.socket.recv(1024).decode() + while("\n" not in message): + message += self.socket.recv(1024).decode() + message=message.strip() + return message + + def probe(self): + self.send_message("Parameter:"+self.parameters) + message = self.get_next_line() + if(message == "ok"): + self.sender_finish_blocks = True + while(not self.start_current_probing): + time.sleep(0.010) + self.send_message("Start:currentProbing") + message = self.get_next_line() + if(message.lower() == "ok"): + self.current_probe_started = True + while(not self.stop_current_probing): + time.sleep(0.010) + self.send_message("Check:done") + message = self.get_next_line() + msg = message.strip().split(":") + if(msg[0].lower() == "true"): + self.st_client.stop_time = int(time.time()*1000) + self.st_client.transfer_done = True + if(len(msg) >= 2): + self.st_client.total_transfer_done = int(msg[1].strip()) + self.stop_current_probing = True + self.stop_probing = True + return True + self.send_message("Stop:currentProbing") + message = self.get_next_line() + + self.stop_current_probing = False + self.current_probe_started = False + self.sender_finish_blocks = False + return True + + def run(self): + self.socket = socket.socket() +# self.server_socket.bind((self.st_client.sender_ip, self.st_client.sender_port)) + self.client_data = self.socket.connect((self.st_client.sender_ip, self.st_client.sender_port)) + + self.send_message(str(self.st_client.interface_ip)) + while(self.st_client.transfer_port == 0): + time.sleep(10) + #Sender data + self.send_message(str(self.st_client.transfer_port)) + self.send_message(str(self.st_client.sender_path)) + + #Max Parameters + self.send_message(str(self.st_client.max_transfer_threads)) + self.send_message(str(self.st_client.max_read_threads)) + self.send_message(str(self.st_client.max_read_queue)) + + self.send_message("Start:Probing") + self.get_next_line() + self.can_start_probing = True + self.st_client.start_time = int(time.time()*1000) + print_message("Starting time is: %d"%self.st_client.start_time) + + while(not self.stop_probing): + while(self.parameters == ""): + time.sleep(0.010) + self.probe() + if not self.st_client.transfer_done: + self.send_message("Start:Transfer") + print_message("[Sender] Normal transfer has been started") + self.get_next_line() + self.probe() + self.socket.close() + def set_parameter(self, param): + self.parameters = param + + +class TalkReceive(Thread): + def __init__(self, st_client_param, st_send_param): + Thread.__init__(self) + self.stop_probing = False + self.start_current_probing = False + self.stop_current_probing = False + self.convergence_found = False + self.can_start_probing = False + self.parameters = "" + self.parameter_list = {} + self.thpt_list = [] + self.talk_send = st_send_param + self.st_client = st_client_param + self.serversocket = None + self.client_data, self.client_addr = None, None + self.zero_count = 0 + + def send_message(self, message): + message = str(message)+"\n" + self.socket.send(message.encode()) + def get_next_line(self): + message = self.socket.recv(1024).decode() + while("\n" not in message): + message += self.socket.recv(1024).decode() + message=message.strip() + return message + def add_throughput(self, message, throughput_started): + if(message != ""): + thpts = message.split(",") + for i in thpts: + if(i != ""): + throughput = 0. + try: + throughput = float(i) + except: + pass + if(throughput <= 0.000001 and not len(self.thpt_list) and not throughput_started): + self.zero_count += 1 + if self.zero_count > 15: + break + continue + throughput_started = True + self.thpt_list.append(throughput) + return throughput_started + def probe(self): + self.zero_count = 0 + self.send_message("Parameter:"+self.parameters) + message = self.get_next_line() + if(message.lower() == "ok"): + self.talk_send.start_current_probing = True + while(not self.talk_send.sender_finish_blocks or not self.start_current_probing): + time.sleep(0.01) + self.start_current_probing = self.talk_send.current_probe_started + self.send_message("Start:currentProbing") + message = self.get_next_line() + self.thpt_list = [] + throughput_started = False + while(not self.stop_current_probing): + self.send_message("Get Throughput:New") + message = self.get_next_line() + throughput_started = self.add_throughput(message, throughput_started) + time.sleep(0.01) + tmp = self.thpt_list + self.thpt_list = [0]*self.zero_count + tmp + self.talk_send.stop_current_probing = True + self.send_message("Stop:currentProbing") + message = self.get_next_line() + self.stop_current_probing = False + self.start_current_probing = False + self.talk_send.start_current_probing = False + + + def run(self): + self.socket = socket.socket() + self.client_data = self.socket.connect((self.st_client.receiver_ip, self.st_client.receiver_port)) + print_message("Receiver: %s and port: %s" % (self.st_client.receiver_ip, str(self.st_client.receiver_port))) + self.send_message(self.st_client.receiver_path) + self.send_message(self.st_client.max_transfer_threads) + self.send_message(self.st_client.max_write_threads) + self.send_message(self.st_client.max_write_queue) + self.send_message(self.st_client.max_buffer_size) + self.st_client.transfer_port = int(self.get_next_line()) + self.send_message("Start:Probing") + self.get_next_line() + self.st_client.start_time = int(time.time()*1000) + + self.can_start_probing = True + while(not self.stop_probing): + while(self.parameters == ""): + time.sleep(0.01) + self.probe() + if(not self.st_client.transfer_done): + self.send_message("Start:Transfer") + print_message("[Receiver] Normal transfer has been started") + self.get_next_line() + self.probe() + + self.socket.close() + def set_parameter(self, param): + self.parameters = param + +class GA: + def __init__(self, args): + self.number_of_generations = 4 + self.number_of_population = 6 + if(args.generation): + self.number_of_generations = args.generation + if(args.population): + self.number_of_population = args.population + self.evaluation = "adaptive" + self.crossover_type = "favour_zero" + self.mutation_type = "bit_flip" + self.selection_type = "standard_deviation_elitist" + + self.params = [[1, 128], [1, 256], [10, 260], [128, 132], [1, 128], [1, 256], [1, 128]] + self.param_length = [4,8,0,0,4,8,4]#[4, 6, 6, 6, 4, 6, 4] + self.population_length = 0 + self.agents = [] + self.talk_send = None + self.talk_receive = None + self.st_client = None + self.best_score = 0 + self.best_pop = "" + self.previous_best_pop = "" + self.initiate(args) + for i in self.param_length: + self.population_length += i + for i in range(self.number_of_population): + self.agents.append(Agents(population="", population_length=self.population_length)) + if(self.conv == "dnn"): + self.load_clfs() + print("[+] CLF's loaded") + if(self.conv == "rand"): + adaptive_iterative.regression_train() + adaptive_iterative.classification_train() + print("[+] CLF's loaded") + print("[+] Conv is ", self.conv) + self.mutation_probab = 1.0#/self.population_length + self.crossover_probab = 1.0 + print("[+] Method is ", args.method) + if(args.method.lower() == "random"): + self.run_random() + else: + self.run_GA() + print_message("Best Population: %s and best score: %d" % (self.best_pop, self.best_score)) + self.talk_send.parameters = self.best_pop + self.talk_receive.parameters = self.best_pop + + self.talk_send.stop_probing = True + self.talk_receive.stop_probing = True + self.talk_receive.stop_current_probing = True + while(not self.st_client.transfer_done): + time.sleep(0.01) + self.check_transfer_done() + + + def check_transfer_done(self): + message = False + if(self.st_client.transfer_done): + print_message("Sending message to receiver") + self.talk_receive.send_message("Done:transfer") + print_message("Message send to receiver") + try: + msg = self.talk_receive.get_next_line() + print(msg) + msg = msg.split(":") + print(msg) + if("ok" in msg): + self.talk_send.stop_probing = True + self.talk_receive.stop_probing = True + self.st_client.transfer_done = True + print_message("[+] Total time for this transfer is %.3f Seconds"%((self.st_client.stop_time - self.st_client.start_time)/1000.)) + message = True + if(len(msg) == 2): + print_message(str(self.st_client.total_transfer_done) + " Bytes") + print_message("[+] Total file transfered is %.3f Bytes" % (int(self.st_client.total_transfer_done)/(1024.*1024.*1024.))) + print_message("[+] Average throughput is %.3f Mbps" % ((1000 * 8 * int(self.st_client.total_transfer_done))/((self.st_client.stop_time - self.st_client.start_time)*1000.*1000.))) + exit(0) + except: + self.talk_send.stop_probing = True + self.talk_receive.stop_probing = True + self.st_client.transfer_done = True + print_message("[+ err] Total time for this transfer is %.3f Seconds"%((self.st_client.stop_time - self.st_client.start_time)/1000.)) + print_message("[+ err] Average throughput is %.3f Mbps" % ((1000 * 8 * int(self.st_client.total_transfer_done))/((self.st_client.stop_time - self.st_client.start_time)*1000.*1000.))) + message = True + exit(0) + return message + + def initiate(self, args): + self.st_client = STClient(args) + self.st_client.max_read_threads = self.params[0][0]+2**self.param_length[0]+1 + self.st_client.max_transfer_threads = self.params[4][0]+2**self.param_length[4]+1 + self.st_client.max_write_threads = self.params[6][0]+2**self.param_length[6]+1 + self.st_client.start_time = int(time.time()*1000) + self.talk_send = TalkSend(self.st_client) + self.talk_receive = TalkReceive(self.st_client, self.talk_send) + self.talk_send.start() + self.talk_receive.start() + self.conv = args.conv +# self.talk_send.join() +# self.talk_receive.join() + + def run_GA(self): + ags = self.agents + for i in range(self.number_of_generations): + self.evaluate_population(ags) +# self.evaluate_population(ags) + if self.st_client.transfer_done: + exit(0) +# if i >= (self.number_of_generations/2): +# self.best_score = 0. + ags = self.selection(ags+self.agents, i) + if (i+1) != self.number_of_generations: + ags = self.crossover(ags) + ags = self.mutation(ags) + if i>1: + self.agents = ags + print_message("[+] Current complete generation is: %d"%(i+1)) + self.choose_best_pop() + def run_random(self): + self.agents = [] + for i in range(self.number_of_generations*self.number_of_population): + self.agents.append(Agents(population="", population_length=self.population_length)) + self.evaluate_population(self.agents) +# self.evaluate_population(self.agents) + self.selection(self.agents, 0) + + def selection(self, agents, generation): + ags = [] + tmp = [] + + + #Comment this out in final version. + is_chameleon = False + if(generation == 0 and is_chameleon): + tmp = agents[:3] + agents = agents[5:] + + + + agents = sorted(agents, reverse=True) + if self.best_score (1+higher_than) * avg_thpt: + ags.append(agent) + if len(ags)<2: + ags = self.percentage_elitist(agents, higher_than=higher_than-0.05) + for i in range(len(ags), self.number_of_population): + new_ags.append(rd.choice(ags)) + return ags+new_ags + def standard_deviation_elitist(self, agents, number_of_std=1): + ags = [] + new_ags = [] + avg_thpt = self.get_average_generation_thpt(agents) + std_ = self.get_generation_std(agents) + print_message("[Genaration] average throughput of generation is %.3f Mbps and larger than %.3f Mbps" % (avg_thpt, avg_thpt+(number_of_std*std_))) + for agent in agents: + if agent.get_score() >= (avg_thpt+(std_*number_of_std)): + ags.append(agent) + if len(ags)<2: + ags = self.standard_deviation_elitist(agents, number_of_std-0.5) + for i in range(len(ags), self.number_of_population): + new_ags.append(rd.choice(ags)) + return ags+new_ags + + def get_ranking(self, length): + ranks = [i for i in range(length, 0, -1)] + sum_r = sum(ranks) + return [(1.0*i)/sum_r for i in ranks] + + def crossover(self, agents): + ags = [] + for i in range(self.number_of_population//2): + p1 = agents[2*i].get_population() + p2 = agents[2*i+1].get_population() + c1, c2 = p1, p2 + if rd.random() <= self.crossover_probab: + if self.crossover_type == "single_point": + c1, c2 = self.single_point_crossover(p1, p2) + elif self.crossover_type == "favor_zero": + c1, c2 = self.favor_zero(p1, p2) + ags.append(Agents(c1)) + ags.append(Agents(c2)) + return ags + def favor_zero(self, p1, p2): + p1 = self.get_bin_array(p1) + p2 = self.get_bin_array(p2) + c1 = "" + c2 = "" + for i in range(len(p1)): + curr_1 = p1[i] + curr_2 = p2[i] + for j in range(len(curr_1)): + char_1 = int(curr_1[j]) + char_2 = int(curr_2[j]) + if char_1 == char_2: + c1 += str(char_1) + c2 += str(char_2) + else: + if char_1 == 0: + c1 += str(char_1) + c2 += '0' if rd.random() < 0.5+((len(curr_2)-j)/50) else '1' + else: + c1 += '0' if rd.random() < 0.5+((len(curr_2)-j)/50) else '1' + c2 += str(char_2) + return c1, c2 + def get_bin_array(self, p1): + string_array = [] + since_last = 0 + for i in range(len(self.param_length)): + till = since_last + self.param_length[i] + current = p1[since_last:till] + string_array.append(current) + since_last = till + return string_array + + def single_point_crossover(self, p1, p2): + index = rd.randint(0, len(p1)-1) + c1 = p1[:index] + p2[index:] + c2 = p2[:index] + p1[index:] + return c1, c2 + + def mutation(self, agents): + ags = [] + for agent in agents: + if rd.random() <= self.mutation_probab: + if self.mutation_type == "bit_flip": + agent = Agents(self.bit_flip_mutation(agent.get_population())) + ags.append(agent) + return ags + + def bit_flip_mutation(self, population): + index = rd.randint(0, len(population)-1) + tmp = "0" if population[index]=="1" else "1" + return population[:index] + tmp + population[index+1:] + + def get_param_string(self, population): + value_string = "" + since_last = 0 + for i in range(len(self.param_length)): + till = since_last + self.param_length[i] + current = population[since_last:till] + if i: + value_string += "," + since_last = till + value_string += str(self.get_value(current, i)) + return value_string + def get_bin_string(self, param_str): + val_str = param_str.split(",") + pop = "" + for i in range(len(val_str)): + value = str(bin(int(val_str[i]) - self.params[i][0]))[2:] + pop += self.get_bin_value(value, self.param_length[i]) + return pop + def get_bin_value(self, value, length): + if length == 0: + return "" + zeros = "0" * (length-len(value)) + return zeros + value + + def get_value(self, current, i): + if not current: + return self.params[i][0] + return int(current, 2)+self.params[i][0] + + def evaluate_population(self, agents): + for agent in agents: + print_message("[Agent] for agent "+self.get_param_string(agent.get_population())) + + evaluation_start = time.time() + agent.set_thpt(self.evaluate(agent)) + if(self.check_transfer_done()): + print_message("Transfer is done in %.3f Seconds"%((self.st_client.stop_time - self.st_client.start_time)/1000.)) + break + print_message("[Agent] for agent %s throughput is: %.3f Mbps avgThpt: %.3f in total time %.3f"%(self.get_param_string(agent.get_population()), agent.get_thpt(), agent.get_avg_thpt(), time.time() - evaluation_start)) + + def evaluate(self, agent): + self.convergence_thpt = {} + parameter = self.get_param_string(agent.get_population()) + throughput = 0. + self.talk_receive.set_parameter(parameter) + self.talk_send.set_parameter(parameter) + while(not self.talk_send.can_start_probing or not self.talk_receive.can_start_probing): + time.sleep(0.01) + self.talk_send.start_current_probing = True + while(throughput == 0.0 and len(self.talk_receive.thpt_list)<1500 and not self.st_client.transfer_done): + time.sleep(0.01) + thpt_list = self.talk_receive.thpt_list + if len(thpt_list) not in self.convergence_thpt: + throughput = self.find_convergence(thpt_list) + self.convergence_thpt[len(thpt_list)] = throughput + else: + throughput = self.convergence_thpt[len(thpt_list)] +# print("[Throughput] list" + str(self.talk_receive.thpt_list)) + if throughput == 0.: + throughput = self.find_average_thpt(self.talk_receive.thpt_list) + agent.set_avg_thpt(self.find_average_thpt(self.talk_receive.thpt_list)) + self.talk_receive.thpt_list = [] + self.talk_receive.stop_current_probing = True + return throughput + def find_convergence(self, thpt_list): + if(self.conv == "ar"): + return self.find_convergence_timeseries(thpt_list) + elif(self.conv == "dnn"): + return self.find_convergence_dnn(thpt_list) + elif self.conv == "rand": + if len(thpt_list) >= 2 and adaptive_iterative.is_predictable(thpt_list): + print(thpt_list) + return adaptive_iterative.make_prediction(thpt_list) + else: + return 0.0 + else: + return self.find_average_thpt(thpt_list) if len(thpt_list)>=10 else 0.0 + def find_convergence_timeseries(self, thpt_list): + if(len(thpt_list)<4): + return 0. + if(len(thpt_list)>15): + return self.find_average_thpt(thpt_list) + tmp = [0]+thpt_list[:-1] + model = AR(tmp) + start_params = [0, 0, 1] + + model_fit = model.fit(maxlag=1, start_params=start_params, disp=-1) + predicted_last = model_fit.predict(len(tmp), len(tmp))[0] + last_pt = thpt_list[-1] + + if( (last_pt != 0.) and (predicted_last - last_pt)/last_pt < 0.1): + return predicted_last + return 0. + def load_clfs(self): + self.all_clfs = {} + for i in range(3, 16): + self.all_clfs[i] = joblib.load("./clfs/pronghorn-10-%d-42-percentage-optimal.pkl"%i) + def get_percentage_change_thpts(self, thpt_list): + if len(thpt_list) <= 1: + return [] + new_thpt = [] + prev_thpt = thpt_list[0] + for index in range(1, len(thpt_list)): + perc = thpt_list[index] - prev_thpt + new_thpt.append(perc / (prev_thpt+1.5)) + prev_thpt = thpt_list[index] + return new_thpt + + def find_convergence_dnn(self, thpt_list): + threshold = 1.0 +# print("Actual Thpt list", thpt_list) +# print(thpt_list) + prev_thpt_list = thpt_list + thpt_list = self.get_percentage_change_thpts(thpt_list) +# print("Percentage Change thpt", thpt_list) +# print(thpt_list) + if len(thpt_list)<3: + return 0 + elif len(thpt_list)>=15: + return self.find_average_thpt(thpt_list) + i = len(thpt_list) + y_pred = self.all_clfs[i].predict_proba([thpt_list])[0] + max_, ind_ = self.get_max_and_index(y_pred) + print("[+] ", max_, threshold - 0.05*(len(thpt_list) - 2), ind_, i, len(prev_thpt_list)) +# print("CT", ind_, " Prediction probability", max_, " Threshold", threshold - 0.05*(len(thpt_list) - 2)) + if(max_ > (threshold - 0.05*(len(thpt_list) - 2)) and ind_+2 <= i+1): + return self.find_average_thpt(prev_thpt_list) + return 0.0 + def get_max_and_index(self, lis): + max_ = lis[0] + ind_ = 0 + for i in range(len(lis)): + if max_ <= lis[i]: + max_ = lis[i] + ind_ = i + return max_, ind_ + def find_average_thpt(self, thpt_list): + if not thpt_list: + return 0. + return (1.0*sum(thpt_list))/len(thpt_list) + def choose_best_pop(self): + pass + def get_random_choices(self, population, weights=None, cum_weights=None, k=1): + random = rd.random + if cum_weights is None: + if weights is None: + _int = int + total = len(population) + return [population[_int(random() * total)] for i in range(k)] + cum_weights = [] + last_val = 0 + for i in weights: + last_val += i + cum_weights.append(last_val) + elif weights is not None: + raise TypeError('Cannot specify both weights and cumulative weights') + if len(cum_weights) != len(population): + raise ValueError('The number of weights does not match the population') + bisect = _bisect.bisect + total = cum_weights[-1] + hi = len(cum_weights) - 1 + return [population[bisect(cum_weights, random() * total, 0, hi)] + for i in range(k)] + def get_average_generation_thpt(self, agents): + total_score = 0 + for agent in agents: + total_score += agent.get_score() + return total_score/len(agents) +# def get_generation_std(self, agents): +# scores = [] +# for i in agents: +# scores.append(i.get_score()) +# return np.array(scores).std() + def get_generation_std(self, agents): + n = len(agents) + if n <= 1: + return 0.0 + mean, sd = self.get_average_generation_thpt(agents), 0.0 + # calculate stan. dev. + for el in agents: + sd += (float(el.get_score()) - mean)**2 + sd = math.sqrt(sd / float(n-1)) + return sd + +class Agents: + def __init__(self, population="", population_length=0): + self.population = population + self.throughput = -1 + self.average_throughput = 0 + self.memory_error = False + if population == "": + self.population = self.get_random_string(population_length) + + def get_random_string(self, population_length): + to_ret = "" + for i in range(population_length): + to_ret += rd.choice(["0", "1"]) + return to_ret + def set_avg_thpt(self, thpt): + if self.average_throughput: + self.average_throughput += thpt + self.average_throughput = self.average_throughput/2. + else: + self.average_throughput = thpt + def get_avg_thpt(self): + return self.average_throughput + def set_thpt(self, thpt): + if self.throughput != -1: + self.throughput += thpt + self.throughput /= 2. + else: + self.throughput = thpt + def get_thpt(self): + return self.throughput + def get_population(self): + return self.population + def get_score(self): + if self.get_thpt()>0.0: + return self.get_thpt() + return self.get_avg_thpt() + self.get_thpt() + def __lt__ (self, other): + return self.get_score() < other.get_score() + def __gt__(self, other): + return self.get_score() > other.get_score() + def __eq__(self, other): + return self.get_score() == other.get_score() + +if __name__=="__main__": + parser = argparse.ArgumentParser(description='Parameters in the application') + parser.add_argument('sender', type=str, + help='Sender information') + parser.add_argument('receiver', type=str, + help='Receiver information') + parser.add_argument('--interface', type=str, + help='Interface to send information in the receiver') + parser.add_argument('--generation', type=int, + help='Number of generation for GA') + parser.add_argument('--population', type=int, + help='Number of population for GA') + parser.add_argument('--method', type=str, + help='Method of algorithm to use') + parser.add_argument('--conv', type=str, + help='Method of algorithm to use') + args = parser.parse_args() + print("Argument values:") + print(args.sender) + print(args.receiver) + if(not args.method): + args.method = "GA" + if(not args.conv): + args.method = "avg" + ga = GA(args) + + \ No newline at end of file diff --git a/STReceiver.java b/STReceiver.java new file mode 100644 index 0000000..4a31de9 --- /dev/null +++ b/STReceiver.java @@ -0,0 +1,545 @@ +import java.io.*; +import java.net.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicLong; + +public class STReceiver { + private int numberOfWrite = 2; + private int sizeOfQueue = 100; + private boolean debug = false; + private int transferPort = 48892; + private int maxBufferSize = 128*1024; + private long throughputCycle = 1000l; + private int clientPort = 53823; + private String throughputMessage = ""; + private Boolean transferDone = false; + private long numberOfMaxConnections = 0; + private AtomicLong totalWriteDone = new AtomicLong(0); + private AtomicLong totalTransferDone = new AtomicLong(0); + private long tillLastWrite = 0l; + private long sinceLastTime = 0l; + private long tillLastTransfer = 0l; + private long startTime = 0l; + private String status = "TRANSFER"; + private String toDir = "/data/hem/"; + private ExecutorService writePool = Executors.newFixedThreadPool(200); + private HashMap writeBlocks = new HashMap<>(); + private ExecutorService receivePool = Executors.newFixedThreadPool(200); + private HashMap receiveBlocks = new HashMap<>(); + private static LinkedBlockingQueue blocks = new LinkedBlockingQueue<>(1500); + private HashMap filesNames = new HashMap<>(); + + public STReceiver(){ + + } + void startListeningToSender(){ + STReceiver.OpenTransferThread ott = this.new OpenTransferThread(this, this.transferPort); + Thread tott = new Thread(ott, "0"); + tott.start(); + } + public static void main(String[] args){ + STReceiver str = new STReceiver(); + STReceiver.TalkClient tc = str.new TalkClient(str); + Thread ttc = new Thread(tc, "TalkClient"); + ttc.start(); + System.out.println("Started Receiver at "+str.clientPort+" port"); + long iter = 0l; + long tillCertainTransfer = 0l; + long lastCertainTime = 0l; + while(!str.transferDone){ + if(str.startTime != 0) { + long totalTransfer = str.totalTransferDone.get(); + long lastTransfer = str.tillLastTransfer; + str.tillLastTransfer = totalTransfer; + long totalW = str.totalWriteDone.get(); + long lastW = str.tillLastWrite; + str.tillLastWrite = totalW; + + long lastTime = str.sinceLastTime; + str.sinceLastTime = System.currentTimeMillis(); + + double timeInterval = str.sinceLastTime - lastTime; + if(timeInterval>=1000) { + double thpt = 8 * 1000 * (totalW - lastW) / (timeInterval * 1000 * 1000); + double thpt_transfer = 8*1000*(totalTransfer-lastTransfer)/(timeInterval *1000*1000); + if (str.status.equalsIgnoreCase("probing") && (iter % 1) == 0) { + long thisTime = System.currentTimeMillis(); + double avgThpt = Math.round(100.0*8*1000*(totalW-tillCertainTransfer)/((thisTime-lastCertainTime)*1000.*1000.))/100.0; + + /* + System.out.println("Transfer Done till: " + totalW / (1024. * 1024.) + + "MB in time: " + (System.currentTimeMillis() - str.startTime) / 1000. + + " seconds and Throughput is:" + avgThpt + "Mbps time: " + timeInterval + " tnsferDone: " + (totalW-tillCertainTransfer) + " blocks: "+STReceiver.blocks.size()); + //*/ + System.out.println("Transfer Done till: " + totalW / (1024. * 1024.) + "MB in time: "+(System.currentTimeMillis() - str.startTime) / 1000. + + " seconds and thpt is: "+avgThpt+" Mbps" + " blocks: "+STReceiver.blocks.size()); + tillCertainTransfer = totalW; + lastCertainTime = thisTime; + } + + + if (str.status.equalsIgnoreCase("probing") && tc.currentProbeStarted) {//PROBING + synchronized (str.throughputMessage) { + if (!str.throughputMessage.equalsIgnoreCase("")) { + str.throughputMessage += "," + (thpt*0.95+thpt_transfer*0.05); + } else { + str.throughputMessage += thpt; + } + } + } + } + } + try{ + for(int i = 0; i < 100; i++) { + Thread.sleep(str.throughputCycle/100); + } + }catch(InterruptedException e){ + e.printStackTrace(); + } + iter++; + } + + System.out.println("[+] Receiver has stopped"); + try{ + Thread.sleep(5000); + }catch(InterruptedException e){ + + } + System.exit(0); + + } + public int getTransferPort(){ + return this.transferPort; + } + public void closeConnections(){ + for (int i = 0; i < this.receiveBlocks.size(); i++) { + try { + this.receiveBlocks.get(i).close(); + }catch (IOException e){} + } + } + public void startWriteThreads(int count){ + this.stopAllWriteThreads(); + if(this.writeBlocks.size() < count){ + for(int i=this.writeBlocks.size(); i(); + } + ServerSocket socketReceive = new ServerSocket(this.communicationPort); + Socket clientSock = socketReceive.accept(); + DataInputStream dataInputStream = new DataInputStream(clientSock.getInputStream()); + DataOutputStream dos = new DataOutputStream(clientSock.getOutputStream()); + this.stReceiver.numberOfMaxConnections = dataInputStream.readLong(); + this.stReceiver.startTime = System.currentTimeMillis(); + int startPort = 61024; + for(int i=startPort;i byteArray; + long tillNow = 0l; + boolean bufferLoaded = false; + long startTime = 0l; + + + Block(long offset, long length){ + this.offset = offset; + this.length = length; + byteArray = new ArrayList(); + + } + void add_buffer(byte[] bff, int buffer_size){ + byteArray.add(new Buffer(bff, buffer_size)); + tillNow += buffer_size; + + } + void remove_buffer(){ + this.byteArray = null; + } + void setOffset(long offset){this.offset=offset;} + void setFilename(String fn){this.filename=fn;} + void setFileId(long fi){this.fileId=fi;} + void setBlockId(long bi){this.blockId=bi;} + } + class Buffer { + byte[] small_buffer; + int length; + + Buffer(byte[] buffer, int buffer_size){ + small_buffer = Arrays.copyOf(buffer, buffer_size); + length = buffer_size; + } + } + class WriteBlock implements Runnable{ + STReceiver stReceiver = null; + boolean waitNext = false; + WriteBlock(STReceiver str){ + this.stReceiver = str; + } + void stopThread(){this.waitNext = true;} + void startThread(){this.waitNext=false;} + + @Override + public void run() { + if(this.stReceiver.debug) { + System.out.println("[+] Write Thread-" + Thread.currentThread().getName() + " has Started."); + } + //Start reading block and write to this.stSender.blocks + while(true){ + Block currentBlock = null; + while(STReceiver.blocks.isEmpty()){ + try { + Thread.sleep(10); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + try { + currentBlock = STReceiver.blocks.poll(50, TimeUnit.MILLISECONDS); + if(currentBlock == null) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + e.printStackTrace(); + } + if(this.stReceiver.transferDone && + STReceiver.blocks.size() == 0){ + break; + } + continue; + } + long st_wite = System.currentTimeMillis(); + String filename = this.stReceiver.toDir + "/" + currentBlock.filename; + RandomAccessFile randomAccessFile = filesNames.remove(filename); + if(randomAccessFile == null) { + randomAccessFile = new RandomAccessFile(filename, "rw"); + } + if (currentBlock.offset > 0) { + randomAccessFile.getChannel().position(currentBlock.offset); + } + for(Buffer buffer: currentBlock.byteArray){ + randomAccessFile.write(buffer.small_buffer, 0, buffer.length); + + buffer.small_buffer = null; + } + if(this.stReceiver.filesNames.containsKey(filename)){ + randomAccessFile.close(); + }else { + filesNames.put(filename, randomAccessFile); + } + + long done_wrt = System.currentTimeMillis(); + //System.out.println("[+] Block is done of blockId"+currentBlock.blockId+ " in total time "+(done_wrt-currentBlock.startTime)+"ms & and write time "+(done_wrt-st_wite)+"ms total bolcks is "+STReceiver.blocks.size()); + + if(this.stReceiver.debug){ + System.out.println("[Block Written "+System.currentTimeMillis()+"] fileId: "+currentBlock.fileId+" blockId: "+currentBlock.blockId + " offset: "+currentBlock.offset + + " threadName:"+Thread.currentThread().getName()); + } + currentBlock = null; + }catch(Exception e){ + e.printStackTrace(); + } + } + } + } + class ReceiveFile implements Runnable{ + STReceiver stReceiver = null; + boolean waitNext = false; + boolean closeConnection = false; + Socket clientSock = null; + ServerSocket socket_receive = null; + int sendFileId = 0; + Socket s = null; + int port = 0; + + DataInputStream dataInputStream = null; + + ReceiveFile(STReceiver str, int port){ + //Start a connection here + this.stReceiver = str; + this.port = port; + } + void stopThread(){this.waitNext = true;} + void startThread(){this.waitNext=false;} + void close() throws IOException{ + if(clientSock!=null){ + clientSock.close(); + } + if(socket_receive!=null){ + socket_receive.close(); + } + } + @Override + public void run() { + if (this.stReceiver.debug) { + System.out.println("[+] Receive Thread-" + Thread.currentThread().getName() + " has Started."); + } + try { + socket_receive = new ServerSocket(port); + socket_receive.setReceiveBufferSize(32*1024*1024); + clientSock = socket_receive.accept(); + + dataInputStream = new DataInputStream(clientSock.getInputStream()); + while (!closeConnection) { + while (STReceiver.blocks.remainingCapacity() == 0) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + try { + int read = 0; + String filename = dataInputStream.readUTF(); + if (filename.equals("done")) { + break; + } + long length = dataInputStream.readLong(); + long offset = dataInputStream.readLong(); + long fileId = dataInputStream.readLong(); + long blockId = dataInputStream.readLong(); + + Block currentBlock = new Block(offset, length); + currentBlock.startTime = System.currentTimeMillis(); + currentBlock.setFilename(filename); + currentBlock.setFileId(fileId); + currentBlock.setBlockId(blockId); + if(this.stReceiver.debug) { + System.out.println("[Block Received "+System.currentTimeMillis()+" Started] " + currentBlock.filename + " fileId: " + currentBlock.fileId + " blockId: " + currentBlock.blockId + + " offset: " + currentBlock.offset + " length: " + currentBlock.length + " blockLoaded: " + currentBlock.written); + } + byte[] buffer = new byte[maxBufferSize]; + + + while (currentBlock.tillNow < currentBlock.length) { + //long st_tim = System.currentTimeMillis(); + read = dataInputStream.read(buffer, 0, (int) Math.min(buffer.length, currentBlock.length - currentBlock.tillNow)); + totalTransferDone.addAndGet(read); + //long new_st_tim = System.currentTimeMillis(); + currentBlock.add_buffer(buffer, read); + totalWriteDone.addAndGet(read); + //totalWriteDone.addAndGet(read); + //long latest_st = System.currentTimeMillis(); + //System.out.println("Receive Blockid = "+currentBlock.blockId+" and read time "+ (new_st_tim-st_tim)+"ms and add time "+(latest_st-new_st_tim)+"ms"); + } + boolean doneAddition = true; + + if(this.stReceiver.debug) { + System.out.println("[Block Received "+System.currentTimeMillis()+" Done] " + currentBlock.filename + " fileId: " + currentBlock.fileId + " blockId: " + currentBlock.blockId + + " offset: " + currentBlock.offset + " length: " + currentBlock.length + " blockLoaded: " + currentBlock.written); + } + currentBlock.written = true; + + } catch (IOException e) { + e.printStackTrace(); + } + } +// this.stReceiver.transferDone = true; + }catch(IOException e){ + e.printStackTrace(); + } + } + } + class TalkClient implements Runnable{ + private Socket talkSocket = null; + private ServerSocket serverSocket = null; + private STReceiver stReceiver = null; + + + BufferedReader readFromClient = null; + PrintStream sendClient = null; + boolean stopProbing = false; + boolean startCurrentProbing = false; + boolean stopCurrentProbing = false; + boolean currentProbeStarted = false; + + + public TalkClient(STReceiver str){ + this.stReceiver = str; + } + String parseMessage(String message){ + String[] messages = message.split(":"); + if(messages[0].equalsIgnoreCase("start")){ + if(messages[1].equalsIgnoreCase("transfer")){ + this.stReceiver.status = "TRANSFER"; + System.out.println("[+] Normal transfer has started"); + this.stReceiver.startProbing(); + return "ok"; + }else if(messages[1].equalsIgnoreCase("probing")){ + this.stReceiver.status = "PROBING"; + return "ok"; + }else if(messages[1].equalsIgnoreCase("currentProbing")){ + this.stReceiver.status = "PROBING"; + this.stReceiver.startProbing(); + this.currentProbeStarted = true; + this.stReceiver.sinceLastTime = System.currentTimeMillis(); + return "ok"; + } + }else if(messages[0].equalsIgnoreCase("stop")){ + if(messages[1].equalsIgnoreCase("currentProbing")){ + this.currentProbeStarted = false; + this.stReceiver.stopEverything(); + return "ok"; + }else if(messages[1].equalsIgnoreCase("everything")){ + return "ok"; + } + }else if(messages[0].equalsIgnoreCase("parameter")){ + System.out.println("Parameter: Done"); + while(STReceiver.blocks.size() !=0){ + try{ + Thread.sleep(10); + }catch(InterruptedException e){ } + } + String[] params = messages[1].split(","); + this.stReceiver.sizeOfQueue = Integer.parseInt(params[5]); + ///* + synchronized (STReceiver.blocks){ + STReceiver.blocks = new LinkedBlockingQueue<>(this.stReceiver.sizeOfQueue); + } + //*/ + this.stReceiver.numberOfWrite = Integer.parseInt(params[6]); + System.out.println("Parameter: "+messages[1]); + return "ok"; + }else if(messages[0].equalsIgnoreCase("get throughput")){ + String thpts = ""; + synchronized (this.stReceiver.throughputMessage){ + thpts = this.stReceiver.throughputMessage; + this.stReceiver.throughputMessage = ""; + } + return thpts; + }else if(messages[0].equalsIgnoreCase("done")){ + while(STReceiver.blocks.size() != 0){ + try{ + Thread.sleep(10); + }catch(InterruptedException e){} + } + System.out.println("Receiver is done"); + System.out.println("Transfer should be done"); + this.stReceiver.transferIsDone(); + this.stopProbing = true; +// double write_done = this.stReceiver.totalWriteDone.get() / 1000000000; +// double thpt = Math.ceil(write_done) * 1.024*1.024*8; + System.out.println(this.stReceiver.totalWriteDone); + this.stReceiver.transferDone = true; + return ":ok:"+this.stReceiver.totalWriteDone; + } + return "ok"; + } + public void run(){ + try { + this.serverSocket = new ServerSocket(this.stReceiver.clientPort); + this.talkSocket = serverSocket.accept(); + + this.readFromClient= new BufferedReader(new InputStreamReader(this.talkSocket.getInputStream())); + this.sendClient = new PrintStream(this.talkSocket.getOutputStream()); + + this.stReceiver.toDir = this.readFromClient.readLine().trim(); + this.stReceiver.receivePool = Executors.newFixedThreadPool(Integer.parseInt(this.readFromClient.readLine().trim())); + this.stReceiver.writePool = Executors.newFixedThreadPool(Integer.parseInt(this.readFromClient.readLine().trim())); + this.stReceiver.blocks = new LinkedBlockingQueue<>(Integer.parseInt(this.readFromClient.readLine().trim())); + this.stReceiver.maxBufferSize = Integer.parseInt(this.readFromClient.readLine().trim()); + this.stReceiver.startListeningToSender(); + this.sendMessage(""+this.stReceiver.getTransferPort()); + + //String init_ = this.readFromClient.readLine(); + //this.sendMessage("ok"); + + while(!this.stopProbing){ + String receivedMessage = this.readFromClient.readLine(); + String message = this.parseMessage(receivedMessage); + this.sendMessage(message); + } + try{ + Thread.sleep(1000); + }catch(InterruptedException e){ + + } + this.close(); + }catch(IOException e){ + e.printStackTrace(); + } + } + void sendMessage(String message){ + this.sendClient.println(message); + } + public void close() throws IOException{ + talkSocket.close(); + readFromClient.close(); + sendClient.close(); + } + } +} diff --git a/STSender.java b/STSender.java new file mode 100644 index 0000000..db1fa86 --- /dev/null +++ b/STSender.java @@ -0,0 +1,587 @@ +import java.io.*; +import java.net.*; +import java.util.*; +import java.util.Queue; +import java.util.concurrent.*; + +public class STSender { + private String receiverIP = "192.168.1.1"; + private int receiverPort = 48892; + int numberOfRead = 2; + private int sizeOfQueue = 100; + private int numberOfConnection = 1; + private int numberOfMaxConnections = 10; + int clientPort = 52005; + private long blockSize = 256*1024*1024l; + private long bufferSize = 128*1024l; + private boolean debug = false; + String status = "TRANSFER"; + private Boolean transferDone = false; + private String fromDir = "/data/aalhussen/10Mfiles"; + private static Queue files = new LinkedList();//Collections.asLifoQueue(new ArrayDeque<>()); + private ExecutorService readPool = Executors.newFixedThreadPool(200); + private HashMap readBlocks = new HashMap<>(); + private ExecutorService sendPool = Executors.newFixedThreadPool(200); + private HashMap sendBlocks = new HashMap<>(); + private static LinkedBlockingQueue blocks = new LinkedBlockingQueue<>(1500); + private boolean sendingDone = false; + private boolean fileFinished = false; + private static Long totalByteToSend = 0l; + private long last_file_id = 0l; + private long total_files = 0l; + private STSender.OpenTransferThread ott = null; + private HashMap> openFiles = new HashMap<>(); + + static int fileNum = 0; + + public STSender(){ + + } + public static void main(String[] args){ + STSender sts = new STSender(); + STSender.TalkClient tc = sts.new TalkClient(sts); + Thread ttc = new Thread(tc, "TalkClient"); + ttc.start(); + try { + ttc.join(); + }catch(InterruptedException e){ + + } + System.out.println("done"); + try{ + Thread.sleep(15000); + }catch(InterruptedException e){ + + } + System.exit(0); + } + void talkWithReceiver(STSender sts){ + ott = sts.new OpenTransferThread(sts, sts.receiverIP, sts.receiverPort); + Thread tott = new Thread(ott, "0"); + tott.start(); + } + public void startReadThreads(int count){ + this.stopAllReadThreads(); + if(this.readBlocks.size() < count){ + for(int i=this.readBlocks.size(); i= this.stSender.sizeOfQueue){ + try { + Thread.sleep(10); + continue; + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + long newOffset = 0; + long length = 0; + long blockId = 0; + int n; + FileInputStream fis = null; + synchronized (STSender.files) { + if(STSender.files.size()==0){ + this.stSender.fileFinished = true; + } + head = STSender.files.poll(); + if(head!=null){ + blockId = head.blockId; + head.blockId++; + newOffset = head.offset; + length = Math.min(this.stSender.blockSize, head.length - newOffset); + head.offset += length; + if (head.offset < head.length) { + STSender.files.add(head); + }else if (this.stSender.last_file_id < this.stSender.total_files * 3.5){ + STSender.files.add(new TransferFile(head.file, head.filename, 0l, head.length, this.stSender.last_file_id++)); + STSender.totalByteToSend += head.length; + } + } + } + try { + if(head == null) { + if(STSender.files.size() == 0) { + } + try { + Thread.sleep(100); + } catch (InterruptedException e) { + e.printStackTrace(); + } + continue; + } + currentBlock = new Block(newOffset, length); + currentBlock.filename = head.filename; + currentBlock.fileId = head.fileId; + currentBlock.blockId = blockId; + currentBlock.fileLength = head.length; + + byte[] buffer = new byte[(int)buffer_size]; + + if(this.stSender.debug) { + System.out.println("[Block Load "+System.currentTimeMillis()+" started] "+currentBlock.filename + " fileId: " + currentBlock.fileId + " blockId: " + currentBlock.blockId + + " offset: " + currentBlock.offset + " length: " + currentBlock.length + " fileLength: " + + currentBlock.fileLength + " blockLoaded: "+ currentBlock.bufferLoaded + " threadName: "+Thread.currentThread().getName()); + } + while ((int) Math.min(buffer.length, currentBlock.length - currentBlock.tillNow) > 0) { + n = (int) Math.min(buffer.length, currentBlock.length - currentBlock.tillNow); + currentBlock.add_buffer(buffer, n); + if (currentBlock.tillNow >= currentBlock.length) { + boolean isSuccess = false; + while (!isSuccess) { + isSuccess = STSender.blocks.offer(currentBlock, 100, TimeUnit.MILLISECONDS); + } + currentBlock.bufferLoaded = true; + if(this.stSender.debug) { + System.out.println("[Block Loaded "+System.currentTimeMillis()+"] "+currentBlock.filename + " fileId: " + currentBlock.fileId + " blockId: " + currentBlock.blockId + + " offset: " + currentBlock.offset + " length: " + currentBlock.length + " fileLength: " + + currentBlock.fileLength + " blockLoaded: "+ currentBlock.bufferLoaded + " threadName: "+Thread.currentThread().getName()); + } + } + } + //Might need to initiate buffer to null and to buffer_size again + } catch (Exception e) { + e.printStackTrace(); + try { + if (fis != null) { + fis.close(); + } + }catch(IOException ee){ + ee.printStackTrace(); + } + if(currentBlock!=null && head!=null){ + synchronized (STSender.files){ + //File file, String filename, long offset, long length, long fileId + STSender.files.add(new TransferFile(head.file, currentBlock.filename, currentBlock.offset, currentBlock.length, currentBlock.fileId)); + currentBlock.bufferLoaded = false; + } + } + break; + } + } + System.out.println("Blocks in blocks: "+STSender.blocks.size()); + } + } + class Block{ + long offset = 0l; + long length = 0l; + String filename = ""; + long fileId = 0l; + long fileLength = 0l; + boolean written = false; + long blockId = 0l; + List byteArray; + long tillNow = 0l; + boolean bufferLoaded = false; + + + Block(long offset, long length){ + this.offset = offset; + this.length = length; + byteArray = new ArrayList(); + + + } + void add_buffer(byte[] bff, int buffer_size){ + byteArray.add(new Buffer(bff, buffer_size)); + tillNow += buffer_size; + + } + void remove_buffer(){ + this.byteArray = null; + } + void setOffset(long offset){this.offset=offset;} + void setFilename(String fn){this.filename=fn;} + void setFileId(long fi){this.fileId=fi;} + void setBlockId(long bi){this.blockId=bi;} + } + class Buffer { + byte[] small_buffer; + int length; + + Buffer(byte[] buffer, int buffer_size){ + small_buffer = Arrays.copyOf(buffer, buffer_size); + length = buffer_size; + } + } + void collect_files(){ + File file = new File(this.fromDir); + long fileId = 0l; + if (file.isDirectory()) { + for (File f : file.listFiles()) { + long fileLeng = f.length(); + STSender.totalByteToSend += fileLeng; + STSender.files.add(new TransferFile(f, f.getName(), 0, f.length(), fileId)); + this.openFiles.put(this.fromDir + "/" + f.getName(), new ArrayList<>()); + fileId++; + } + } else { + long fileLeng = file.length(); + STSender.totalByteToSend += fileLeng; + STSender.files.add(new TransferFile(file, file.getName(), 0, file.length(), fileId)); + this.fromDir = file.getParent(); + this.openFiles.put(this.fromDir + "/" + file.getName(), new ArrayList<>()); + fileId++; + } + this.last_file_id = fileId; + this.total_files = fileId - 1; + System.out.println("[+] Sending "+files.size()+" files to receiver."); + System.out.println("[+] Sending "+Math.round(100*STSender.totalByteToSend/(1024.0*1024*1024.))/100.0+" GB data to receiver."); + + } + class TransferFile{ + String filename = ""; + long offset = 0l; + long length = 0l; + long fileId = 0l; + File file = null; + long blockId = 0l; + TransferFile(File file, String filename, long offset, long length, long fileId){ + this.file = file; + this.filename = filename; + this.offset = offset; + this.length = length; + this.fileId = fileId; + } + } + + + class TalkClient implements Runnable{ + private Socket talkSocket = null; + private ServerSocket serverSocket = null; + private STSender stSender = null; + + + BufferedReader readFromClient = null; + PrintStream sendClient = null; + boolean stopProbing = false; + + + public TalkClient(STSender sts){ + this.stSender = sts; + } + String parseMessage(String message){ + String[] messages = message.split(":"); + if(messages[0].equalsIgnoreCase("start")){ + if(messages[1].equalsIgnoreCase("transfer")){ + this.stSender.status = "TRANSFER"; + return "ok"; + }else if(messages[1].equalsIgnoreCase("probing")){ + this.stSender.status = "PROBING"; + return "ok"; + }else if(messages[1].equalsIgnoreCase("currentprobing")){ + this.stSender.startProbing(); + return "ok"; + } + }else if(messages[0].equalsIgnoreCase("stop")){ + if(messages[1].equalsIgnoreCase("currentprobing")){ + this.stSender.stopProbing(); + return "ok"; + } + }else if(messages[0].equalsIgnoreCase("parameter")){ +// while(STSender.blocks.size() !=0){ +// try{ +// Thread.sleep(100); +// }catch(InterruptedException e){ +// +// } +// } + String[] params = messages[1].split(","); + + this.stSender.numberOfRead = Integer.parseInt(params[0]); + this.stSender.sizeOfQueue = Integer.parseInt(params[1]); + +// synchronized (STSender.blocks){ +// STSender.blocks = new LinkedBlockingQueue<>(this.stSender.sizeOfQueue); +// } + this.stSender.blockSize = Long.parseLong(params[2])*1024*1024; + this.stSender.bufferSize = Long.parseLong(params[3])*1024; + this.stSender.numberOfConnection = Integer.parseInt(params[4]); + return "ok"; + }else if(messages[0].equalsIgnoreCase("check")){ + if(messages[1].equalsIgnoreCase("done")){ + if(STSender.files.size() == 0 && STSender.blocks.size()==0 && this.stSender.sendingDone){ + //System.out.println("[+] Sender is done"); + this.stSender.transferIsDone(); + this.stopProbing = true; + this.stSender.ott.stopThreads = true; + return "true:"+STSender.totalByteToSend; + } + } + } + return "ok"; + } + public void run(){ + try { + this.serverSocket = new ServerSocket(stSender.clientPort); + System.out.println("Listening to port "+stSender.clientPort); + this.talkSocket = serverSocket.accept(); + System.out.println("Socket accepted"); + + this.readFromClient= new BufferedReader(new InputStreamReader(this.talkSocket.getInputStream())); + this.sendClient = new PrintStream(this.talkSocket.getOutputStream()); + this.stSender.receiverIP = this.readFromClient.readLine().trim(); + System.out.println(this.stSender.receiverIP); + this.stSender.receiverPort = Integer.parseInt(this.readFromClient.readLine()); + + this.stSender.fromDir = this.readFromClient.readLine().trim(); + this.stSender.collect_files(); + this.stSender.numberOfMaxConnections = Integer.parseInt(this.readFromClient.readLine().trim()); + this.stSender.sendPool = Executors.newFixedThreadPool(this.stSender.numberOfMaxConnections); + this.stSender.readPool = Executors.newFixedThreadPool(Integer.parseInt(this.readFromClient.readLine().trim())); + this.stSender.blocks = new LinkedBlockingQueue<>(Integer.parseInt(this.readFromClient.readLine().trim())); + this.stSender.talkWithReceiver(this.stSender); + + while(!this.stopProbing){ + String receivedMessage = this.readFromClient.readLine(); +// System.out.println("[+] Message is "+receivedMessage); + String message = this.parseMessage(receivedMessage); + this.sendMessage(message); + } + try{ + Thread.sleep(2000); + }catch(InterruptedException e){ + + } + this.close(); + }catch(IOException e){ + e.printStackTrace(); + } + System.out.println("TalkClient thread done"); + } + void sendMessage(String message){ + this.sendClient.println(message); + } + public void close() throws IOException{ + talkSocket.close(); + readFromClient.close(); + sendClient.close(); + } + } +} diff --git a/adaptive_iterative.py b/adaptive_iterative.py new file mode 100644 index 0000000..3373279 --- /dev/null +++ b/adaptive_iterative.py @@ -0,0 +1,113 @@ +import warnings +warnings.filterwarnings('ignore') +#import pandas as pd +import numpy as np +from joblib import load +import os +import pathlib +from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier +# from sklearn.preprocessing import StandardScaler + +#root = '~/hpcn/network_probing/' +clf_dir = pathlib.Path('./trained_classifiers/').absolute() +reg_dir = pathlib.Path('./trained_regressors/').absolute() + +threshold = 10 +min_points, max_points = 2, 15 +classifiers = {} +regressors = {} +num_of_tree = 50 + + +def regression_train(): + for n in range(min_points, max_points+1): + file_path = "./trained_regressors/regressor_%d.joblib" % n + if os.path.exists(file_path): + regressors[n] = load(file_path) + + +def classification_train(): + for n in range(min_points, max_points+1): + file_path = "./trained_classifiers/classifier_%d.joblib" % n + if os.path.exists(file_path): + classifiers[n] = load(file_path) + + +def is_predictable(test_data): + n = len(test_data) + test_value = np.reshape(test_data, (1,n)) + if classifiers[n].predict(test_value)[0]==1 or n==max_points: + return True + else: + return False + + +def make_prediction(test_data): + n = len(test_data) + test_value = np.reshape(test_data, (1,n)) + return regressors[n].predict(test_value)[0] + + +#def evaluate(X,y): +# times = [] +# errors = [] +# for i in range(len(X)): +# for n in range(min_points, max_points+1): +# try: +# test_value = np.reshape(X[i, :n],(1,n)) +# if classifiers[n].predict(test_value)[0]==1 or n==max_points: +# predicted = regressors[n].predict(test_value)[0] +# original = y[i] +# error_rate = np.abs((predicted - original)/original) * 100 +# times.append(n) +# errors.append(error_rate) +# break +# except Exception as e: +# raise e +# +# print("Duration: {0}s\nError Rate: {1}%".format(np.round(np.mean(times), 2), np.round(np.mean(errors), 2))) +# + +#def main(): +# df = pd.read_csv(root+'data/esnet.csv') +# # df = df[(df.mean_throughput / df.stdv_throughput)>2] +# df.dropna(inplace=True) +# +# test_data = df.sample(frac=.2) +# y_reg_test = test_data["mean_throughput"].values +# # X_test = StandardScaler().fit(test_data.iloc[:, :max_points]).transform(test_data.iloc[:, :max_points]) +# X_test = test_data.iloc[:, :max_points].values +# +# new_dataset = df.drop(test_data.index) +# clf_train_data = new_dataset.sample(frac=.3) +# clf_reg_test = clf_train_data["mean_throughput"].values +# # X_train_clf = StandardScaler().fit(clf_train_data.iloc[:, :max_points]).transform(clf_train_data.iloc[:, :max_points]) +# X_train_clf = clf_train_data.iloc[:, :max_points].values +# +# train_data = new_dataset.drop(clf_train_data.index) +# y_reg_train = train_data["mean_throughput"].values +# # X_train = StandardScaler().fit(train_data.iloc[:, :max_points]).transform(train_data.iloc[:, :max_points]) +# X_train = train_data.iloc[:, :max_points].values +# +# regression_train(X_train, y_reg_train) +# classification_train(X_train_clf, clf_reg_test) +# +# +# # Incremental Prediction +# for i in range(len(X_test)): +# for n in range(min_points, max_points+1): +# # print(is_predictable(X_test[i, :n])) +# if is_predictable(X_test[i, :n]): +# predicted = make_prediction(X_test[i, :n]) +# print("Index: {0}, Duration: {1}s, Predicted={2}MBps, Actual: {3}MBps".format(i, n, +# np.round(predicted), +# np.round(y_reg_test[i]))) +# break +# +# ## Evaluate +# evaluate(X_test, y_reg_test) + + +#if __name__ == "__main__": +# main() + \ No newline at end of file