Skip to content

Commit 261d087

Browse files
authored
Merge branch 'master' into postprocessing
2 parents 53fad73 + 1283ce3 commit 261d087

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

docker-compose.yml

+5-5
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ services:
99
volumes:
1010
# source on local machine: place in container
1111
- ./models:/app/code2seq/models
12-
- ./datasets/codesearchnet/raw:/app/code2seq/datasets/codesearchnet/raw:ro
13-
- ./datasets/codesearchnet/preprocessed:/app/code2seq/datasets/codesearchnet/preprocessed:rw
12+
- ./datasets/funcom/raw:/app/code2seq/datasets/funcom/raw:ro
13+
- ./datasets/funcom/preprocessed:/app/code2seq/datasets/funcom/preprocessed:rw
1414

1515
environment:
1616
dataset: "codesearchnet"
1717
variant: "comments"
1818
# Preprocessing variables
1919
preprocess: true
2020
includeComments: true
21-
excludeStopwords: false
22-
useTfidf: true
21+
excludeStopwords: true
22+
useTfidf: false
2323
numberOfTfidfKeywords: "50"
2424
# Training variables
2525
train: true
@@ -32,4 +32,4 @@ services:
3232
devices:
3333
- driver: nvidia
3434
count: 1
35-
capabilities: [gpu]
35+
capabilities: [gpu]

modelrunner.py

+34-17
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,23 @@ def train(self):
141141
else:
142142
print("Initializing model from scratch.")
143143

144+
if self.config.LOAD_PATH and not self.config.TRAIN_PATH:
145+
model_dirname = self.config.LOAD_PATH
146+
elif self.config.MODEL_PATH:
147+
model_dirname = self.config.MODEL_PATH
148+
else:
149+
model_dirname = None
150+
print('Model directory is missing')
151+
exit(-1)
152+
153+
stats_file_name = os.path.join(model_dirname, "stats.txt")
154+
loss_file_name = os.path.join(model_dirname, "avg_loss.txt")
155+
try:
156+
os.remove(stats_file_name)
157+
os.remove(loss_file_name)
158+
except OSError:
159+
pass
160+
144161
sum_loss = 0
145162
batch_num = 0
146163
epochs_trained = 0
@@ -194,7 +211,7 @@ def train(self):
194211
batch_num += 1
195212

196213
if batch_num % self.num_batches_to_log == 0:
197-
self.trace(pbar, sum_loss, batch_num, multi_batch_start_time)
214+
self.trace(pbar, sum_loss, batch_num, multi_batch_start_time, loss_file_name)
198215
sum_loss = 0
199216
multi_batch_start_time = time.time()
200217

@@ -211,7 +228,12 @@ def train(self):
211228
checkpoint_manager.save()
212229

213230
# validate model to calculate metrics or stop training
214-
results, precision, recall, f1, rouge = self.evaluate()
231+
results, precision, recall, f1, rouge = self.evaluate(model_dirname)
232+
233+
# Add results to a stats file for later processing of graphs
234+
with open(stats_file_name, "a+") as stats_file:
235+
stats_file.write("{0}, {1}, {2}, {3}, {4}\n".format(epochs_trained, results, precision, recall, f1))
236+
215237
if self.config.BEAM_WIDTH == 0:
216238
print("Accuracy after %d epochs: %.5f" % (epochs_trained, results))
217239
else:
@@ -252,22 +274,15 @@ def train(self):
252274
% ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60)
253275
)
254276

255-
def evaluate(self):
277+
def evaluate(self, model_dirname):
256278
if not self.model:
257279
print("Model is not initialized")
258280
exit(-1)
259281

260282
print("Testing...")
261283
eval_start_time = time.time()
262284

263-
if self.config.LOAD_PATH and not self.config.TRAIN_PATH:
264-
model_dirname = self.config.LOAD_PATH
265-
elif self.config.MODEL_PATH:
266-
model_dirname = self.config.MODEL_PATH
267-
else:
268-
model_dirname = None
269-
print('Model directory is missing')
270-
exit(-1)
285+
271286

272287
ref_file_name = os.path.join(model_dirname, "ref.txt")
273288
predicted_file_name = os.path.join(model_dirname, "pred.txt")
@@ -388,15 +403,16 @@ def evaluate(self):
388403

389404
elapsed = int(time.time() - eval_start_time)
390405
precision, recall, f1 = calculate_results(true_positive, false_positive, false_negative)
391-
406+
accuracy = num_correct_predictions / total_predictions
407+
392408
try:
393409
files_rouge = FilesRouge(predicted_file_name, ref_file_name)
394410
rouge = files_rouge.get_scores(avg=True, ignore_empty=True)
395411
except ValueError:
396412
rouge = 0
397413

398414
print("Evaluation time: %sh%sm%ss" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))
399-
return num_correct_predictions / total_predictions, precision, recall, f1, rouge
415+
return accuracy, precision, recall, f1, rouge
400416

401417
def print_hyperparams(self):
402418
print("Training batch size:\t\t\t", self.config.BATCH_SIZE)
@@ -422,17 +438,18 @@ def print_hyperparams(self):
422438
print("LSTM dropout keep_prob:\t\t\t", self.config.RNN_DROPOUT_KEEP_PROB)
423439
print("============================================")
424440

425-
def trace(self, pbar, sum_loss, batch_num, multi_batch_start_time):
441+
def trace(self, pbar, sum_loss, batch_num, multi_batch_start_time, loss_file_name):
426442
multi_batch_elapsed = time.time() - multi_batch_start_time
427443
avg_loss = sum_loss / self.num_batches_to_log
444+
throughput = self.config.BATCH_SIZE * self.num_batches_to_log / (multi_batch_elapsed if multi_batch_elapsed > 0 else 1)
428445
msg = "Average loss at batch {0}: {1}, \tthroughput: {2} samples/sec".format(
429446
batch_num,
430447
avg_loss,
431-
self.config.BATCH_SIZE
432-
* self.num_batches_to_log
433-
/ (multi_batch_elapsed if multi_batch_elapsed > 0 else 1),
448+
throughput
434449
)
435450
pbar.set_description(msg)
451+
with open(loss_file_name, "a+") as loss_file:
452+
loss_file.write("{0}, {1}, {2}\n".format(batch_num, avg_loss, throughput))
436453

437454
def encode(self, predict_data_lines):
438455
if not self.model:

0 commit comments

Comments
 (0)