Skip to content

Commit

Permalink
improved retraining (#1), added a history model and an evaluation part (
Browse files Browse the repository at this point in the history
  • Loading branch information
sennierer committed Aug 14, 2018
1 parent 01ed71b commit 4765055
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 32 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name='acdh-spacyal',
version='0.3.6',
version='0.3.7',
packages=find_packages(
exclude=['spacyal/__pycache__']),
include_package_data=True,
Expand Down
2 changes: 2 additions & 0 deletions spacyal/api_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
name='download_model'),
path('download_cases/', api_views.DownloadCasesView.as_view(),
name='download_cases'),
path('project_history/', api_views.GetProjectHistory.as_view(),
name='project_history'),
]
18 changes: 16 additions & 2 deletions spacyal/api_views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from rest_framework.views import APIView
from .models import al_project, case
from .models import al_project, case, project_history
from spacyal.tasks import get_cases, retrain_model
from rest_framework.response import Response
from rest_framework import status
Expand Down Expand Up @@ -42,7 +42,7 @@ def get(self, request):
project_id, model=c['folder'], retrained=c['retrained'])
else:
get_cases.delay(project_id, retrained=False)
c = case.objects.filter(project_id=project_id, decission__isnull=True)
c = case.objects.filter(project_id=project_id, decission__isnull=True).distinct()
res = [obj.as_dict() for obj in c]
return Response(res)

Expand Down Expand Up @@ -132,3 +132,17 @@ def get(self, request):
res[idx][1]['entities'] = [(x[0], x[1], x[2], choices[x[3]])
for x in e[1]['entities']]
return Response(res)


class GetProjectHistory(APIView):

def get(self, request):
project_pk = request.query_params.get('project_pk', None)
hist_obj = project_history.objects.filter(project_id=project_pk).order_by('timestamp')
prec_data = [x.eval_precission for x in hist_obj]
rec_data = [x.eval_recall for x in hist_obj]
f1_data = [x.eval_f1 for x in hist_obj]
labels = list(hist_obj.values_list('timestamp', flat=True))
print(labels)
return Response({'f1': f1_data, 'precission': prec_data, 'recall': rec_data, 'labels': labels})

18 changes: 18 additions & 0 deletions spacyal/migrations/0009_auto_20180814_0729.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 2.0.8 on 2018-08-14 07:29

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('spacyal', '0008_auto_20180809_1311'),
]

operations = [
migrations.AlterField(
model_name='project_history',
name='timestamp',
field=models.DateTimeField(auto_now=True),
),
]
2 changes: 1 addition & 1 deletion spacyal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class project_history(models.Model):
eval_precission = models.FloatField(blank=True, null=True)
eval_recall = models.FloatField(blank=True, null=True)
model_path = models.CharField(max_length=255)
timestamp = models.TimeField(auto_now=True)
timestamp = models.DateTimeField(auto_now=True)


def start_get_cases(sender, instance, created, **kwargs):
Expand Down
30 changes: 24 additions & 6 deletions spacyal/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from spacy import util
import re
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import precision_recall_fscore_support
Expand Down Expand Up @@ -139,8 +140,14 @@ def __init__(self, text, nlp, attribute='entities'):


