Skip to content

Commit

Permalink
Add support for saturation arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
mfasi committed May 31, 2024
1 parent f38d309 commit d4b9161
Show file tree
Hide file tree
Showing 8 changed files with 512 additions and 308 deletions.
68 changes: 52 additions & 16 deletions mex/cpfloat.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ void mexFunction(int nlhs,
fpopts->precision = 11;
fpopts->emin = -14;
fpopts->emax = 15;
fpopts->subnormal = CPFLOAT_SUBN_USE;
fpopts->explim = CPFLOAT_EXPRANGE_TARG;
fpopts->round = CPFLOAT_RND_NE;
fpopts->saturation = CPFLOAT_SAT_NO;
fpopts->subnormal = CPFLOAT_SUBN_USE;

fpopts->flip = CPFLOAT_SOFTERR_NO;
fpopts->p = 0.5;

fpopts->bitseed = NULL;
fpopts->randseedf = NULL;
fpopts->randseed = NULL;
Expand All @@ -54,6 +57,7 @@ void mexFunction(int nlhs,
/* Parse second argument and populate fpopts structure. */
if (nrhs > 1) {
bool is_subn_rnd_default = false;
bool is_saturation_default = false;
if(!mxIsEmpty(prhs[1]) && !mxIsStruct(prhs[1])) {
mexErrMsgIdAndTxt("cpfloat:invalidstruct",
"Second argument must be a struct.");
Expand All @@ -62,7 +66,7 @@ void mexFunction(int nlhs,

if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
/* Use default format, for compatibility with chop. */
/* Set default format, for compatibility with chop. */
strcpy(fpopts->format, "h");
else if (mxGetClassID(tmp) == mxCHAR_CLASS)
strcpy(fpopts->format, mxArrayToString(tmp));
Expand All @@ -80,6 +84,7 @@ void mexFunction(int nlhs,
fpopts->precision = 4;
fpopts->emin = -6;
fpopts->emax = 8;
is_saturation_default = true;
} else if (!strcmp(fpopts->format, "q52") ||
!strcmp(fpopts->format, "fp8-e5m2") ||
!strcmp(fpopts->format, "E5M2")) {
Expand Down Expand Up @@ -161,6 +166,31 @@ void mexFunction(int nlhs,
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->round = *((double *)mxGetData(tmp));
}
tmp = mxGetField(prhs[1], 0, "saturation");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->saturation = CPFLOAT_SAT_NO;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->saturation = *((double *)mxGetData(tmp));
} else {
if (is_saturation_default)
fpopts->saturation = CPFLOAT_SAT_USE; /* Default for E4M3. */
else
fpopts->saturation = CPFLOAT_SAT_NO;
}
tmp = mxGetField(prhs[1], 0, "subnormal");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
fpopts->subnormal = CPFLOAT_SUBN_USE;
else if (mxGetClassID(tmp) == mxDOUBLE_CLASS)
fpopts->subnormal = *((double *)mxGetData(tmp));
} else {
if (is_subn_rnd_default)
fpopts->subnormal = CPFLOAT_SUBN_RND; /* Default for bfloat16. */
else
fpopts->subnormal = CPFLOAT_SUBN_USE;
}

tmp = mxGetField(prhs[1], 0, "flip");
if (tmp != NULL) {
if (mxGetM(tmp) == 0 && mxGetN(tmp) == 0)
Expand Down Expand Up @@ -288,10 +318,11 @@ void mexFunction(int nlhs,

/* Allocate and return second output. */
if (nlhs > 1) {
const char* field_names[] = {"format", "params", "subnormal", "round",
"flip", "p", "explim"};
const char* field_names[] = {"format", "params", "explim",
"round", "saturation", "subnormal",
"flip", "p"};
mwSize dims[2] = {1, 1};
plhs[1] = mxCreateStructArray(2, dims, 7, field_names);
plhs[1] = mxCreateStructArray(2, dims, 8, field_names);
mxSetFieldByNumber(plhs[1], 0, 0, mxCreateString(fpopts->format));

mxArray *outparams = mxCreateDoubleMatrix(1,3,mxREAL);
Expand All @@ -301,30 +332,35 @@ void mexFunction(int nlhs,
outparamsptr[2] = fpopts->emax;
mxSetFieldByNumber(plhs[1], 0, 1, outparams);

mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsubnormalptr = mxGetData(outsubnormal);
outsubnormalptr[0] = fpopts->subnormal;
mxSetFieldByNumber(plhs[1], 0, 2, outsubnormal);
mxArray *outexplim = mxCreateDoubleMatrix(1, 1, mxREAL);
double *outexplimptr = mxGetData(outexplim);
outexplimptr[0] = fpopts->explim;
mxSetFieldByNumber(plhs[1], 0, 2, outexplim);

mxArray *outround = mxCreateDoubleMatrix(1,1,mxREAL);
double *outroundptr = mxGetData(outround);
outroundptr[0] = fpopts->round;
mxSetFieldByNumber(plhs[1], 0, 3, outround);

mxArray *outsaturation = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsaturationptr = mxGetData(outsaturation);
outsaturationptr[0] = fpopts->saturation;
mxSetFieldByNumber(plhs[1], 0, 4, outsaturation);

mxArray *outsubnormal = mxCreateDoubleMatrix(1,1,mxREAL);
double *outsubnormalptr = mxGetData(outsubnormal);
outsubnormalptr[0] = fpopts->subnormal;
mxSetFieldByNumber(plhs[1], 0, 5, outsubnormal);

mxArray *outflip = mxCreateDoubleMatrix(1,1,mxREAL);
double *outflipptr = mxGetData(outflip);
outflipptr[0] = fpopts->flip;
mxSetFieldByNumber(plhs[1], 0, 4, outflip);
mxSetFieldByNumber(plhs[1], 0, 6, outflip);

mxArray *outp = mxCreateDoubleMatrix(1,1,mxREAL);
double *outpptr = mxGetData(outp);
outpptr[0] = fpopts->p;
mxSetFieldByNumber(plhs[1], 0, 5, outp);

mxArray *outexplim = mxCreateDoubleMatrix(1,1,mxREAL);
double *outexplimptr = mxGetData(outexplim);
outexplimptr[0] = fpopts->explim;
mxSetFieldByNumber(plhs[1], 0, 6, outexplim);
mxSetFieldByNumber(plhs[1], 0, 7, outp);

}
if (nlhs > 2)
Expand Down
16 changes: 11 additions & 5 deletions mex/cpfloat.m
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
% the target format, respectively. The default value of this field is
% the vector [11,-14,15].
%
% * The scalar FPOPTS.subnormal specifies the support for subnormal numbers.
% The target floating-point format will not support subnormal numbers if
% this field is set to 0, and will support them otherwise. The default value
% for this field is 0 if the target format is 'bfloat16' and 1 otherwise.
%
% * The scalar FPOPTS.explim specifies the support for an extended exponent
% range. The target floating-point format will have the exponent range of
% the storage format ('single' or 'double', depending on the class of X) if
Expand All @@ -63,6 +58,17 @@
% Any other value results in no rounding. The default value for this field
% is 1.
%
% * The scalar FPOPTS.saturation specifies whether saturation arithmetic is in
% use. On overflow, the target floating-point format will use the largest
% representable floating-point if this field is set to 0, and infinity
% otherwise. The default value for this field is 1 if the target format is
% 'E4M3' and 1 otherwise.

% * The scalar FPOPTS.subnormal specifies the support for subnormal numbers.
% The target floating-point format will not support subnormal numbers if
% this field is set to 0, and will support them otherwise. The default value
% for this field is 0 if the target format is 'bfloat16' and 1 otherwise.
%
% * The scalar FPOPTS.flip specifies whether the function should simulate the
% occurrence of a single bit flip striking the floating-point representation
% of elements of Y. Possible values are:
Expand Down
2 changes: 1 addition & 1 deletion src/cpfloat_binary32.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ static inline int cpf_fmaf(float *X, const float *A, const float *B,
#define INTSUFFIX U

#define DEFPREC 24
#define DEFEMAX 127
#define DEFEMIN -126
#define DEFEMAX 127
#define NLEADBITS 9
#define NBITS 32
#define FULLMASK 0xFFFFFFFFU
Expand Down
2 changes: 1 addition & 1 deletion src/cpfloat_binary64.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ static inline int cpf_fma(double *X, const double *A, const double *B,
#define INTTYPE uint64_t
#define INTSUFFIX ULL
#define DEFPREC 53
#define DEFEMAX 1023
#define DEFEMIN -1022
#define DEFEMAX 1023
#define NLEADBITS 12
#define NBITS 64
#define FULLMASK 0xFFFFFFFFFFFFFFFFULL
Expand Down
39 changes: 31 additions & 8 deletions src/cpfloat_definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
*
* + @ref cpfloat_explim_t,
* + @ref cpfloat_rounding_t,
* + @ref cpfloat_saturation_t,
* + @ref cpfloat_softerr_t,
* + @ref cpfloat_subnormal_t,
*
Expand Down Expand Up @@ -88,6 +89,16 @@ typedef enum {
CPFLOAT_NO_RND = 8,
} cpfloat_rounding_t;

/**
* @brief Saturation modes available in CPFloat.
*/
typedef enum {
/** Use standard arithmetic. */
CPFLOAT_SAT_NO = 0,
/** Use saturation arithmetic. */
CPFLOAT_SAT_USE = 1,
} cpfloat_saturation_t;

/**
* @brief Soft fault simulation modes available in CPFloat.
*/
Expand Down Expand Up @@ -214,14 +225,6 @@ typedef struct {
* exponent is larger than the maximum allowed by the storage format.
*/
cpfloat_exponent_t emax;
/**
* @brief Support for subnormal numbers in target format.
*
* @details Subnormal numbers are supported if this field is set to
* `CPFLOAT_SUBN_USE` and rounded to a normal number using the current
* rounding mode if it is set to `CPFLOAT_SUBN_RND`.
*/
cpfloat_subnormal_t subnormal;
/**
* @brief Support for extended exponents in target format.
*
Expand Down Expand Up @@ -256,6 +259,24 @@ typedef struct {
* those in the list above is specified.
*/
cpfloat_rounding_t round;
/**
* @brief Support for subnormal numbers in target format.
*
* @details Subnormal numbers are supported if this field is set to
* `CPFLOAT_SUBN_USE` and rounded to a normal number using the current
* rounding mode if it is set to `CPFLOAT_SUBN_RND`.
*/
cpfloat_saturation_t saturation;
/**
* @brief Support for subnormal numbers in target format.
*
* @details Subnormal numbers are supported if this field is set to
* `CPFLOAT_SUBN_USE` and rounded to a normal number using the current
* rounding mode if it is set to `CPFLOAT_SUBN_RND`.
*/
cpfloat_subnormal_t subnormal;

/* Bit flips. */
/**
* @brief Support for soft errors.
*
Expand All @@ -281,6 +302,8 @@ typedef struct {
* contain a number in the interval [0,1].
*/
double p;

/* Internal: state of pseudo-random number generator. */
/**
* @brief Internal state of pseudo-random number generator for single bits.
*
Expand Down
Loading

0 comments on commit d4b9161

Please sign in to comment.