diff --git a/README.md b/README.md
new file mode 100644
index 0000000..ccb898f
--- /dev/null
+++ b/README.md
@@ -0,0 +1,34 @@
+# STransfer
+
+Compile and Run STSender and STReceiver in sender and receiver respectively.
+And then run STClient.py file as follows:
+On Sender Side:
+```bash
+javac STSender.java;
+java -Xmx3296m STSender;
+```
+On Receiver Side:
+```bash
+javac STReceiver.java;
+java -Xmx3296m STSender;
+```
+On controller Side:
+```bash
+python STClient.py 192.168.1.8:52005/dir/of/the/files/to/send/form/ 192.168.1.7:53823/dir/where/to/save/the/files/on/receiver/side #(--conv=avg --method=GA --generation=5 --population=10) --> Optional
+```
+Optional variables which can be passed as parameter while running java functions are:
+```bash
+--method
+--generation
+--population
+--conv
+--convTime
+```
+Method refers to the method which will be used for finding optimal throughput. It can be either "GA" or "random".
+Generation refers to number of generation in genetic algorithm.
+Population refers to number of individual in each generation of GA.
+Conv refers to convergence function used in which getting the fitness of each individual. The variables for conv can be "avg" and "ar".
+convTime refers to total throughput data to calculate the average convergence time with.
+
+And to run the python code, Scikit-learn, Scipy and Numpy will need to be installed.
+
diff --git a/STClient.py b/STClient.py
new file mode 100644
index 0000000..45fe03b
--- /dev/null
+++ b/STClient.py
@@ -0,0 +1,790 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Jan 10 18:25:34 2020
+
+@author: hem
+"""
+
+import socket
+import sys
+import random as rd
+import math
+import time
+from threading import Thread
+import bisect as _bisect
+import argparse
+from statsmodels.tsa.ar_model import AR
+from sklearn.externals import joblib
+import adaptive_iterative
+import warnings
+warnings.filterwarnings('ignore', 'statsmodels.tsa.ar_model.AR', FutureWarning)
+warnings.filterwarnings("ignore",category=FutureWarning)
+
+def print_message(message):
+ print(message)
+
+
+class STClient:
+ def __init__(self, args):
+ self.max_write_threads = 10
+ self.max_read_threads = 10
+ self.max_transfer_threads = 10
+ self.max_write_queue = 100
+ self.max_buffer_size = 128*1024
+ self.max_read_queue = 100
+
+ self.debug = False
+ self.total_transfer_done = 0
+ self.transfer_done = False
+ self.stop_time = 0
+ self.transfer_port = 0
+ self.start_time = 0
+ if(not args.sender or not args.receiver):
+ print("[+] Sender and receiver not specified")
+ exit(0)
+ sender_info = self.parse_IP(args.sender)
+ receiver_info = self.parse_IP(args.receiver)
+ self.sender_ip = sender_info[0]
+ self.sender_port = int(sender_info[1])
+ self.sender_path = sender_info[2]
+ print_message("[Sender] is %s:%d%s"%(self.sender_ip, self.sender_port, self.sender_path))
+
+ self.receiver_ip = receiver_info[0]
+ self.receiver_port = int(receiver_info[1])
+ self.receiver_path = receiver_info[2]
+ self.interface_ip = self.receiver_ip
+ print_message("[Receiver] is %s:%d%s"%(self.receiver_ip, self.receiver_port, self.receiver_path))
+ if(args.interface):
+ self.interface_ip = args.interface
+
+ def parse_IP(self, ip):
+ parsed_ip = []
+ messages = ip.split(":")
+ parsed_ip.append(messages[0])
+ if "~" in messages[1]:
+ new_msg = messages[1].split("~")
+ parsed_ip.append(new_msg[0])
+ parsed_ip.append("~"+new_msg[1])
+ else:
+ new_msg = messages[1].split("/", 1)
+ parsed_ip.append(new_msg[0])
+ parsed_ip.append("/"+new_msg[1])
+ return parsed_ip
+
+
+
+class TalkSend(Thread):
+ def __init__(self, st_client_param):
+ Thread.__init__(self)
+ self.st_client = st_client_param
+ self.stop_probing = False
+ self.sender_finish_blocks = False
+ self.start_current_probing = False
+ self.stop_current_probing = False
+ self.can_start_probing = False
+ self.current_probe_started = False
+ self.parameters = ""
+ self.parameter_list = {}
+ self.serversocket = None
+ self.client_data, self.client_addr = None, None
+
+ def send_message(self, message):
+ message = str(message)+"\n"
+ self.socket.send(message.encode())
+ def get_next_line(self):
+ message = self.socket.recv(1024).decode()
+ while("\n" not in message):
+ message += self.socket.recv(1024).decode()
+ message=message.strip()
+ return message
+
+ def probe(self):
+ self.send_message("Parameter:"+self.parameters)
+ message = self.get_next_line()
+ if(message == "ok"):
+ self.sender_finish_blocks = True
+ while(not self.start_current_probing):
+ time.sleep(0.010)
+ self.send_message("Start:currentProbing")
+ message = self.get_next_line()
+ if(message.lower() == "ok"):
+ self.current_probe_started = True
+ while(not self.stop_current_probing):
+ time.sleep(0.010)
+ self.send_message("Check:done")
+ message = self.get_next_line()
+ msg = message.strip().split(":")
+ if(msg[0].lower() == "true"):
+ self.st_client.stop_time = int(time.time()*1000)
+ self.st_client.transfer_done = True
+ if(len(msg) >= 2):
+ self.st_client.total_transfer_done = int(msg[1].strip())
+ self.stop_current_probing = True
+ self.stop_probing = True
+ return True
+ self.send_message("Stop:currentProbing")
+ message = self.get_next_line()
+
+ self.stop_current_probing = False
+ self.current_probe_started = False
+ self.sender_finish_blocks = False
+ return True
+
+ def run(self):
+ self.socket = socket.socket()
+# self.server_socket.bind((self.st_client.sender_ip, self.st_client.sender_port))
+ self.client_data = self.socket.connect((self.st_client.sender_ip, self.st_client.sender_port))
+
+ self.send_message(str(self.st_client.interface_ip))
+ while(self.st_client.transfer_port == 0):
+ time.sleep(10)
+ #Sender data
+ self.send_message(str(self.st_client.transfer_port))
+ self.send_message(str(self.st_client.sender_path))
+
+ #Max Parameters
+ self.send_message(str(self.st_client.max_transfer_threads))
+ self.send_message(str(self.st_client.max_read_threads))
+ self.send_message(str(self.st_client.max_read_queue))
+
+ self.send_message("Start:Probing")
+ self.get_next_line()
+ self.can_start_probing = True
+ self.st_client.start_time = int(time.time()*1000)
+ print_message("Starting time is: %d"%self.st_client.start_time)
+
+ while(not self.stop_probing):
+ while(self.parameters == ""):
+ time.sleep(0.010)
+ self.probe()
+ if not self.st_client.transfer_done:
+ self.send_message("Start:Transfer")
+ print_message("[Sender] Normal transfer has been started")
+ self.get_next_line()
+ self.probe()
+ self.socket.close()
+ def set_parameter(self, param):
+ self.parameters = param
+
+
+class TalkReceive(Thread):
+ def __init__(self, st_client_param, st_send_param):
+ Thread.__init__(self)
+ self.stop_probing = False
+ self.start_current_probing = False
+ self.stop_current_probing = False
+ self.convergence_found = False
+ self.can_start_probing = False
+ self.parameters = ""
+ self.parameter_list = {}
+ self.thpt_list = []
+ self.talk_send = st_send_param
+ self.st_client = st_client_param
+ self.serversocket = None
+ self.client_data, self.client_addr = None, None
+ self.zero_count = 0
+
+ def send_message(self, message):
+ message = str(message)+"\n"
+ self.socket.send(message.encode())
+ def get_next_line(self):
+ message = self.socket.recv(1024).decode()
+ while("\n" not in message):
+ message += self.socket.recv(1024).decode()
+ message=message.strip()
+ return message
+ def add_throughput(self, message, throughput_started):
+ if(message != ""):
+ thpts = message.split(",")
+ for i in thpts:
+ if(i != ""):
+ throughput = 0.
+ try:
+ throughput = float(i)
+ except:
+ pass
+ if(throughput <= 0.000001 and not len(self.thpt_list) and not throughput_started):
+ self.zero_count += 1
+ if self.zero_count > 15:
+ break
+ continue
+ throughput_started = True
+ self.thpt_list.append(throughput)
+ return throughput_started
+ def probe(self):
+ self.zero_count = 0
+ self.send_message("Parameter:"+self.parameters)
+ message = self.get_next_line()
+ if(message.lower() == "ok"):
+ self.talk_send.start_current_probing = True
+ while(not self.talk_send.sender_finish_blocks or not self.start_current_probing):
+ time.sleep(0.01)
+ self.start_current_probing = self.talk_send.current_probe_started
+ self.send_message("Start:currentProbing")
+ message = self.get_next_line()
+ self.thpt_list = []
+ throughput_started = False
+ while(not self.stop_current_probing):
+ self.send_message("Get Throughput:New")
+ message = self.get_next_line()
+ throughput_started = self.add_throughput(message, throughput_started)
+ time.sleep(0.01)
+ tmp = self.thpt_list
+ self.thpt_list = [0]*self.zero_count + tmp
+ self.talk_send.stop_current_probing = True
+ self.send_message("Stop:currentProbing")
+ message = self.get_next_line()
+ self.stop_current_probing = False
+ self.start_current_probing = False
+ self.talk_send.start_current_probing = False
+
+
+ def run(self):
+ self.socket = socket.socket()
+ self.client_data = self.socket.connect((self.st_client.receiver_ip, self.st_client.receiver_port))
+ print_message("Receiver: %s and port: %s" % (self.st_client.receiver_ip, str(self.st_client.receiver_port)))
+ self.send_message(self.st_client.receiver_path)
+ self.send_message(self.st_client.max_transfer_threads)
+ self.send_message(self.st_client.max_write_threads)
+ self.send_message(self.st_client.max_write_queue)
+ self.send_message(self.st_client.max_buffer_size)
+ self.st_client.transfer_port = int(self.get_next_line())
+ self.send_message("Start:Probing")
+ self.get_next_line()
+ self.st_client.start_time = int(time.time()*1000)
+
+ self.can_start_probing = True
+ while(not self.stop_probing):
+ while(self.parameters == ""):
+ time.sleep(0.01)
+ self.probe()
+ if(not self.st_client.transfer_done):
+ self.send_message("Start:Transfer")
+ print_message("[Receiver] Normal transfer has been started")
+ self.get_next_line()
+ self.probe()
+
+ self.socket.close()
+ def set_parameter(self, param):
+ self.parameters = param
+
+class GA:
+ def __init__(self, args):
+ self.number_of_generations = 4
+ self.number_of_population = 6
+ if(args.generation):
+ self.number_of_generations = args.generation
+ if(args.population):
+ self.number_of_population = args.population
+ self.evaluation = "adaptive"
+ self.crossover_type = "favour_zero"
+ self.mutation_type = "bit_flip"
+ self.selection_type = "standard_deviation_elitist"
+
+ self.params = [[1, 128], [1, 256], [10, 260], [128, 132], [1, 128], [1, 256], [1, 128]]
+ self.param_length = [4,8,0,0,4,8,4]#[4, 6, 6, 6, 4, 6, 4]
+ self.population_length = 0
+ self.agents = []
+ self.talk_send = None
+ self.talk_receive = None
+ self.st_client = None
+ self.best_score = 0
+ self.best_pop = ""
+ self.previous_best_pop = ""
+ self.initiate(args)
+ for i in self.param_length:
+ self.population_length += i
+ for i in range(self.number_of_population):
+ self.agents.append(Agents(population="", population_length=self.population_length))
+ if(self.conv == "dnn"):
+ self.load_clfs()
+ print("[+] CLF's loaded")
+ if(self.conv == "rand"):
+ adaptive_iterative.regression_train()
+ adaptive_iterative.classification_train()
+ print("[+] CLF's loaded")
+ print("[+] Conv is ", self.conv)
+ self.mutation_probab = 1.0#/self.population_length
+ self.crossover_probab = 1.0
+ print("[+] Method is ", args.method)
+ if(args.method.lower() == "random"):
+ self.run_random()
+ else:
+ self.run_GA()
+ print_message("Best Population: %s and best score: %d" % (self.best_pop, self.best_score))
+ self.talk_send.parameters = self.best_pop
+ self.talk_receive.parameters = self.best_pop
+
+ self.talk_send.stop_probing = True
+ self.talk_receive.stop_probing = True
+ self.talk_receive.stop_current_probing = True
+ while(not self.st_client.transfer_done):
+ time.sleep(0.01)
+ self.check_transfer_done()
+
+
+ def check_transfer_done(self):
+ message = False
+ if(self.st_client.transfer_done):
+ print_message("Sending message to receiver")
+ self.talk_receive.send_message("Done:transfer")
+ print_message("Message send to receiver")
+ try:
+ msg = self.talk_receive.get_next_line()
+ print(msg)
+ msg = msg.split(":")
+ print(msg)
+ if("ok" in msg):
+ self.talk_send.stop_probing = True
+ self.talk_receive.stop_probing = True
+ self.st_client.transfer_done = True
+ print_message("[+] Total time for this transfer is %.3f Seconds"%((self.st_client.stop_time - self.st_client.start_time)/1000.))
+ message = True
+ if(len(msg) == 2):
+ print_message(str(self.st_client.total_transfer_done) + " Bytes")
+ print_message("[+] Total file transfered is %.3f Bytes" % (int(self.st_client.total_transfer_done)/(1024.*1024.*1024.)))
+ print_message("[+] Average throughput is %.3f Mbps" % ((1000 * 8 * int(self.st_client.total_transfer_done))/((self.st_client.stop_time - self.st_client.start_time)*1000.*1000.)))
+ exit(0)
+ except:
+ self.talk_send.stop_probing = True
+ self.talk_receive.stop_probing = True
+ self.st_client.transfer_done = True
+ print_message("[+ err] Total time for this transfer is %.3f Seconds"%((self.st_client.stop_time - self.st_client.start_time)/1000.))
+ print_message("[+ err] Average throughput is %.3f Mbps" % ((1000 * 8 * int(self.st_client.total_transfer_done))/((self.st_client.stop_time - self.st_client.start_time)*1000.*1000.)))
+ message = True
+ exit(0)
+ return message
+
+ def initiate(self, args):
+ self.st_client = STClient(args)
+ self.st_client.max_read_threads = self.params[0][0]+2**self.param_length[0]+1
+ self.st_client.max_transfer_threads = self.params[4][0]+2**self.param_length[4]+1
+ self.st_client.max_write_threads = self.params[6][0]+2**self.param_length[6]+1
+ self.st_client.start_time = int(time.time()*1000)
+ self.talk_send = TalkSend(self.st_client)
+ self.talk_receive = TalkReceive(self.st_client, self.talk_send)
+ self.talk_send.start()
+ self.talk_receive.start()
+ self.conv = args.conv
+# self.talk_send.join()
+# self.talk_receive.join()
+
+ def run_GA(self):
+ ags = self.agents
+ for i in range(self.number_of_generations):
+ self.evaluate_population(ags)
+# self.evaluate_population(ags)
+ if self.st_client.transfer_done:
+ exit(0)
+# if i >= (self.number_of_generations/2):
+# self.best_score = 0.
+ ags = self.selection(ags+self.agents, i)
+ if (i+1) != self.number_of_generations:
+ ags = self.crossover(ags)
+ ags = self.mutation(ags)
+ if i>1:
+ self.agents = ags
+ print_message("[+] Current complete generation is: %d"%(i+1))
+ self.choose_best_pop()
+ def run_random(self):
+ self.agents = []
+ for i in range(self.number_of_generations*self.number_of_population):
+ self.agents.append(Agents(population="", population_length=self.population_length))
+ self.evaluate_population(self.agents)
+# self.evaluate_population(self.agents)
+ self.selection(self.agents, 0)
+
+ def selection(self, agents, generation):
+ ags = []
+ tmp = []
+
+
+ #Comment this out in final version.
+ is_chameleon = False
+ if(generation == 0 and is_chameleon):
+ tmp = agents[:3]
+ agents = agents[5:]
+
+
+
+ agents = sorted(agents, reverse=True)
+ if self.best_score (1+higher_than) * avg_thpt:
+ ags.append(agent)
+ if len(ags)<2:
+ ags = self.percentage_elitist(agents, higher_than=higher_than-0.05)
+ for i in range(len(ags), self.number_of_population):
+ new_ags.append(rd.choice(ags))
+ return ags+new_ags
+ def standard_deviation_elitist(self, agents, number_of_std=1):
+ ags = []
+ new_ags = []
+ avg_thpt = self.get_average_generation_thpt(agents)
+ std_ = self.get_generation_std(agents)
+ print_message("[Genaration] average throughput of generation is %.3f Mbps and larger than %.3f Mbps" % (avg_thpt, avg_thpt+(number_of_std*std_)))
+ for agent in agents:
+ if agent.get_score() >= (avg_thpt+(std_*number_of_std)):
+ ags.append(agent)
+ if len(ags)<2:
+ ags = self.standard_deviation_elitist(agents, number_of_std-0.5)
+ for i in range(len(ags), self.number_of_population):
+ new_ags.append(rd.choice(ags))
+ return ags+new_ags
+
+ def get_ranking(self, length):
+ ranks = [i for i in range(length, 0, -1)]
+ sum_r = sum(ranks)
+ return [(1.0*i)/sum_r for i in ranks]
+
+ def crossover(self, agents):
+ ags = []
+ for i in range(self.number_of_population//2):
+ p1 = agents[2*i].get_population()
+ p2 = agents[2*i+1].get_population()
+ c1, c2 = p1, p2
+ if rd.random() <= self.crossover_probab:
+ if self.crossover_type == "single_point":
+ c1, c2 = self.single_point_crossover(p1, p2)
+ elif self.crossover_type == "favor_zero":
+ c1, c2 = self.favor_zero(p1, p2)
+ ags.append(Agents(c1))
+ ags.append(Agents(c2))
+ return ags
+ def favor_zero(self, p1, p2):
+ p1 = self.get_bin_array(p1)
+ p2 = self.get_bin_array(p2)
+ c1 = ""
+ c2 = ""
+ for i in range(len(p1)):
+ curr_1 = p1[i]
+ curr_2 = p2[i]
+ for j in range(len(curr_1)):
+ char_1 = int(curr_1[j])
+ char_2 = int(curr_2[j])
+ if char_1 == char_2:
+ c1 += str(char_1)
+ c2 += str(char_2)
+ else:
+ if char_1 == 0:
+ c1 += str(char_1)
+ c2 += '0' if rd.random() < 0.5+((len(curr_2)-j)/50) else '1'
+ else:
+ c1 += '0' if rd.random() < 0.5+((len(curr_2)-j)/50) else '1'
+ c2 += str(char_2)
+ return c1, c2
+ def get_bin_array(self, p1):
+ string_array = []
+ since_last = 0
+ for i in range(len(self.param_length)):
+ till = since_last + self.param_length[i]
+ current = p1[since_last:till]
+ string_array.append(current)
+ since_last = till
+ return string_array
+
+ def single_point_crossover(self, p1, p2):
+ index = rd.randint(0, len(p1)-1)
+ c1 = p1[:index] + p2[index:]
+ c2 = p2[:index] + p1[index:]
+ return c1, c2
+
+ def mutation(self, agents):
+ ags = []
+ for agent in agents:
+ if rd.random() <= self.mutation_probab:
+ if self.mutation_type == "bit_flip":
+ agent = Agents(self.bit_flip_mutation(agent.get_population()))
+ ags.append(agent)
+ return ags
+
+ def bit_flip_mutation(self, population):
+ index = rd.randint(0, len(population)-1)
+ tmp = "0" if population[index]=="1" else "1"
+ return population[:index] + tmp + population[index+1:]
+
+ def get_param_string(self, population):
+ value_string = ""
+ since_last = 0
+ for i in range(len(self.param_length)):
+ till = since_last + self.param_length[i]
+ current = population[since_last:till]
+ if i:
+ value_string += ","
+ since_last = till
+ value_string += str(self.get_value(current, i))
+ return value_string
+ def get_bin_string(self, param_str):
+ val_str = param_str.split(",")
+ pop = ""
+ for i in range(len(val_str)):
+ value = str(bin(int(val_str[i]) - self.params[i][0]))[2:]
+ pop += self.get_bin_value(value, self.param_length[i])
+ return pop
+ def get_bin_value(self, value, length):
+ if length == 0:
+ return ""
+ zeros = "0" * (length-len(value))
+ return zeros + value
+
+ def get_value(self, current, i):
+ if not current:
+ return self.params[i][0]
+ return int(current, 2)+self.params[i][0]
+
+ def evaluate_population(self, agents):
+ for agent in agents:
+ print_message("[Agent] for agent "+self.get_param_string(agent.get_population()))
+
+ evaluation_start = time.time()
+ agent.set_thpt(self.evaluate(agent))
+ if(self.check_transfer_done()):
+ print_message("Transfer is done in %.3f Seconds"%((self.st_client.stop_time - self.st_client.start_time)/1000.))
+ break
+ print_message("[Agent] for agent %s throughput is: %.3f Mbps avgThpt: %.3f in total time %.3f"%(self.get_param_string(agent.get_population()), agent.get_thpt(), agent.get_avg_thpt(), time.time() - evaluation_start))
+
+ def evaluate(self, agent):
+ self.convergence_thpt = {}
+ parameter = self.get_param_string(agent.get_population())
+ throughput = 0.
+ self.talk_receive.set_parameter(parameter)
+ self.talk_send.set_parameter(parameter)
+ while(not self.talk_send.can_start_probing or not self.talk_receive.can_start_probing):
+ time.sleep(0.01)
+ self.talk_send.start_current_probing = True
+ while(throughput == 0.0 and len(self.talk_receive.thpt_list)<1500 and not self.st_client.transfer_done):
+ time.sleep(0.01)
+ thpt_list = self.talk_receive.thpt_list
+ if len(thpt_list) not in self.convergence_thpt:
+ throughput = self.find_convergence(thpt_list)
+ self.convergence_thpt[len(thpt_list)] = throughput
+ else:
+ throughput = self.convergence_thpt[len(thpt_list)]
+# print("[Throughput] list" + str(self.talk_receive.thpt_list))
+ if throughput == 0.:
+ throughput = self.find_average_thpt(self.talk_receive.thpt_list)
+ agent.set_avg_thpt(self.find_average_thpt(self.talk_receive.thpt_list))
+ self.talk_receive.thpt_list = []
+ self.talk_receive.stop_current_probing = True
+ return throughput
+ def find_convergence(self, thpt_list):
+ if(self.conv == "ar"):
+ return self.find_convergence_timeseries(thpt_list)
+ elif(self.conv == "dnn"):
+ return self.find_convergence_dnn(thpt_list)
+ elif self.conv == "rand":
+ if len(thpt_list) >= 2 and adaptive_iterative.is_predictable(thpt_list):
+ print(thpt_list)
+ return adaptive_iterative.make_prediction(thpt_list)
+ else:
+ return 0.0
+ else:
+ return self.find_average_thpt(thpt_list) if len(thpt_list)>=10 else 0.0
+ def find_convergence_timeseries(self, thpt_list):
+ if(len(thpt_list)<4):
+ return 0.
+ if(len(thpt_list)>15):
+ return self.find_average_thpt(thpt_list)
+ tmp = [0]+thpt_list[:-1]
+ model = AR(tmp)
+ start_params = [0, 0, 1]
+
+ model_fit = model.fit(maxlag=1, start_params=start_params, disp=-1)
+ predicted_last = model_fit.predict(len(tmp), len(tmp))[0]
+ last_pt = thpt_list[-1]
+
+ if( (last_pt != 0.) and (predicted_last - last_pt)/last_pt < 0.1):
+ return predicted_last
+ return 0.
+ def load_clfs(self):
+ self.all_clfs = {}
+ for i in range(3, 16):
+ self.all_clfs[i] = joblib.load("./clfs/pronghorn-10-%d-42-percentage-optimal.pkl"%i)
+ def get_percentage_change_thpts(self, thpt_list):
+ if len(thpt_list) <= 1:
+ return []
+ new_thpt = []
+ prev_thpt = thpt_list[0]
+ for index in range(1, len(thpt_list)):
+ perc = thpt_list[index] - prev_thpt
+ new_thpt.append(perc / (prev_thpt+1.5))
+ prev_thpt = thpt_list[index]
+ return new_thpt
+
+ def find_convergence_dnn(self, thpt_list):
+ threshold = 1.0
+# print("Actual Thpt list", thpt_list)
+# print(thpt_list)
+ prev_thpt_list = thpt_list
+ thpt_list = self.get_percentage_change_thpts(thpt_list)
+# print("Percentage Change thpt", thpt_list)
+# print(thpt_list)
+ if len(thpt_list)<3:
+ return 0
+ elif len(thpt_list)>=15:
+ return self.find_average_thpt(thpt_list)
+ i = len(thpt_list)
+ y_pred = self.all_clfs[i].predict_proba([thpt_list])[0]
+ max_, ind_ = self.get_max_and_index(y_pred)
+ print("[+] ", max_, threshold - 0.05*(len(thpt_list) - 2), ind_, i, len(prev_thpt_list))
+# print("CT", ind_, " Prediction probability", max_, " Threshold", threshold - 0.05*(len(thpt_list) - 2))
+ if(max_ > (threshold - 0.05*(len(thpt_list) - 2)) and ind_+2 <= i+1):
+ return self.find_average_thpt(prev_thpt_list)
+ return 0.0
+ def get_max_and_index(self, lis):
+ max_ = lis[0]
+ ind_ = 0
+ for i in range(len(lis)):
+ if max_ <= lis[i]:
+ max_ = lis[i]
+ ind_ = i
+ return max_, ind_
+ def find_average_thpt(self, thpt_list):
+ if not thpt_list:
+ return 0.
+ return (1.0*sum(thpt_list))/len(thpt_list)
+ def choose_best_pop(self):
+ pass
+ def get_random_choices(self, population, weights=None, cum_weights=None, k=1):
+ random = rd.random
+ if cum_weights is None:
+ if weights is None:
+ _int = int
+ total = len(population)
+ return [population[_int(random() * total)] for i in range(k)]
+ cum_weights = []
+ last_val = 0
+ for i in weights:
+ last_val += i
+ cum_weights.append(last_val)
+ elif weights is not None:
+ raise TypeError('Cannot specify both weights and cumulative weights')
+ if len(cum_weights) != len(population):
+ raise ValueError('The number of weights does not match the population')
+ bisect = _bisect.bisect
+ total = cum_weights[-1]
+ hi = len(cum_weights) - 1
+ return [population[bisect(cum_weights, random() * total, 0, hi)]
+ for i in range(k)]
+ def get_average_generation_thpt(self, agents):
+ total_score = 0
+ for agent in agents:
+ total_score += agent.get_score()
+ return total_score/len(agents)
+# def get_generation_std(self, agents):
+# scores = []
+# for i in agents:
+# scores.append(i.get_score())
+# return np.array(scores).std()
+ def get_generation_std(self, agents):
+ n = len(agents)
+ if n <= 1:
+ return 0.0
+ mean, sd = self.get_average_generation_thpt(agents), 0.0
+ # calculate stan. dev.
+ for el in agents:
+ sd += (float(el.get_score()) - mean)**2
+ sd = math.sqrt(sd / float(n-1))
+ return sd
+
+class Agents:
+ def __init__(self, population="", population_length=0):
+ self.population = population
+ self.throughput = -1
+ self.average_throughput = 0
+ self.memory_error = False
+ if population == "":
+ self.population = self.get_random_string(population_length)
+
+ def get_random_string(self, population_length):
+ to_ret = ""
+ for i in range(population_length):
+ to_ret += rd.choice(["0", "1"])
+ return to_ret
+ def set_avg_thpt(self, thpt):
+ if self.average_throughput:
+ self.average_throughput += thpt
+ self.average_throughput = self.average_throughput/2.
+ else:
+ self.average_throughput = thpt
+ def get_avg_thpt(self):
+ return self.average_throughput
+ def set_thpt(self, thpt):
+ if self.throughput != -1:
+ self.throughput += thpt
+ self.throughput /= 2.
+ else:
+ self.throughput = thpt
+ def get_thpt(self):
+ return self.throughput
+ def get_population(self):
+ return self.population
+ def get_score(self):
+ if self.get_thpt()>0.0:
+ return self.get_thpt()
+ return self.get_avg_thpt() + self.get_thpt()
+ def __lt__ (self, other):
+ return self.get_score() < other.get_score()
+ def __gt__(self, other):
+ return self.get_score() > other.get_score()
+ def __eq__(self, other):
+ return self.get_score() == other.get_score()
+
+if __name__=="__main__":
+ parser = argparse.ArgumentParser(description='Parameters in the application')
+ parser.add_argument('sender', type=str,
+ help='Sender information')
+ parser.add_argument('receiver', type=str,
+ help='Receiver information')
+ parser.add_argument('--interface', type=str,
+ help='Interface to send information in the receiver')
+ parser.add_argument('--generation', type=int,
+ help='Number of generation for GA')
+ parser.add_argument('--population', type=int,
+ help='Number of population for GA')
+ parser.add_argument('--method', type=str,
+ help='Method of algorithm to use')
+ parser.add_argument('--conv', type=str,
+ help='Method of algorithm to use')
+ args = parser.parse_args()
+ print("Argument values:")
+ print(args.sender)
+ print(args.receiver)
+ if(not args.method):
+ args.method = "GA"
+ if(not args.conv):
+ args.method = "avg"
+ ga = GA(args)
+
+
\ No newline at end of file
diff --git a/STReceiver.java b/STReceiver.java
new file mode 100644
index 0000000..4a31de9
--- /dev/null
+++ b/STReceiver.java
@@ -0,0 +1,545 @@
+import java.io.*;
+import java.net.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.concurrent.*;
+import java.util.concurrent.atomic.AtomicLong;
+
+public class STReceiver {
+ private int numberOfWrite = 2;
+ private int sizeOfQueue = 100;
+ private boolean debug = false;
+ private int transferPort = 48892;
+ private int maxBufferSize = 128*1024;
+ private long throughputCycle = 1000l;
+ private int clientPort = 53823;
+ private String throughputMessage = "";
+ private Boolean transferDone = false;
+ private long numberOfMaxConnections = 0;
+ private AtomicLong totalWriteDone = new AtomicLong(0);
+ private AtomicLong totalTransferDone = new AtomicLong(0);
+ private long tillLastWrite = 0l;
+ private long sinceLastTime = 0l;
+ private long tillLastTransfer = 0l;
+ private long startTime = 0l;
+ private String status = "TRANSFER";
+ private String toDir = "/data/hem/";
+ private ExecutorService writePool = Executors.newFixedThreadPool(200);
+ private HashMap writeBlocks = new HashMap<>();
+ private ExecutorService receivePool = Executors.newFixedThreadPool(200);
+ private HashMap receiveBlocks = new HashMap<>();
+ private static LinkedBlockingQueue blocks = new LinkedBlockingQueue<>(1500);
+ private HashMap filesNames = new HashMap<>();
+
+ public STReceiver(){
+
+ }
+ void startListeningToSender(){
+ STReceiver.OpenTransferThread ott = this.new OpenTransferThread(this, this.transferPort);
+ Thread tott = new Thread(ott, "0");
+ tott.start();
+ }
+ public static void main(String[] args){
+ STReceiver str = new STReceiver();
+ STReceiver.TalkClient tc = str.new TalkClient(str);
+ Thread ttc = new Thread(tc, "TalkClient");
+ ttc.start();
+ System.out.println("Started Receiver at "+str.clientPort+" port");
+ long iter = 0l;
+ long tillCertainTransfer = 0l;
+ long lastCertainTime = 0l;
+ while(!str.transferDone){
+ if(str.startTime != 0) {
+ long totalTransfer = str.totalTransferDone.get();
+ long lastTransfer = str.tillLastTransfer;
+ str.tillLastTransfer = totalTransfer;
+ long totalW = str.totalWriteDone.get();
+ long lastW = str.tillLastWrite;
+ str.tillLastWrite = totalW;
+
+ long lastTime = str.sinceLastTime;
+ str.sinceLastTime = System.currentTimeMillis();
+
+ double timeInterval = str.sinceLastTime - lastTime;
+ if(timeInterval>=1000) {
+ double thpt = 8 * 1000 * (totalW - lastW) / (timeInterval * 1000 * 1000);
+ double thpt_transfer = 8*1000*(totalTransfer-lastTransfer)/(timeInterval *1000*1000);
+ if (str.status.equalsIgnoreCase("probing") && (iter % 1) == 0) {
+ long thisTime = System.currentTimeMillis();
+ double avgThpt = Math.round(100.0*8*1000*(totalW-tillCertainTransfer)/((thisTime-lastCertainTime)*1000.*1000.))/100.0;
+
+ /*
+ System.out.println("Transfer Done till: " + totalW / (1024. * 1024.) +
+ "MB in time: " + (System.currentTimeMillis() - str.startTime) / 1000. +
+ " seconds and Throughput is:" + avgThpt + "Mbps time: " + timeInterval + " tnsferDone: " + (totalW-tillCertainTransfer) + " blocks: "+STReceiver.blocks.size());
+ //*/
+ System.out.println("Transfer Done till: " + totalW / (1024. * 1024.) + "MB in time: "+(System.currentTimeMillis() - str.startTime) / 1000. +
+ " seconds and thpt is: "+avgThpt+" Mbps" + " blocks: "+STReceiver.blocks.size());
+ tillCertainTransfer = totalW;
+ lastCertainTime = thisTime;
+ }
+
+
+ if (str.status.equalsIgnoreCase("probing") && tc.currentProbeStarted) {//PROBING
+ synchronized (str.throughputMessage) {
+ if (!str.throughputMessage.equalsIgnoreCase("")) {
+ str.throughputMessage += "," + (thpt*0.95+thpt_transfer*0.05);
+ } else {
+ str.throughputMessage += thpt;
+ }
+ }
+ }
+ }
+ }
+ try{
+ for(int i = 0; i < 100; i++) {
+ Thread.sleep(str.throughputCycle/100);
+ }
+ }catch(InterruptedException e){
+ e.printStackTrace();
+ }
+ iter++;
+ }
+
+ System.out.println("[+] Receiver has stopped");
+ try{
+ Thread.sleep(5000);
+ }catch(InterruptedException e){
+
+ }
+ System.exit(0);
+
+ }
+ public int getTransferPort(){
+ return this.transferPort;
+ }
+ public void closeConnections(){
+ for (int i = 0; i < this.receiveBlocks.size(); i++) {
+ try {
+ this.receiveBlocks.get(i).close();
+ }catch (IOException e){}
+ }
+ }
+ public void startWriteThreads(int count){
+ this.stopAllWriteThreads();
+ if(this.writeBlocks.size() < count){
+ for(int i=this.writeBlocks.size(); i();
+ }
+ ServerSocket socketReceive = new ServerSocket(this.communicationPort);
+ Socket clientSock = socketReceive.accept();
+ DataInputStream dataInputStream = new DataInputStream(clientSock.getInputStream());
+ DataOutputStream dos = new DataOutputStream(clientSock.getOutputStream());
+ this.stReceiver.numberOfMaxConnections = dataInputStream.readLong();
+ this.stReceiver.startTime = System.currentTimeMillis();
+ int startPort = 61024;
+ for(int i=startPort;i byteArray;
+ long tillNow = 0l;
+ boolean bufferLoaded = false;
+ long startTime = 0l;
+
+
+ Block(long offset, long length){
+ this.offset = offset;
+ this.length = length;
+ byteArray = new ArrayList();
+
+ }
+ void add_buffer(byte[] bff, int buffer_size){
+ byteArray.add(new Buffer(bff, buffer_size));
+ tillNow += buffer_size;
+
+ }
+ void remove_buffer(){
+ this.byteArray = null;
+ }
+ void setOffset(long offset){this.offset=offset;}
+ void setFilename(String fn){this.filename=fn;}
+ void setFileId(long fi){this.fileId=fi;}
+ void setBlockId(long bi){this.blockId=bi;}
+ }
+ class Buffer {
+ byte[] small_buffer;
+ int length;
+
+ Buffer(byte[] buffer, int buffer_size){
+ small_buffer = Arrays.copyOf(buffer, buffer_size);
+ length = buffer_size;
+ }
+ }
+ class WriteBlock implements Runnable{
+ STReceiver stReceiver = null;
+ boolean waitNext = false;
+ WriteBlock(STReceiver str){
+ this.stReceiver = str;
+ }
+ void stopThread(){this.waitNext = true;}
+ void startThread(){this.waitNext=false;}
+
+ @Override
+ public void run() {
+ if(this.stReceiver.debug) {
+ System.out.println("[+] Write Thread-" + Thread.currentThread().getName() + " has Started.");
+ }
+ //Start reading block and write to this.stSender.blocks
+ while(true){
+ Block currentBlock = null;
+ while(STReceiver.blocks.isEmpty()){
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+ try {
+ currentBlock = STReceiver.blocks.poll(50, TimeUnit.MILLISECONDS);
+ if(currentBlock == null) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ if(this.stReceiver.transferDone &&
+ STReceiver.blocks.size() == 0){
+ break;
+ }
+ continue;
+ }
+ long st_wite = System.currentTimeMillis();
+ String filename = this.stReceiver.toDir + "/" + currentBlock.filename;
+ RandomAccessFile randomAccessFile = filesNames.remove(filename);
+ if(randomAccessFile == null) {
+ randomAccessFile = new RandomAccessFile(filename, "rw");
+ }
+ if (currentBlock.offset > 0) {
+ randomAccessFile.getChannel().position(currentBlock.offset);
+ }
+ for(Buffer buffer: currentBlock.byteArray){
+ randomAccessFile.write(buffer.small_buffer, 0, buffer.length);
+
+ buffer.small_buffer = null;
+ }
+ if(this.stReceiver.filesNames.containsKey(filename)){
+ randomAccessFile.close();
+ }else {
+ filesNames.put(filename, randomAccessFile);
+ }
+
+ long done_wrt = System.currentTimeMillis();
+ //System.out.println("[+] Block is done of blockId"+currentBlock.blockId+ " in total time "+(done_wrt-currentBlock.startTime)+"ms & and write time "+(done_wrt-st_wite)+"ms total bolcks is "+STReceiver.blocks.size());
+
+ if(this.stReceiver.debug){
+ System.out.println("[Block Written "+System.currentTimeMillis()+"] fileId: "+currentBlock.fileId+" blockId: "+currentBlock.blockId + " offset: "+currentBlock.offset +
+ " threadName:"+Thread.currentThread().getName());
+ }
+ currentBlock = null;
+ }catch(Exception e){
+ e.printStackTrace();
+ }
+ }
+ }
+ }
+ class ReceiveFile implements Runnable{
+ STReceiver stReceiver = null;
+ boolean waitNext = false;
+ boolean closeConnection = false;
+ Socket clientSock = null;
+ ServerSocket socket_receive = null;
+ int sendFileId = 0;
+ Socket s = null;
+ int port = 0;
+
+ DataInputStream dataInputStream = null;
+
+ ReceiveFile(STReceiver str, int port){
+ //Start a connection here
+ this.stReceiver = str;
+ this.port = port;
+ }
+ void stopThread(){this.waitNext = true;}
+ void startThread(){this.waitNext=false;}
+ void close() throws IOException{
+ if(clientSock!=null){
+ clientSock.close();
+ }
+ if(socket_receive!=null){
+ socket_receive.close();
+ }
+ }
+ @Override
+ public void run() {
+ if (this.stReceiver.debug) {
+ System.out.println("[+] Receive Thread-" + Thread.currentThread().getName() + " has Started.");
+ }
+ try {
+ socket_receive = new ServerSocket(port);
+ socket_receive.setReceiveBufferSize(32*1024*1024);
+ clientSock = socket_receive.accept();
+
+ dataInputStream = new DataInputStream(clientSock.getInputStream());
+ while (!closeConnection) {
+ while (STReceiver.blocks.remainingCapacity() == 0) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+ try {
+ int read = 0;
+ String filename = dataInputStream.readUTF();
+ if (filename.equals("done")) {
+ break;
+ }
+ long length = dataInputStream.readLong();
+ long offset = dataInputStream.readLong();
+ long fileId = dataInputStream.readLong();
+ long blockId = dataInputStream.readLong();
+
+ Block currentBlock = new Block(offset, length);
+ currentBlock.startTime = System.currentTimeMillis();
+ currentBlock.setFilename(filename);
+ currentBlock.setFileId(fileId);
+ currentBlock.setBlockId(blockId);
+ if(this.stReceiver.debug) {
+ System.out.println("[Block Received "+System.currentTimeMillis()+" Started] " + currentBlock.filename + " fileId: " + currentBlock.fileId + " blockId: " + currentBlock.blockId +
+ " offset: " + currentBlock.offset + " length: " + currentBlock.length + " blockLoaded: " + currentBlock.written);
+ }
+ byte[] buffer = new byte[maxBufferSize];
+
+
+ while (currentBlock.tillNow < currentBlock.length) {
+ //long st_tim = System.currentTimeMillis();
+ read = dataInputStream.read(buffer, 0, (int) Math.min(buffer.length, currentBlock.length - currentBlock.tillNow));
+ totalTransferDone.addAndGet(read);
+ //long new_st_tim = System.currentTimeMillis();
+ currentBlock.add_buffer(buffer, read);
+ totalWriteDone.addAndGet(read);
+ //totalWriteDone.addAndGet(read);
+ //long latest_st = System.currentTimeMillis();
+ //System.out.println("Receive Blockid = "+currentBlock.blockId+" and read time "+ (new_st_tim-st_tim)+"ms and add time "+(latest_st-new_st_tim)+"ms");
+ }
+ boolean doneAddition = true;
+
+ if(this.stReceiver.debug) {
+ System.out.println("[Block Received "+System.currentTimeMillis()+" Done] " + currentBlock.filename + " fileId: " + currentBlock.fileId + " blockId: " + currentBlock.blockId +
+ " offset: " + currentBlock.offset + " length: " + currentBlock.length + " blockLoaded: " + currentBlock.written);
+ }
+ currentBlock.written = true;
+
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+// this.stReceiver.transferDone = true;
+ }catch(IOException e){
+ e.printStackTrace();
+ }
+ }
+ }
+ class TalkClient implements Runnable{
+ private Socket talkSocket = null;
+ private ServerSocket serverSocket = null;
+ private STReceiver stReceiver = null;
+
+
+ BufferedReader readFromClient = null;
+ PrintStream sendClient = null;
+ boolean stopProbing = false;
+ boolean startCurrentProbing = false;
+ boolean stopCurrentProbing = false;
+ boolean currentProbeStarted = false;
+
+
+ public TalkClient(STReceiver str){
+ this.stReceiver = str;
+ }
+ String parseMessage(String message){
+ String[] messages = message.split(":");
+ if(messages[0].equalsIgnoreCase("start")){
+ if(messages[1].equalsIgnoreCase("transfer")){
+ this.stReceiver.status = "TRANSFER";
+ System.out.println("[+] Normal transfer has started");
+ this.stReceiver.startProbing();
+ return "ok";
+ }else if(messages[1].equalsIgnoreCase("probing")){
+ this.stReceiver.status = "PROBING";
+ return "ok";
+ }else if(messages[1].equalsIgnoreCase("currentProbing")){
+ this.stReceiver.status = "PROBING";
+ this.stReceiver.startProbing();
+ this.currentProbeStarted = true;
+ this.stReceiver.sinceLastTime = System.currentTimeMillis();
+ return "ok";
+ }
+ }else if(messages[0].equalsIgnoreCase("stop")){
+ if(messages[1].equalsIgnoreCase("currentProbing")){
+ this.currentProbeStarted = false;
+ this.stReceiver.stopEverything();
+ return "ok";
+ }else if(messages[1].equalsIgnoreCase("everything")){
+ return "ok";
+ }
+ }else if(messages[0].equalsIgnoreCase("parameter")){
+ System.out.println("Parameter: Done");
+ while(STReceiver.blocks.size() !=0){
+ try{
+ Thread.sleep(10);
+ }catch(InterruptedException e){ }
+ }
+ String[] params = messages[1].split(",");
+ this.stReceiver.sizeOfQueue = Integer.parseInt(params[5]);
+ ///*
+ synchronized (STReceiver.blocks){
+ STReceiver.blocks = new LinkedBlockingQueue<>(this.stReceiver.sizeOfQueue);
+ }
+ //*/
+ this.stReceiver.numberOfWrite = Integer.parseInt(params[6]);
+ System.out.println("Parameter: "+messages[1]);
+ return "ok";
+ }else if(messages[0].equalsIgnoreCase("get throughput")){
+ String thpts = "";
+ synchronized (this.stReceiver.throughputMessage){
+ thpts = this.stReceiver.throughputMessage;
+ this.stReceiver.throughputMessage = "";
+ }
+ return thpts;
+ }else if(messages[0].equalsIgnoreCase("done")){
+ while(STReceiver.blocks.size() != 0){
+ try{
+ Thread.sleep(10);
+ }catch(InterruptedException e){}
+ }
+ System.out.println("Receiver is done");
+ System.out.println("Transfer should be done");
+ this.stReceiver.transferIsDone();
+ this.stopProbing = true;
+// double write_done = this.stReceiver.totalWriteDone.get() / 1000000000;
+// double thpt = Math.ceil(write_done) * 1.024*1.024*8;
+ System.out.println(this.stReceiver.totalWriteDone);
+ this.stReceiver.transferDone = true;
+ return ":ok:"+this.stReceiver.totalWriteDone;
+ }
+ return "ok";
+ }
+ public void run(){
+ try {
+ this.serverSocket = new ServerSocket(this.stReceiver.clientPort);
+ this.talkSocket = serverSocket.accept();
+
+ this.readFromClient= new BufferedReader(new InputStreamReader(this.talkSocket.getInputStream()));
+ this.sendClient = new PrintStream(this.talkSocket.getOutputStream());
+
+ this.stReceiver.toDir = this.readFromClient.readLine().trim();
+ this.stReceiver.receivePool = Executors.newFixedThreadPool(Integer.parseInt(this.readFromClient.readLine().trim()));
+ this.stReceiver.writePool = Executors.newFixedThreadPool(Integer.parseInt(this.readFromClient.readLine().trim()));
+ this.stReceiver.blocks = new LinkedBlockingQueue<>(Integer.parseInt(this.readFromClient.readLine().trim()));
+ this.stReceiver.maxBufferSize = Integer.parseInt(this.readFromClient.readLine().trim());
+ this.stReceiver.startListeningToSender();
+ this.sendMessage(""+this.stReceiver.getTransferPort());
+
+ //String init_ = this.readFromClient.readLine();
+ //this.sendMessage("ok");
+
+ while(!this.stopProbing){
+ String receivedMessage = this.readFromClient.readLine();
+ String message = this.parseMessage(receivedMessage);
+ this.sendMessage(message);
+ }
+ try{
+ Thread.sleep(1000);
+ }catch(InterruptedException e){
+
+ }
+ this.close();
+ }catch(IOException e){
+ e.printStackTrace();
+ }
+ }
+ void sendMessage(String message){
+ this.sendClient.println(message);
+ }
+ public void close() throws IOException{
+ talkSocket.close();
+ readFromClient.close();
+ sendClient.close();
+ }
+ }
+}
diff --git a/STSender.java b/STSender.java
new file mode 100644
index 0000000..db1fa86
--- /dev/null
+++ b/STSender.java
@@ -0,0 +1,587 @@
+import java.io.*;
+import java.net.*;
+import java.util.*;
+import java.util.Queue;
+import java.util.concurrent.*;
+
+public class STSender {
+ private String receiverIP = "192.168.1.1";
+ private int receiverPort = 48892;
+ int numberOfRead = 2;
+ private int sizeOfQueue = 100;
+ private int numberOfConnection = 1;
+ private int numberOfMaxConnections = 10;
+ int clientPort = 52005;
+ private long blockSize = 256*1024*1024l;
+ private long bufferSize = 128*1024l;
+ private boolean debug = false;
+ String status = "TRANSFER";
+ private Boolean transferDone = false;
+ private String fromDir = "/data/aalhussen/10Mfiles";
+ private static Queue files = new LinkedList();//Collections.asLifoQueue(new ArrayDeque<>());
+ private ExecutorService readPool = Executors.newFixedThreadPool(200);
+ private HashMap readBlocks = new HashMap<>();
+ private ExecutorService sendPool = Executors.newFixedThreadPool(200);
+ private HashMap sendBlocks = new HashMap<>();
+ private static LinkedBlockingQueue blocks = new LinkedBlockingQueue<>(1500);
+ private boolean sendingDone = false;
+ private boolean fileFinished = false;
+ private static Long totalByteToSend = 0l;
+ private long last_file_id = 0l;
+ private long total_files = 0l;
+ private STSender.OpenTransferThread ott = null;
+ private HashMap> openFiles = new HashMap<>();
+
+ static int fileNum = 0;
+
+ public STSender(){
+
+ }
+ public static void main(String[] args){
+ STSender sts = new STSender();
+ STSender.TalkClient tc = sts.new TalkClient(sts);
+ Thread ttc = new Thread(tc, "TalkClient");
+ ttc.start();
+ try {
+ ttc.join();
+ }catch(InterruptedException e){
+
+ }
+ System.out.println("done");
+ try{
+ Thread.sleep(15000);
+ }catch(InterruptedException e){
+
+ }
+ System.exit(0);
+ }
+ void talkWithReceiver(STSender sts){
+ ott = sts.new OpenTransferThread(sts, sts.receiverIP, sts.receiverPort);
+ Thread tott = new Thread(ott, "0");
+ tott.start();
+ }
+ public void startReadThreads(int count){
+ this.stopAllReadThreads();
+ if(this.readBlocks.size() < count){
+ for(int i=this.readBlocks.size(); i= this.stSender.sizeOfQueue){
+ try {
+ Thread.sleep(10);
+ continue;
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+ long newOffset = 0;
+ long length = 0;
+ long blockId = 0;
+ int n;
+ FileInputStream fis = null;
+ synchronized (STSender.files) {
+ if(STSender.files.size()==0){
+ this.stSender.fileFinished = true;
+ }
+ head = STSender.files.poll();
+ if(head!=null){
+ blockId = head.blockId;
+ head.blockId++;
+ newOffset = head.offset;
+ length = Math.min(this.stSender.blockSize, head.length - newOffset);
+ head.offset += length;
+ if (head.offset < head.length) {
+ STSender.files.add(head);
+ }else if (this.stSender.last_file_id < this.stSender.total_files * 3.5){
+ STSender.files.add(new TransferFile(head.file, head.filename, 0l, head.length, this.stSender.last_file_id++));
+ STSender.totalByteToSend += head.length;
+ }
+ }
+ }
+ try {
+ if(head == null) {
+ if(STSender.files.size() == 0) {
+ }
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ continue;
+ }
+ currentBlock = new Block(newOffset, length);
+ currentBlock.filename = head.filename;
+ currentBlock.fileId = head.fileId;
+ currentBlock.blockId = blockId;
+ currentBlock.fileLength = head.length;
+
+ byte[] buffer = new byte[(int)buffer_size];
+
+ if(this.stSender.debug) {
+ System.out.println("[Block Load "+System.currentTimeMillis()+" started] "+currentBlock.filename + " fileId: " + currentBlock.fileId + " blockId: " + currentBlock.blockId +
+ " offset: " + currentBlock.offset + " length: " + currentBlock.length + " fileLength: " +
+ currentBlock.fileLength + " blockLoaded: "+ currentBlock.bufferLoaded + " threadName: "+Thread.currentThread().getName());
+ }
+ while ((int) Math.min(buffer.length, currentBlock.length - currentBlock.tillNow) > 0) {
+ n = (int) Math.min(buffer.length, currentBlock.length - currentBlock.tillNow);
+ currentBlock.add_buffer(buffer, n);
+ if (currentBlock.tillNow >= currentBlock.length) {
+ boolean isSuccess = false;
+ while (!isSuccess) {
+ isSuccess = STSender.blocks.offer(currentBlock, 100, TimeUnit.MILLISECONDS);
+ }
+ currentBlock.bufferLoaded = true;
+ if(this.stSender.debug) {
+ System.out.println("[Block Loaded "+System.currentTimeMillis()+"] "+currentBlock.filename + " fileId: " + currentBlock.fileId + " blockId: " + currentBlock.blockId +
+ " offset: " + currentBlock.offset + " length: " + currentBlock.length + " fileLength: " +
+ currentBlock.fileLength + " blockLoaded: "+ currentBlock.bufferLoaded + " threadName: "+Thread.currentThread().getName());
+ }
+ }
+ }
+ //Might need to initiate buffer to null and to buffer_size again
+ } catch (Exception e) {
+ e.printStackTrace();
+ try {
+ if (fis != null) {
+ fis.close();
+ }
+ }catch(IOException ee){
+ ee.printStackTrace();
+ }
+ if(currentBlock!=null && head!=null){
+ synchronized (STSender.files){
+ //File file, String filename, long offset, long length, long fileId
+ STSender.files.add(new TransferFile(head.file, currentBlock.filename, currentBlock.offset, currentBlock.length, currentBlock.fileId));
+ currentBlock.bufferLoaded = false;
+ }
+ }
+ break;
+ }
+ }
+ System.out.println("Blocks in blocks: "+STSender.blocks.size());
+ }
+ }
+ class Block{
+ long offset = 0l;
+ long length = 0l;
+ String filename = "";
+ long fileId = 0l;
+ long fileLength = 0l;
+ boolean written = false;
+ long blockId = 0l;
+ List byteArray;
+ long tillNow = 0l;
+ boolean bufferLoaded = false;
+
+
+ Block(long offset, long length){
+ this.offset = offset;
+ this.length = length;
+ byteArray = new ArrayList();
+
+
+ }
+ void add_buffer(byte[] bff, int buffer_size){
+ byteArray.add(new Buffer(bff, buffer_size));
+ tillNow += buffer_size;
+
+ }
+ void remove_buffer(){
+ this.byteArray = null;
+ }
+ void setOffset(long offset){this.offset=offset;}
+ void setFilename(String fn){this.filename=fn;}
+ void setFileId(long fi){this.fileId=fi;}
+ void setBlockId(long bi){this.blockId=bi;}
+ }
+ class Buffer {
+ byte[] small_buffer;
+ int length;
+
+ Buffer(byte[] buffer, int buffer_size){
+ small_buffer = Arrays.copyOf(buffer, buffer_size);
+ length = buffer_size;
+ }
+ }
+ void collect_files(){
+ File file = new File(this.fromDir);
+ long fileId = 0l;
+ if (file.isDirectory()) {
+ for (File f : file.listFiles()) {
+ long fileLeng = f.length();
+ STSender.totalByteToSend += fileLeng;
+ STSender.files.add(new TransferFile(f, f.getName(), 0, f.length(), fileId));
+ this.openFiles.put(this.fromDir + "/" + f.getName(), new ArrayList<>());
+ fileId++;
+ }
+ } else {
+ long fileLeng = file.length();
+ STSender.totalByteToSend += fileLeng;
+ STSender.files.add(new TransferFile(file, file.getName(), 0, file.length(), fileId));
+ this.fromDir = file.getParent();
+ this.openFiles.put(this.fromDir + "/" + file.getName(), new ArrayList<>());
+ fileId++;
+ }
+ this.last_file_id = fileId;
+ this.total_files = fileId - 1;
+ System.out.println("[+] Sending "+files.size()+" files to receiver.");
+ System.out.println("[+] Sending "+Math.round(100*STSender.totalByteToSend/(1024.0*1024*1024.))/100.0+" GB data to receiver.");
+
+ }
+ class TransferFile{
+ String filename = "";
+ long offset = 0l;
+ long length = 0l;
+ long fileId = 0l;
+ File file = null;
+ long blockId = 0l;
+ TransferFile(File file, String filename, long offset, long length, long fileId){
+ this.file = file;
+ this.filename = filename;
+ this.offset = offset;
+ this.length = length;
+ this.fileId = fileId;
+ }
+ }
+
+
+ class TalkClient implements Runnable{
+ private Socket talkSocket = null;
+ private ServerSocket serverSocket = null;
+ private STSender stSender = null;
+
+
+ BufferedReader readFromClient = null;
+ PrintStream sendClient = null;
+ boolean stopProbing = false;
+
+
+ public TalkClient(STSender sts){
+ this.stSender = sts;
+ }
+ String parseMessage(String message){
+ String[] messages = message.split(":");
+ if(messages[0].equalsIgnoreCase("start")){
+ if(messages[1].equalsIgnoreCase("transfer")){
+ this.stSender.status = "TRANSFER";
+ return "ok";
+ }else if(messages[1].equalsIgnoreCase("probing")){
+ this.stSender.status = "PROBING";
+ return "ok";
+ }else if(messages[1].equalsIgnoreCase("currentprobing")){
+ this.stSender.startProbing();
+ return "ok";
+ }
+ }else if(messages[0].equalsIgnoreCase("stop")){
+ if(messages[1].equalsIgnoreCase("currentprobing")){
+ this.stSender.stopProbing();
+ return "ok";
+ }
+ }else if(messages[0].equalsIgnoreCase("parameter")){
+// while(STSender.blocks.size() !=0){
+// try{
+// Thread.sleep(100);
+// }catch(InterruptedException e){
+//
+// }
+// }
+ String[] params = messages[1].split(",");
+
+ this.stSender.numberOfRead = Integer.parseInt(params[0]);
+ this.stSender.sizeOfQueue = Integer.parseInt(params[1]);
+
+// synchronized (STSender.blocks){
+// STSender.blocks = new LinkedBlockingQueue<>(this.stSender.sizeOfQueue);
+// }
+ this.stSender.blockSize = Long.parseLong(params[2])*1024*1024;
+ this.stSender.bufferSize = Long.parseLong(params[3])*1024;
+ this.stSender.numberOfConnection = Integer.parseInt(params[4]);
+ return "ok";
+ }else if(messages[0].equalsIgnoreCase("check")){
+ if(messages[1].equalsIgnoreCase("done")){
+ if(STSender.files.size() == 0 && STSender.blocks.size()==0 && this.stSender.sendingDone){
+ //System.out.println("[+] Sender is done");
+ this.stSender.transferIsDone();
+ this.stopProbing = true;
+ this.stSender.ott.stopThreads = true;
+ return "true:"+STSender.totalByteToSend;
+ }
+ }
+ }
+ return "ok";
+ }
+ public void run(){
+ try {
+ this.serverSocket = new ServerSocket(stSender.clientPort);
+ System.out.println("Listening to port "+stSender.clientPort);
+ this.talkSocket = serverSocket.accept();
+ System.out.println("Socket accepted");
+
+ this.readFromClient= new BufferedReader(new InputStreamReader(this.talkSocket.getInputStream()));
+ this.sendClient = new PrintStream(this.talkSocket.getOutputStream());
+ this.stSender.receiverIP = this.readFromClient.readLine().trim();
+ System.out.println(this.stSender.receiverIP);
+ this.stSender.receiverPort = Integer.parseInt(this.readFromClient.readLine());
+
+ this.stSender.fromDir = this.readFromClient.readLine().trim();
+ this.stSender.collect_files();
+ this.stSender.numberOfMaxConnections = Integer.parseInt(this.readFromClient.readLine().trim());
+ this.stSender.sendPool = Executors.newFixedThreadPool(this.stSender.numberOfMaxConnections);
+ this.stSender.readPool = Executors.newFixedThreadPool(Integer.parseInt(this.readFromClient.readLine().trim()));
+ this.stSender.blocks = new LinkedBlockingQueue<>(Integer.parseInt(this.readFromClient.readLine().trim()));
+ this.stSender.talkWithReceiver(this.stSender);
+
+ while(!this.stopProbing){
+ String receivedMessage = this.readFromClient.readLine();
+// System.out.println("[+] Message is "+receivedMessage);
+ String message = this.parseMessage(receivedMessage);
+ this.sendMessage(message);
+ }
+ try{
+ Thread.sleep(2000);
+ }catch(InterruptedException e){
+
+ }
+ this.close();
+ }catch(IOException e){
+ e.printStackTrace();
+ }
+ System.out.println("TalkClient thread done");
+ }
+ void sendMessage(String message){
+ this.sendClient.println(message);
+ }
+ public void close() throws IOException{
+ talkSocket.close();
+ readFromClient.close();
+ sendClient.close();
+ }
+ }
+}
diff --git a/adaptive_iterative.py b/adaptive_iterative.py
new file mode 100644
index 0000000..3373279
--- /dev/null
+++ b/adaptive_iterative.py
@@ -0,0 +1,113 @@
+import warnings
+warnings.filterwarnings('ignore')
+#import pandas as pd
+import numpy as np
+from joblib import load
+import os
+import pathlib
+from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
+# from sklearn.preprocessing import StandardScaler
+
+#root = '~/hpcn/network_probing/'
+clf_dir = pathlib.Path('./trained_classifiers/').absolute()
+reg_dir = pathlib.Path('./trained_regressors/').absolute()
+
+threshold = 10
+min_points, max_points = 2, 15
+classifiers = {}
+regressors = {}
+num_of_tree = 50
+
+
+def regression_train():
+ for n in range(min_points, max_points+1):
+ file_path = "./trained_regressors/regressor_%d.joblib" % n
+ if os.path.exists(file_path):
+ regressors[n] = load(file_path)
+
+
+def classification_train():
+ for n in range(min_points, max_points+1):
+ file_path = "./trained_classifiers/classifier_%d.joblib" % n
+ if os.path.exists(file_path):
+ classifiers[n] = load(file_path)
+
+
+def is_predictable(test_data):
+ n = len(test_data)
+ test_value = np.reshape(test_data, (1,n))
+ if classifiers[n].predict(test_value)[0]==1 or n==max_points:
+ return True
+ else:
+ return False
+
+
+def make_prediction(test_data):
+ n = len(test_data)
+ test_value = np.reshape(test_data, (1,n))
+ return regressors[n].predict(test_value)[0]
+
+
+#def evaluate(X,y):
+# times = []
+# errors = []
+# for i in range(len(X)):
+# for n in range(min_points, max_points+1):
+# try:
+# test_value = np.reshape(X[i, :n],(1,n))
+# if classifiers[n].predict(test_value)[0]==1 or n==max_points:
+# predicted = regressors[n].predict(test_value)[0]
+# original = y[i]
+# error_rate = np.abs((predicted - original)/original) * 100
+# times.append(n)
+# errors.append(error_rate)
+# break
+# except Exception as e:
+# raise e
+#
+# print("Duration: {0}s\nError Rate: {1}%".format(np.round(np.mean(times), 2), np.round(np.mean(errors), 2)))
+#
+
+#def main():
+# df = pd.read_csv(root+'data/esnet.csv')
+# # df = df[(df.mean_throughput / df.stdv_throughput)>2]
+# df.dropna(inplace=True)
+#
+# test_data = df.sample(frac=.2)
+# y_reg_test = test_data["mean_throughput"].values
+# # X_test = StandardScaler().fit(test_data.iloc[:, :max_points]).transform(test_data.iloc[:, :max_points])
+# X_test = test_data.iloc[:, :max_points].values
+#
+# new_dataset = df.drop(test_data.index)
+# clf_train_data = new_dataset.sample(frac=.3)
+# clf_reg_test = clf_train_data["mean_throughput"].values
+# # X_train_clf = StandardScaler().fit(clf_train_data.iloc[:, :max_points]).transform(clf_train_data.iloc[:, :max_points])
+# X_train_clf = clf_train_data.iloc[:, :max_points].values
+#
+# train_data = new_dataset.drop(clf_train_data.index)
+# y_reg_train = train_data["mean_throughput"].values
+# # X_train = StandardScaler().fit(train_data.iloc[:, :max_points]).transform(train_data.iloc[:, :max_points])
+# X_train = train_data.iloc[:, :max_points].values
+#
+# regression_train(X_train, y_reg_train)
+# classification_train(X_train_clf, clf_reg_test)
+#
+#
+# # Incremental Prediction
+# for i in range(len(X_test)):
+# for n in range(min_points, max_points+1):
+# # print(is_predictable(X_test[i, :n]))
+# if is_predictable(X_test[i, :n]):
+# predicted = make_prediction(X_test[i, :n])
+# print("Index: {0}, Duration: {1}s, Predicted={2}MBps, Actual: {3}MBps".format(i, n,
+# np.round(predicted),
+# np.round(y_reg_test[i])))
+# break
+#
+# ## Evaluate
+# evaluate(X_test, y_reg_test)
+
+
+#if __name__ == "__main__":
+# main()
+
\ No newline at end of file