Skip to content

Commit

Permalink
Update RankingTrainer usage and Remove BLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei-Cheng Chang committed Oct 11, 2024
1 parent 05acc37 commit a4870fb
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 92 deletions.
19 changes: 1 addition & 18 deletions .github/build_pypi_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,8 @@ echo "pip: $($PIP --version)"


# Install dependencies
# TODO: remove pin on setuptools after removing numpy.distutils
echo "Install dependencies..."
$PIP install 'setuptools<=73.0.1' wheel twine auditwheel

# Install OpenBLAS
# Using pre-build OpenBLAS lib v0.3.27 hosted on Anaconda
# Refer to: https://github.com/MacPython/openblas-libs
# OpenBLAS64 is for ILP64, which is not our case
if [ "$PLAT" = "manylinux2014_x86_64" ] || [ "$PLAT" = "manylinux2014_aarch64" ]; then
OPENBLAS_VER="v0.3.27"
OPENBLAS_LIB="openblas-${OPENBLAS_VER}-${PLAT}.tar.gz"
OPENBLAS_LIB_URL="https://anaconda.org/multibuild-wheels-staging/openblas-libs/$OPENBLAS_VER/download/$OPENBLAS_LIB"
yum install wget -y
wget $OPENBLAS_LIB_URL
tar -xvf $OPENBLAS_LIB
else
echo "$PLAT not supported."
exit 1
fi
$PIP install 'setuptools' wheel twine auditwheel


# Build wheel
Expand Down
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ mypy:


# Install and unit test
# TODO: remove pin on pip and setuptools after removing numpy.distutils
libpecos:
python3 -m pip install pip==23.0.1
python3 -m pip install "setuptools<=73.0.1"
python3 -m pip install setuptools
${WARN_AS_ERROR_CMD} python3 -m pip install ${VFLAG} --editable .

.PHONY: test
Expand Down
63 changes: 36 additions & 27 deletions pecos/core/utils/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,37 +764,46 @@ namespace pecos {
}
};

// ===== BLAS C++ Wrapper =====

