Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
xxaier committed Jul 31, 2023
1 parent eb58f2b commit 1edd232
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion onnx/export/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ def export(txt, img):


export(TXT, IMG)
# export(TXT_NORM, IMG_NORM)
export(TXT_NORM, IMG_NORM)
22 changes: 12 additions & 10 deletions onnx/misc/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .device import DEVICE
from .config import MODEL_FP
from flagai.model.mm.AltCLIP import CLIPHF
from misc.norm import norm

MODEL = CLIPHF.from_pretrained(MODEL_FP)

Expand All @@ -24,10 +25,10 @@ def forward(self, image):
return self.model.get_image_features(image)


# class ImgNorm(Img):
#
# def forward(self, image):
# return norm(super(ImgNorm, self).forward(image))
class ImgNorm(Img):

def forward(self, image):
return norm(super(ImgNorm, self).forward(image))


class Txt(nn.Module):
Expand All @@ -43,13 +44,14 @@ def forward(self, text, attention_mask):
return self.model.get_text_features(text, attention_mask=attention_mask)


# class TxtNorm(Txt):
#
# def forward(self, text, attention_mask):
# return norm(super(TxtNorm, self).forward(text, attention_mask))
class TxtNorm(Txt):

def forward(self, text, attention_mask):
return norm(super(TxtNorm, self).forward(text, attention_mask))


IMG = Img()
# IMG_NORM = ImgNorm()
IMG_NORM = ImgNorm()

TXT = Txt()
# TXT_NORM = TxtNorm()
TXT_NORM = TxtNorm()

0 comments on commit 1edd232

Please sign in to comment.