Skip to content

Commit

Permalink
Update training module name to open_clip_train
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jul 4, 2024
1 parent 148ad69 commit 82a244d
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 17 deletions.
14 changes: 7 additions & 7 deletions src/open_clip_train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
hvd = None

from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss
from training.data import get_data
from training.distributed import is_master, init_distributed_device, broadcast_object
from training.logger import setup_logging
from training.params import parse_args
from training.scheduler import cosine_lr, const_lr, const_lr_cooldown
from training.train import train_one_epoch, evaluate
from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync
from open_clip_train.data import get_data
from open_clip_train.distributed import is_master, init_distributed_device, broadcast_object
from open_clip_train.logger import setup_logging
from open_clip_train.params import parse_args
from open_clip_train.scheduler import cosine_lr, const_lr, const_lr_cooldown
from open_clip_train.train import train_one_epoch, evaluate
from open_clip_train.file_utils import pt_load, check_exists, start_sync_process, remote_sync


LATEST_CHECKPOINT_NAME = "epoch_latest.pt"
Expand Down
6 changes: 3 additions & 3 deletions src/open_clip_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
wandb = None

from open_clip import get_input_dtype, CLIP, CustomTextCLIP
from .distributed import is_master
from .zero_shot import zero_shot_eval
from .precision import get_autocast
from open_clip_train.distributed import is_master
from open_clip_train.zero_shot import zero_shot_eval
from open_clip_train.precision import get_autocast


class AverageMeter(object):
Expand Down
2 changes: 1 addition & 1 deletion src/open_clip_train/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \
IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
from .precision import get_autocast
from open_clip_train.precision import get_autocast


def accuracy(output, target, topk=(1,)):
Expand Down
1 change: 1 addition & 0 deletions tests/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from open_clip.hf_model import _POOLERS, HFTextEncoder
from transformers import AutoConfig
from transformers.modeling_outputs import BaseModelOutput

# test poolers
def test_poolers():
bs, sl, d = 2, 10, 5
Expand Down
2 changes: 1 addition & 1 deletion tests/test_num_shards.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from training.data import get_dataset_size
from open_clip_train.data import get_dataset_size

@pytest.mark.parametrize(
"shards,expected_size",
Expand Down
3 changes: 1 addition & 2 deletions tests/test_training_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import os
import sys
import pytest
from PIL import Image
import torch
from training.main import main
from open_clip_train.main import main

os.environ["CUDA_VISIBLE_DEVICES"] = ""

Expand Down
6 changes: 3 additions & 3 deletions tests/test_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import io
from PIL import Image

from training.data import get_wds_dataset
from training.params import parse_args
from training.main import random_seed
from open_clip_train.data import get_wds_dataset
from open_clip_train.params import parse_args
from open_clip_train.main import random_seed

TRAIN_NUM_SAMPLES = 10_000
RTOL = 0.2
Expand Down

0 comments on commit 82a244d

Please sign in to comment.