@shared_task(time_limit=1800)
def retrain_model(project, model=None, n_iter=50):
def retrain_model(project, model=None, n_iter=30):
"""Load the model, set up the pipeline and train the entity recognizer."""
dropout_rates = util.decaying(util.env_opt('dropout_from', 0.2),
util.env_opt('dropout_to', 0.2),
util.env_opt('dropout_decay', 0.0))
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
util.env_opt('batch_to', 16),
util.env_opt('batch_compound', 1.001))
if model == 'model_1':
output_model = 'model_2'
else:
Expand All @@ -156,7 +163,17 @@ def retrain_model(project, model=None, n_iter=50):
return message
TRAIN_DATA, eval_data, hist_object = project.get_training_data(
include_all=True, include_negative=True)
nlp = spacy.load(os.path.join(base_d, model)) # load existing spaCy model
nlp = spacy.load(os.path.join(base_d, model))# load existing spaCy model
if project.project_history_set.all().count() == 1:
project_history = ContentType.objects.get(
app_label="spacyal", model="project_history").model_class()
ev = test_model(eval_data, nlp)
f1 = ev.compute_f1()
hist2 = project_history.objects.create(
project=project, eval_f1=f1['fbeta'],
eval_precission=f1['precission'], eval_recall=f1['recall'])
hist2.cases_training.add(*list(hist_object.cases_training.all()))
hist2.cases_evaluation.add(*list(hist_object.cases_evaluation.all()))
TRAIN_DATA = mix_train_data(nlp, TRAIN_DATA)
with open(os.path.join(base_d, 'training_data.json'), 'w') as outp:
json.dump(TRAIN_DATA, outp)
Expand All @@ -174,11 +191,12 @@ def retrain_model(project, model=None, n_iter=50):
'model': output_model, 'project': project.pk})
random.shuffle(TRAIN_DATA)
losses = {}
for text, annotations in TRAIN_DATA:
for batch in util.minibatch(TRAIN_DATA, size=batch_sizes):
texts, annotations = zip(*batch)
nlp.update(
[text], # batch of texts
[annotations], # batch of annotations
drop=0.5, # dropout - make it harder to memorise data
texts, # batch of texts
annotations, # batch of annotations
drop=next(dropout_rates), # dropout - make it harder to memorise data
sgd=optimizer, # callable to update weights
losses=losses)
if not Path(output_dir).exists():
Expand Down
113 changes: 91 additions & 22 deletions spacyal/templates/spacyal/al_project.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
{% block title %} About {% endblock %}
{% block scriptHeader %}
<script src="{% static 'spacyal/js/csrf.js' %}"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.7.2/Chart.bundle.min.js"></script>
<link rel="stylesheet" href="{% static 'spacyal/css/spacyal.css' %}" />
{% endblock scriptHeader %}
{% block content %}
Expand All @@ -24,6 +25,9 @@ <h4 style="font-style: italic;">training new model</h4>
</div>
<small class="form-text text-muted">Hint: Training is done in the background. You can proceed annotating, but we will use your work only in the next iteration.</small>
</div>
<hr/>
<h4>Evaluation</h4>
<canvas id="evaluation_chart" width="400" height="200"></canvas>
<h4>decission history</h4>
<table class="table" id="last-5-dec">
<thead><th>sent</th><th>decission</th></thead>
Expand Down Expand Up @@ -73,8 +77,56 @@ <h4>decission history</h4>
return res
};

function update_evaluation_chart(project_id){
$.ajax({
type: "GET",
contentType: "application/json",
dataType: "JSON",
url: "{% url 'spacyal_api:project_history' %}",
data: {'project_pk': project_id},
success: function(data_input) {
console.log(data_input.f1);
console.log(data_input.labels);
var ctx = $("#evaluation_chart");
var history_chart = new Chart(ctx, {
type: 'line',
data: {datasets: [{
label: 'Precission',
fill: false,
borderColor: '#449d44',
data: data_input.precission,
}, {
fill: false,
borderColor: '#23527c',
label: 'F1',
data: data_input.f1
}, {
fill: false,
label: 'Recall',
borderColor: '#ec971f',
data: data_input.recall
}],
labels: data_input.labels,

},
options: {
scales: {
xAxes: [{
ticks: {
display: false
}
}]
},
}

});

}
});
};