extern "C" {
double ddot_(ptrdiff_t *, double *, ptrdiff_t *, double *, ptrdiff_t *);
float sdot_(ptrdiff_t *, float *, ptrdiff_t *, float *, ptrdiff_t *);

ptrdiff_t dscal_(ptrdiff_t *, double *, double *, ptrdiff_t *);
ptrdiff_t sscal_(ptrdiff_t *, float *, float *, ptrdiff_t *);

ptrdiff_t daxpy_(ptrdiff_t *, double *, double *, ptrdiff_t *, double *, ptrdiff_t *);
ptrdiff_t saxpy_(ptrdiff_t *, float *, float *, ptrdiff_t *, float *, ptrdiff_t *);

double dcopy_(ptrdiff_t *, double *, ptrdiff_t *, double *, ptrdiff_t *);
float scopy_(ptrdiff_t *, float *, ptrdiff_t *, float *, ptrdiff_t *);
// ===== self-implemented C++ Wrapper for BLAS interface =====
// Since removing the dependency on BLAS, we manually realize
// the dot/scal/axpy/copy BLAS-compatible API via our naive implementation,
// which is for backward-compatibility (e.g., in Newton solver)

template<typename val_type> val_type dot(ptrdiff_t *len, val_type *x, ptrdiff_t *xinc, val_type *y, ptrdiff_t *yinc) {
val_type res = 0.0;
for (ptrdiff_t idx = 0; idx < *len; idx++) {
res += (*x) * (*y);
x += *xinc;
y += *yinc;
}
return res;
}

template<typename val_type> val_type dot(ptrdiff_t *, val_type *, ptrdiff_t *, val_type *, ptrdiff_t *);
template<> inline double dot(ptrdiff_t *len, double *x, ptrdiff_t *xinc, double *y, ptrdiff_t *yinc) { return ddot_(len, x, xinc, y, yinc); }
template<> inline float dot(ptrdiff_t *len, float *x, ptrdiff_t *xinc, float *y, ptrdiff_t *yinc) { return sdot_(len, x, xinc, y, yinc); }

template<typename val_type> val_type scal(ptrdiff_t *, val_type *, val_type *, ptrdiff_t *);
template<> inline double scal(ptrdiff_t *len, double *a, double *x, ptrdiff_t *xinc) { return dscal_(len, a, x, xinc); }
template<> inline float scal(ptrdiff_t *len, float *a, float *x, ptrdiff_t *xinc) { return sscal_(len, a, x, xinc); }
template<typename val_type> val_type scal(ptrdiff_t *len, val_type *a, val_type *x, ptrdiff_t *xinc) {
for (ptrdiff_t idx = 0; idx < *len; idx++) {
*x = (*x) * (*a);
x += *xinc;
}
return (val_type) 0;
}

template<typename val_type> ptrdiff_t axpy(ptrdiff_t *, val_type *, val_type *, ptrdiff_t *, val_type *, ptrdiff_t *);
template<> inline ptrdiff_t axpy(ptrdiff_t *len, double *alpha, double *x, ptrdiff_t *xinc, double *y, ptrdiff_t *yinc) { return daxpy_(len, alpha, x, xinc, y, yinc); };
template<> inline ptrdiff_t axpy(ptrdiff_t *len, float *alpha, float *x, ptrdiff_t *xinc, float *y, ptrdiff_t *yinc) { return saxpy_(len, alpha, x, xinc, y, yinc); };
template<typename val_type> ptrdiff_t axpy(ptrdiff_t *len, val_type *alpha, val_type *x, ptrdiff_t *xinc, val_type *y, ptrdiff_t *yinc) {
for (ptrdiff_t idx = 0; idx < *len; idx++) {
*y = (*y) + (*x) * (*alpha);
x += *xinc;
y += *yinc;
}
return (ptrdiff_t) 0;
}

template<typename val_type> val_type copy(ptrdiff_t *, val_type *, ptrdiff_t *, val_type *, ptrdiff_t *);
template<> inline double copy(ptrdiff_t *len, double *x, ptrdiff_t *xinc, double *y, ptrdiff_t *yinc) { return dcopy_(len,x,xinc,y,yinc); }
template<> inline float copy(ptrdiff_t *len, float *x, ptrdiff_t *xinc, float *y, ptrdiff_t *yinc) { return scopy_(len,x,xinc,y,yinc); }
template<typename val_type> val_type copy(ptrdiff_t *len, val_type *x, ptrdiff_t *xinc, val_type *y, ptrdiff_t *yinc) {
for (ptrdiff_t idx = 0; idx < *len; idx++) {
*y = *x;
x += *xinc;
y += *yinc;
}
return (val_type) 0;
}

// ===== do_dot_product =====
template<class IX, class VX, class IY, class VY>
Expand Down
4 changes: 2 additions & 2 deletions pecos/xmr/reranker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
inp_feat_dim: int = 1,
inp_dropout_prob: float = 0.1,
hid_dropout_prob: float = 0.1,
hid_actv_type: str = "gelu",
hid_actv_type: str = "relu6",
hid_size_list: list = [64, 128, 256],
**kwargs,
):
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
text_config=None,
numr_config=None,
text_pooling_type="cls",
head_actv_type="gelu",
head_actv_type="relu6",
head_dropout_prob=0.1,
head_size_list=[128, 64],
**kwargs,
Expand Down
51 changes: 37 additions & 14 deletions pecos/xmr/reranker/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,6 @@ def forward(self, preds, target, alpha=0.5):
return loss1


LOSS_FN_DICT = {
"pairwise": PairwisePointwiseHybridLoss(
nn.MarginRankingLoss(reduction="mean", margin=0.1),
nn.MSELoss(reduction="mean"),
),
"listwise": ListwisePointwiseHybridLoss(
nn.CrossEntropyLoss(reduction="mean"),
nn.BCEWithLogitsLoss(reduction="mean"),
),
}


