Skip to content

Commit

Permalink
select model
Browse files Browse the repository at this point in the history
  • Loading branch information
ooe1123 committed Oct 1, 2023
1 parent c2fd944 commit 1244a91
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
27 changes: 22 additions & 5 deletions image_classification/japanese-clip/japanese-clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
# Parameters
# ======================

WEIGHT_IMAGE_PATH = 'CLIP-ViT-B16-image.onnx'
MODEL_IMAGE_PATH = 'CLIP-ViT-B16-image.onnx.prototxt'
WEIGHT_TEXT_PATH = 'CLIP-ViT-B16-text.onnx'
MODEL_TEXT_PATH = 'CLIP-ViT-B16-text.onnx.prototxt'
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/japanese_clip/'
WEIGHT_CLIP_IMAGE_PATH = 'CLIP-ViT-B16-image.onnx'
MODEL_CLIP_IMAGE_PATH = 'CLIP-ViT-B16-image.onnx.prototxt'
WEIGHT_CLIP_TEXT_PATH = 'CLIP-ViT-B16-text.onnx'
MODEL_CLIP_TEXT_PATH = 'CLIP-ViT-B16-text.onnx.prototxt'
WEIGHT_CLOOB_IMAGE_PATH = 'CLOOB-ViT-B16-image.onnx'
MODEL_CLOOB_IMAGE_PATH = 'CLOOB-ViT-B16-image.onnx.prototxt'
WEIGHT_CLOOB_TEXT_PATH = 'CLOOB-ViT-B16-text.onnx'
MODEL_CLOOB_TEXT_PATH = 'CLOOB-ViT-B16-text.onnx.prototxt'
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/japanese-clip/'

IMAGE_PATH = 'dog.jpeg'
Expand All @@ -50,6 +53,10 @@
action='append',
help='Input text. (can be specified multiple times)'
)
parser.add_argument(
'-m', '--model_type', default='clip', choices=('clip', 'cloob'),
help='model type'
)
parser.add_argument(
'--onnx',
action='store_true',
Expand Down Expand Up @@ -249,6 +256,16 @@ def recognize_from_video(models):


def main():
dic_model = {
'clip': (
(WEIGHT_CLIP_IMAGE_PATH, MODEL_CLIP_IMAGE_PATH),
(WEIGHT_CLIP_TEXT_PATH, MODEL_CLIP_TEXT_PATH)),
'cloob': (
(WEIGHT_CLOOB_IMAGE_PATH, MODEL_CLOOB_IMAGE_PATH),
(WEIGHT_CLOOB_TEXT_PATH, MODEL_CLOOB_TEXT_PATH)),
}
(WEIGHT_IMAGE_PATH, MODEL_IMAGE_PATH), (WEIGHT_TEXT_PATH, MODEL_TEXT_PATH) = dic_model[args.model_type]

# model files check and download
check_and_download_models(WEIGHT_IMAGE_PATH, MODEL_IMAGE_PATH, REMOTE_PATH)
check_and_download_models(WEIGHT_TEXT_PATH, MODEL_TEXT_PATH, REMOTE_PATH)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
{
"additional_special_tokens": [],
"bos_token": "<s>",
"clean_up_tokenization_spaces": true,
"cls_token": "[CLS]",
"do_lower_case": true,
"eos_token": "</s>",
"extra_ids": 0,
"legacy": true,
"mask_token": "[MASK]",
"model_max_length": 1000000000000000019884624838656,
"name_or_path": "rinna/japanese-roberta-base",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"sp_model_kwargs": {},
"special_tokens_map_file": "/home/ooe/.cache/huggingface/hub/models--rinna--japanese-roberta-base/snapshots/80adab4b7cc7e4d44468ccb9c77d1b130d23cfb4/special_tokens_map.json",
"tokenizer_class": "T5Tokenizer",
"unk_token": "<unk>"
}

0 comments on commit 1244a91

Please sign in to comment.