Skip to content

Commit 60422ed

Browse files
authored
Add files via upload
1 parent 4806390 commit 60422ed

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

prepare.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@
1919
train_lst, eval_lst = list(), list()
2020
font = ImageFont.truetype(random.choice(font_list)+".ttf",14)
2121
for i in range(args.train_num):
22-
img = Image.new("RGBA", (110,20),(255,255,255))
23-
word = ''.join(random.choice(string.ascii_letters) for i in range(random.randrange(5,10)))
22+
img = Image.new("RGBA", (20,20),(255,255,255))
23+
word = ''.join(random.choice(string.ascii_letters))
2424
ImageDraw.Draw(img).text((5, 0), word, (0,0,0), font=font)
2525
img = np.array(img.convert("L"), dtype=np.float32)
2626
train_lst.append((img, img + img * np.random.normal(0,1,img.size).reshape(img.shape[0], img.shape[1]).astype('uint8')))
27+
if i%1000 == 0: print("Train image range ", i)
2728

2829
for i in range(args.eval_num):
2930
img = Image.new("RGBA", (110,20),(255,255,255))
3031
word = ''.join(random.choice(string.ascii_letters) for i in range(random.randrange(5,10)))
3132
ImageDraw.Draw(img).text((5, 0), word, (0,0,0), font=font)
3233
img = np.array(img.convert("L"), dtype=np.float32)
3334
eval_lst.append((img, img + img * np.random.normal(0,1,img.size).reshape(img.shape[0], img.shape[1]).astype('uint8')))
35+
if i%1000 == 0: print("Test image range ", i)
3436
out1, out2 = open(args.train_path, "wb"), open(args.eval_path, "wb")
3537
pickle.dump(train_lst, out1), pickle.dump(eval_lst, out2)
3638
out1.close(), out2.close()

train.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def tensor_to_image(tensor):
7272
for data in train_dataloader:
7373
inputs, labels = data[0].to(device), data[1]
7474
preds = model(inputs)
75-
tensor_to_image(preds[0].detach().numpy().reshape((20, 110))).save("1.jpg")
75+
tensor_to_image(preds[0].detach().numpy().reshape((20, 20))).save("1.jpg")
7676
loss_f = nn.MSELoss()
7777
loss = loss_f(preds, labels)
7878
epoch_losses.update(loss.item(), len(inputs))
@@ -89,5 +89,6 @@ def tensor_to_image(tensor):
8989
for data in eval_dataloader:
9090
inputs, labels = data[0].to(device), data[1]
9191
with torch.no_grad(): preds = model(inputs).clamp(0.0, 255.0)
92-
tensor_to_image(preds[0].reshape((20, 110))).save(args.outputs_dir + str(eval_counter) + ".jpg")
92+
tensor_to_image(preds[0].reshape((20, 20))).save(args.outputs_dir + str(eval_counter) + ".jpg")
9393
eval_counter+=1
94+
# 0.000370268038*1000*200 \approx 37%

0 commit comments

Comments
 (0)