class LoggerCallback(TrainerCallback):
def on_epoch_begin(
self,
Expand Down Expand Up @@ -115,6 +103,8 @@ def on_log(
logs["loss"] = round(logs["loss"], 6)
if "grad_norm" in logs:
logs["grad_norm"] = round(logs["grad_norm"], 6)
if "learning_rate" in logs:
logs["learning_rate"] = round(logs["learning_rate"], 8)
if "epoch" in logs:
logs["epoch"] = round(logs["epoch"], 2)
if state.is_world_process_zero:
Expand All @@ -126,6 +116,17 @@ class RankingTrainer(Trainer, pecos.BaseClass):
Trainer class for the pecos.xmr.reranker.RankingModel.
"""

LOSS_FN_DICT = {
"pairwise": PairwisePointwiseHybridLoss(
nn.MarginRankingLoss(reduction="mean", margin=0.1),
nn.MSELoss(reduction="mean"),
),
"listwise": ListwisePointwiseHybridLoss(
nn.CrossEntropyLoss(reduction="mean"),
nn.BCEWithLogitsLoss(reduction="mean"),
),
}

@dataclass
class TrainingArgs(TrainingArguments, pecos.BaseParams):
loss_fn: str = "listwise"
Expand All @@ -148,10 +149,12 @@ def to_dict(self, with_meta=True):
return self.append_meta(d) if with_meta else d

def __init__(self, *args, **kwargs):
param_to_save = kwargs.pop("param_to_save")
param_to_save = kwargs.pop("param_to_save", None)
if not param_to_save:
raise ValueError("param_to_save can not be None!")
super(RankingTrainer, self).__init__(*args, **kwargs)

self.loss_fn = LOSS_FN_DICT[self.args.loss_fn]
self.loss_fn = self.LOSS_FN_DICT[self.args.loss_fn]
self.loss_alpha = self.args.loss_alpha
self.param_to_save = param_to_save

Expand Down Expand Up @@ -223,3 +226,23 @@ def compute_loss(

loss = self.loss_fn(preds_2d, target, alpha=self.loss_alpha)
return (loss, preds_1d) if return_outputs else loss

def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self.args.include_num_input_tokens_seen:
logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
logs["global_step"] = self.state.global_step

output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) # type: ignore
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
[aliases]
test=pytest

# TODO: remove pin on setuptools version after removing numpy.distutils
[build-system]
requires = ["setuptools<=73.0.1"]

# Configuration for pytest; enable coverage for pecos, emit
# XML, HTML, and terminal reports.
[tool:pytest]
Expand Down
27 changes: 2 additions & 25 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,11 @@ def get_version(cls):
raise RuntimeError("Unable to find version string.")


class BlasHelper(object):
"""Helper class to figure out user's BLAS library path by Numpy's system-info tool."""

@classmethod
def get_blas_lib_dir(cls):
"""Return user's BLAS library found by Numpy's system-info tool. If not found, will raise error."""
import numpy.distutils.system_info as nps

blas_info = nps.get_info('lapack_opt')
assert blas_info, "No BLAS/LAPACK library is found, need to install BLAS."

blas_lib = blas_info['libraries']
blas_dir = blas_info['library_dirs']

assert blas_lib, "No BLAS/LAPACK library is found, need to install BLAS."
assert blas_dir, "No BLAS/LAPACK library directory is found, need to install BLAS."

return blas_lib, blas_dir


with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()

# Requirements
numpy_requires = [
'setuptools<=73.0.1', # TODO: remove pin on setuptools version after removing numpy.distutils
'numpy>=1.19.5,<2.0.0; python_version>="3.8"'
]
setup_requires = numpy_requires + [
Expand All @@ -124,7 +103,6 @@ def get_blas_lib_dir(cls):

# Fetch Numpy before building Numpy-dependent extension, if Numpy required version was not installed
setuptools.distutils.core.Distribution().fetch_build_eggs(numpy_requires)
blas_lib, blas_dir = BlasHelper.get_blas_lib_dir()

# Get extra manual compile args if any
# Example usage:
Expand All @@ -140,10 +118,9 @@ def get_blas_lib_dir(cls):
"pecos.core.libpecos_float32",
sources=["pecos/core/libpecos.cpp"],
include_dirs=["pecos/core", "/usr/include/", "/usr/local/include"],
libraries=["gomp", "gcc"] + blas_lib,
library_dirs=blas_dir,
libraries=["gomp", "gcc", "stdc++"],
extra_compile_args=["-fopenmp", "-O3", "-std=c++17"] + manual_compile_args,
extra_link_args=['-Wl,--no-as-needed', f"-Wl,-rpath,{':'.join(blas_dir)}"]
extra_link_args=['-Wl,--no-as-needed', f"-Wl,-rpath"]
)

setuptools.setup(
Expand Down

0 comments on commit a4870fb

Please sign in to comment.