Skip to content

Commit

Permalink
Adding model and dataset to the LFS
Browse files Browse the repository at this point in the history
  • Loading branch information
nkise-nlab committed Feb 3, 2021
1 parent d838910 commit 3f66183
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
keypoint.csv filter=lfs diff=lfs merge=lfs -text
point_history.csv filter=lfs diff=lfs merge=lfs -text
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,5 @@ dmypy.json
.pyre/

.DS_Store

.idea/
3 changes: 3 additions & 0 deletions model/keypoint_classifier/keypoint.csv
Git LFS file not shown
Binary file not shown.
36 changes: 36 additions & 0 deletions model/keypoint_classifier/keypoint_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf


class KeyPointClassifier(object):
def __init__(
self,
model_path='model/keypoint_classifier/keypoint_classifier.tflite',
num_threads=1,
):
self.interpreter = tf.lite.Interpreter(model_path=model_path,
num_threads=num_threads)

self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()

def __call__(
self,
landmark_list,
):
input_details_tensor_index = self.input_details[0]['index']
self.interpreter.set_tensor(
input_details_tensor_index,
np.array([landmark_list], dtype=np.float32))
self.interpreter.invoke()

output_details_tensor_index = self.output_details[0]['index']

result = self.interpreter.get_tensor(output_details_tensor_index)

result_index = np.argmax(np.squeeze(result))

return result_index
Binary file not shown.
4 changes: 4 additions & 0 deletions model/keypoint_classifier/keypoint_classifier_label.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Open
Close
Pointer
OK
3 changes: 3 additions & 0 deletions model/point_history_classifier/point_history.csv
Git LFS file not shown
Binary file not shown.
44 changes: 44 additions & 0 deletions model/point_history_classifier/point_history_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf


class PointHistoryClassifier(object):
def __init__(
self,
model_path='model/point_history_classifier/point_history_classifier.tflite',
score_th=0.5,
invalid_value=0,
num_threads=1,
):
self.interpreter = tf.lite.Interpreter(model_path=model_path,
num_threads=num_threads)

self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()

self.score_th = score_th
self.invalid_value = invalid_value

def __call__(
self,
point_history,
):
input_details_tensor_index = self.input_details[0]['index']
self.interpreter.set_tensor(
input_details_tensor_index,
np.array([point_history], dtype=np.float32))
self.interpreter.invoke()

output_details_tensor_index = self.output_details[0]['index']

result = self.interpreter.get_tensor(output_details_tensor_index)

result_index = np.argmax(np.squeeze(result))

if np.squeeze(result)[result_index] < self.score_th:
result_index = self.invalid_value

return result_index
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Stop
Clockwise
Counter Clockwise
Move

0 comments on commit 3f66183

Please sign in to comment.