function save_case(project_id, case_id, decission, correction=false){
if ($.AL_projects[project_id.toString()]['cases'].length < 5 && !correction){
if ($.AL_projects[project_id.toString()]['cases'].length < {{object.num_plus_retrain}} && !correction){
var retrain = true;
$.AL_projects[project_id.toString()]['retrain'] = true;
} else {
Expand All @@ -94,56 +146,69 @@ <h4>decission history</h4>
success: function (data) {
console.log('success');
console.log(data);
if (data.m_hash) {
check_progress_model(data.m_hash);
};
if (!correction) {

if ($.AL_projects[project_id.toString()]['cases'].length == 0){
s = {}
s['sentence'] = "<h3>No more examples</h3>"
} else {
if ($("#last-5-dec > tbody > tr").length >= 5) {
$("#last-5-dec > tbody > tr").last().remove();
};
var ael = $.AL_projects[project_id.toString()]['actual_case'];
$.AL_projects[project_id.toString()]['decided_cases'].push(ael.id);
$("#last-5-dec > tbody").prepend($.parseHTML(
"<tr><td>"+ael.sentence+"</td><td>"+return_decission(decission, ael.id)+"</td></tr>"
))
if ($.AL_projects[project_id.toString()]['cases'].length == 0){
s = {}
s['sentence'] = "<h3>No more examples</h3>"
} else {
"<tr><td>"+ael.sentence+"</td><td>"+return_decission(decission, ael.id)+"</td></tr>"));

var s = $.AL_projects[project_id.toString()]['cases'].shift();
while ($.AL_projects[project_id.toString()]['decided_cases'].includes(s.id)) {
console.log('double element');
console.log(s);
s = $.AL_projects[project_id.toString()]['cases'].shift();
}
$.AL_projects[project_id.toString()]['actual_case'] = s;
console.log(s);
console.log($.AL_projects[project_id.toString()]['cases']);
};

$('#sent_decide').html($.parseHTML(s.sentence));
if (data.m_hash) {
check_progress_model(data.m_hash);
};
}}
});
if (data['retrain']){
load_cases({{object.pk}}, true);
}
load_cases({{object.pk}}, call_save=true);
};
return true
};
function load_cases(project_id, call_save=false){
console.log(project_id);
if (!$.AL_projects) {
$.AL_projects = {};
var s = {'cases': [], 'retrain': false};
var s = {'cases': [], 'retrain': false, 'decided_cases': []};
$.AL_projects[project_id.toString()] = s;
};
if (!$.AL_projects[project_id.toString()]) {
$.AL_projects[project_id.toString()]['cases'] = [];
$.AL_projects[project_id.toString()]['decided_cases'] = [];
$.AL_projects[project_id.toString()]['retrain'] = false;
};
$.ajax({
type: "GET",
url: "{% url 'spacyal_api:retrievecases' %}",
data: {'project_id': project_id},
success: function (data) {
console.log(data);
if ($('#sent_decide').is(':empty')){
console.log('empty');
var s = data.shift();
$('#sent_decide').html($.parseHTML(s.sentence));
};
if (call_save){
var s = data.shift();
$('#sent_decide').html($.parseHTML(s.sentence));
};
$.AL_projects[project_id.toString()]['actual_case'] = s;
$.extend($.AL_projects[project_id.toString()]['cases'], data)
};

//$.extend($.AL_projects[project_id.toString()]['cases'], data)
$.AL_projects[project_id.toString()]['cases'].push(...data);
$.AL_projects[project_id.toString()]['retrain'] = false;
},
statusCode: {
Expand Down Expand Up @@ -171,7 +236,8 @@ <h4>decission history</h4>
setTimeout('check_progress_model("'+hash+'")', 5000);
} else if (data.status == 'SUCCESS') {
$('#progress_model').css('display', 'none');
load_cases({{object.pk}}, false);
//load_cases({{object.pk}}, false);
update_evaluation_chart({{object.pk}});
} else if (data.status == 'NOT STARTED') {
$('#not_started').css('display', '');
setTimeout('check_progress_model("'+hash+'")', 5000);
Expand All @@ -191,7 +257,9 @@ <h4>decission history</h4>
setTimeout('check_progress_model("'+hash+'")', 5000);
}

}})};
}});
return true
};

$(document).ready(function () {
load_cases({{object.pk}});
Expand All @@ -217,7 +285,8 @@ <h4>decission history</h4>
save_case({{object.pk}}, case_id, 0);
}
}
})
});
update_evaluation_chart({{object.pk}});
})
</script>
{% endblock scripts %}

0 comments on commit 4765055

Please sign in to comment.