-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_knn.py
80 lines (51 loc) · 2.18 KB
/
train_knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/python
# -*- coding: utf-8 -*-
from __future__ import print_function
import argparse
import os
import pandas as pd
from sklearn.externals import joblib
from sklearn.neighbors import KNeighborsClassifier
# Model load function
def model_fn(model_dir):
"""Load model from the model_dir. This is the same model that is saved
in the main if statement.
"""
print('Loading model.')
# load using joblib
model = joblib.load(os.path.join(model_dir, 'model.joblib_KNN'))
print('Done loading model.')
return model
## The main code for KNeighborsClassifier
if __name__ == '__main__':
# All of the model parameters and training parameters are sent as arguments
# when this script is executed, during a training job
# Here we set up an argument parser to easily access the parameters
parser = argparse.ArgumentParser()
# SageMaker parameters, like the directories for training data and saving models; set automatically
# Do not need to change
parser.add_argument('--output-data-dir', type=str,
default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--model-dir', type=str,
default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--data-dir', type=str,
default=os.environ['SM_CHANNEL_TRAIN'])
# Additional arguments that you will need to pass into your model
parser.add_argument('--n_neighbors', type=int, default=5)
# args holds all passed-in arguments
args = parser.parse_args()
# Read in csv training file
training_dir = args.data_dir
train_data = pd.read_csv(os.path.join(training_dir, 'train.csv'),
header=None, names=None)
test_data = pd.read_csv(os.path.join(training_dir, 'test.csv'),
header=None, names=None)
# Labels are in the first column
train_y = train_data.iloc[:, 0]
train_x = train_data.iloc[:, 1:]
# Define a model
model = KNeighborsClassifier(n_neighbors=args.n_neighbors)
# Train the model
model.fit(train_x, train_y)
# Save the trained model
joblib.dump(model, os.path.join(args.model_dir, 'model.joblib_KNN'))