Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery Starbot ⭐ refactored lamhoangtung/LineHTR #15

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/DataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@ def __init__(self, filePath, batchSize, imgSize, maxTextLen):

# Read json lables file
# Dataset folder should contain a labels.json file inside, with key is the file name of images and value is the label
with open(filePath + 'labels.json') as json_data:
with open(f'{filePath}labels.json') as json_data:
label_file = json.load(json_data)

# Log
print("Loaded", len(label_file), "images")

# Put sample into list
for fileName, gtText in label_file.items():
self.samples.append(Sample(gtText, filePath + fileName))
self.samples.extend(
Sample(gtText, filePath + fileName)
for fileName, gtText in label_file.items()
)
Comment on lines -54 to +64
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function DataLoader.__init__ refactored with the following changes:


self.charList = list(open(FilePaths.fnCharList).read())

Expand Down
10 changes: 5 additions & 5 deletions src/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,19 @@ def setupCTC(self, ctcIn3d):

def setupTF(self):
""" Initialize TensorFlow """
print('Python: ' + sys.version)
print('Tensorflow: ' + tf.__version__)
print(f'Python: {sys.version}')
print(f'Tensorflow: {tf.__version__}')
Comment on lines -202 to +203
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Model.setupTF refactored with the following changes:

sess = tf.Session() # Tensorflow session
saver = tf.train.Saver(max_to_keep=5) # Saver saves model to file
modelDir = '../model/'
latestSnapshot = tf.train.latest_checkpoint(
modelDir) # Is there a saved model?
# If model must be restored (for inference), there must be a snapshot
if self.mustRestore and not latestSnapshot:
raise Exception('No saved model found in: ' + modelDir)
raise Exception(f'No saved model found in: {modelDir}')
# Load saved model if available
if latestSnapshot:
print('Init with stored values from ' + latestSnapshot)
print(f'Init with stored values from {latestSnapshot}')
saver.restore(sess, latestSnapshot)
else:
print('Init with new values')
Expand Down Expand Up @@ -247,7 +247,7 @@ def toSpare(self, texts):
def decoderOutputToText(self, ctcOutput):
""" Extract texts from output of CTC decoder """
# Contains string of labels for each batch element
encodedLabelStrs = [[] for i in range(Model.batchSize)]
encodedLabelStrs = [[] for _ in range(Model.batchSize)]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Model.decoderOutputToText refactored with the following changes:

# Word beam search: label strings terminated by blank
if self.decoderType == DecoderType.WordBeamSearch:
blank = len(self.charList)
Expand Down