Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加类别 #10

Open
wjn0807 opened this issue Jun 11, 2023 · 2 comments
Open

增加类别 #10

wjn0807 opened this issue Jun 11, 2023 · 2 comments

Comments

@wjn0807
Copy link

wjn0807 commented Jun 11, 2023

源代码中设置了三类:text、logo、deco,如果我想增加类别应该修改哪里的代码,在读代码过程中没有找到类别写在了哪里,期待您的解答

@theKinsley
Copy link
Contributor

theKinsley commented Jun 15, 2023

wjn0807你好,在dataloader.py中第57行cls = list(sliced_df["cls_elem"])是读取类别所代表的整数,另外,增加类别还需要相应地调整初始化label的维数,供你参考

# 在dataset的getitem函数中,当总类别数为n
def __getitem__(self, idx):
    # 读取图片,请依据自身的设置完成,此处省略
    im = ...

    # 读取布局Label,假设来源DataFrame为self.dfs[idx],cls_elem、box_elem格式与PKU PosterLayout相同
    label = np.zeros((self.max_elem, 2, n+1))
    sliced_df = self.dfs[idx]
    cls = list(sliced_df["cls_elem"])
    box = torch.tensor(list(map(eval, sliced_df["box_elem"])))
    
    for idx in range(len(cls)):
        label[idx][0][int(cls[idx])] = 1
        label[idx][1][:4] = box[idx]
        # 还需要对label[idx][1][:4],即边界框,做box_xyxy_to_cxcywh转换、依据图像尺寸做scale,此处省略
        ...

    for idx in range(len(order), self.max_elem):
        label[idx][0][0] = 1
    
    return im, torch.tensor(label).float()

@wjn0807
Copy link
Author

wjn0807 commented Jun 15, 2023

非常感谢!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants