Skip to content

Commit

Permalink
Standardize matrix functions
Browse files Browse the repository at this point in the history
Only one *correct* matrix multiply function now

All matricies except Oklab are fed through compile-time transpose so the
visually written columns match up with m[vec][elem]
  • Loading branch information
Beinsezii committed Jun 27, 2024
1 parent cd8395b commit aaa2742
Showing 1 changed file with 56 additions and 52 deletions.
108 changes: 56 additions & 52 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,33 @@ const JZAZBZ_P: f32 = 1.7 * PQEOTF_M2;
// ### CONSTS ### }}}

// ### MATRICES ### {{{

/// Its easier to write matricies visually then transpose them so they can be indexed per vector
/// [X1, X2] -> [X1, Y1]
/// [Y1, Y2] [X2, Y2]
const fn t(m: [[f32; 3]; 3]) -> [[f32; 3]; 3] {
[
[m[0][0], m[1][0], m[2][0]],
[m[0][1], m[1][1], m[2][1]],
[m[0][2], m[1][2], m[2][2]],
]
}

/// Matrix Multiply
fn mm<T: DType>(m: [[f32; 3]; 3], p: [T; 3]) -> [T; 3] {
[
p[0].fma(m[0][0].to_dt(), p[1].fma(m[1][0].to_dt(), p[2] * m[2][0].to_dt())),
p[0].fma(m[0][1].to_dt(), p[1].fma(m[1][1].to_dt(), p[2] * m[2][1].to_dt())),
p[0].fma(m[0][2].to_dt(), p[1].fma(m[1][2].to_dt(), p[2] * m[2][2].to_dt())),
]
}

// CIE XYZ
const XYZ65_MAT: [[f32; 3]; 3] = [
const XYZ65_MAT: [[f32; 3]; 3] = t([
[0.4124, 0.3576, 0.1805],
[0.2126, 0.7152, 0.0722],
[0.0193, 0.1192, 0.9505],
];
]);

// Original commonly used inverted array
// const XYZ65_MAT_INV: [[f32; 3]; 3] = [
Expand All @@ -260,13 +281,14 @@ const XYZ65_MAT: [[f32; 3]; 3] = [
// ];

// Higher precision invert using numpy. Helps with back conversions
const XYZ65_MAT_INV: [[f32; 3]; 3] = [
const XYZ65_MAT_INV: [[f32; 3]; 3] = t([
[3.2406254773, -1.5372079722, -0.4986285987],
[-0.9689307147, 1.8757560609, 0.0415175238],
[0.0557101204, -0.2040210506, 1.0569959423],
];
]);

// OKLAB
// They appear to be provided already transposed for code in the blog post
const OKLAB_M1: [[f32; 3]; 3] = [
[0.8189330101, 0.0329845436, 0.0482003018],
[0.3618667424, 0.9293118715, 0.2643662691],
Expand All @@ -289,68 +311,50 @@ const OKLAB_M2_INV: [[f32; 3]; 3] = [
];

// JzAzBz
const JZAZBZ_M1: [[f32; 3]; 3] = [
const JZAZBZ_M1: [[f32; 3]; 3] = t([
[0.41478972, 0.579999, 0.0146480],
[-0.2015100, 1.120649, 0.0531008],
[-0.0166008, 0.264800, 0.6684799],
];
const JZAZBZ_M2: [[f32; 3]; 3] = [
]);
const JZAZBZ_M2: [[f32; 3]; 3] = t([
[0.500000, 0.500000, 0.000000],
[3.524000, -4.066708, 0.542708],
[0.199076, 1.096799, -1.295875],
];
]);

const JZAZBZ_M1_INV: [[f32; 3]; 3] = [
const JZAZBZ_M1_INV: [[f32; 3]; 3] = t([
[1.9242264358, -1.0047923126, 0.037651404],
[0.3503167621, 0.7264811939, -0.0653844229],
[-0.090982811, -0.3127282905, 1.5227665613],
];
const JZAZBZ_M2_INV: [[f32; 3]; 3] = [
]);
const JZAZBZ_M2_INV: [[f32; 3]; 3] = t([
[1., 0.1386050433, 0.0580473162],
[1., -0.1386050433, -0.0580473162],
[1., -0.096019242, -0.8118918961],
];
]);

// ICtCp
const ICTCP_M1: [[f32; 3]; 3] = [
const ICTCP_M1: [[f32; 3]; 3] = t([
[1688. / 4096., 2146. / 4096., 262. / 4096.],
[683. / 4096., 2951. / 4096., 462. / 4096.],
[99. / 4096., 309. / 4096., 3688. / 4096.],
];
const ICTCP_M2: [[f32; 3]; 3] = [
]);
const ICTCP_M2: [[f32; 3]; 3] = t([
[2048. / 4096., 2048. / 4096., 0. / 4096.],
[6610. / 4096., -13613. / 4096., 7003. / 4096.],
[17933. / 4096., -17390. / 4096., -543. / 4096.],
];
]);

const ICTCP_M1_INV: [[f32; 3]; 3] = [
const ICTCP_M1_INV: [[f32; 3]; 3] = t([
[3.4366066943, -2.5064521187, 0.0698454243],
[-0.7913295556, 1.9836004518, -0.1922708962],
[-0.0259498997, -0.0989137147, 1.1248636144],
];
const ICTCP_M2_INV: [[f32; 3]; 3] = [
]);
const ICTCP_M2_INV: [[f32; 3]; 3] = t([
[1., 0.008609037, 0.111029625],
[1., -0.008609037, -0.111029625],
[1., 0.5600313357, -0.320627175],
];

/// 3 * 3x3 Matrix multiply with vector transposed, ie pixel @ matrix
fn matmul3t<T: DType>(p: [T; 3], m: [[f32; 3]; 3]) -> [T; 3] {
[
p[0].fma(m[0][0].to_dt(), p[1].fma(m[1][0].to_dt(), p[2] * m[2][0].to_dt())),
p[0].fma(m[0][1].to_dt(), p[1].fma(m[1][1].to_dt(), p[2] * m[2][1].to_dt())),
p[0].fma(m[0][2].to_dt(), p[1].fma(m[1][2].to_dt(), p[2] * m[2][2].to_dt())),
]
}

/// Transposed 3 * 3x3 matrix multiply, ie matrix @ pixel
fn matmul3<T: DType>(m: [[f32; 3]; 3], p: [T; 3]) -> [T; 3] {
[
p[0].fma(m[0][0].to_dt(), p[1].fma(m[0][1].to_dt(), p[2] * m[0][2].to_dt())),
p[0].fma(m[1][0].to_dt(), p[1].fma(m[1][1].to_dt(), p[2] * m[1][2].to_dt())),
p[0].fma(m[2][0].to_dt(), p[1].fma(m[2][1].to_dt(), p[2] * m[2][2].to_dt())),
]
}
]);
// ### MATRICES ### }}}

// ### TRANSFER FUNCTIONS ### {{{
Expand Down Expand Up @@ -1112,7 +1116,7 @@ pub fn lrgb_to_xyz<T: DType, const N: usize>(pixel: &mut [T; N])
where
Channels<N>: ValidChannels,
{
[pixel[0], pixel[1], pixel[2]] = matmul3(XYZ65_MAT, [pixel[0], pixel[1], pixel[2]])
[pixel[0], pixel[1], pixel[2]] = mm(XYZ65_MAT, [pixel[0], pixel[1], pixel[2]])
}

/// Convert from CIE XYZ to CIE LAB.
Expand Down Expand Up @@ -1147,9 +1151,9 @@ pub fn xyz_to_oklab<T: DType, const N: usize>(pixel: &mut [T; N])
where
Channels<N>: ValidChannels,
{
let mut lms = matmul3t([pixel[0], pixel[1], pixel[2]], OKLAB_M1);
let mut lms = mm(OKLAB_M1, [pixel[0], pixel[1], pixel[2]]);
lms.iter_mut().for_each(|c| *c = c.scbrt());
[pixel[0], pixel[1], pixel[2]] = matmul3t(lms, OKLAB_M2);
[pixel[0], pixel[1], pixel[2]] = mm(OKLAB_M2, lms);
}

/// Convert CIE XYZ to JzAzBz
Expand All @@ -1159,7 +1163,7 @@ pub fn xyz_to_jzazbz<T: DType, const N: usize>(pixel: &mut [T; N])
where
Channels<N>: ValidChannels,
{
let mut lms = matmul3(
let mut lms = mm(
JZAZBZ_M1,
[
pixel[0].fma(JZAZBZ_B.to_dt(), T::ff32(-JZAZBZ_B + 1.0) * pixel[2]),
Expand All @@ -1170,7 +1174,7 @@ where

lms.iter_mut().for_each(|e| *e = pqz_oetf(*e));

let lab = matmul3(JZAZBZ_M2, lms);
let lab = mm(JZAZBZ_M2, lms);

pixel[0] = (T::ff32(1.0 + JZAZBZ_D) * lab[0]) / lab[0].fma(JZAZBZ_D.to_dt(), 1.0.to_dt()) - JZAZBZ_D0.to_dt();
pixel[1] = lab[1];
Expand Down Expand Up @@ -1200,10 +1204,10 @@ where
// };
// pixel.iter_mut().for_each(|c| bt2020(c));

let mut lms = matmul3(ICTCP_M1, [pixel[0], pixel[1], pixel[2]]);
let mut lms = mm(ICTCP_M1, [pixel[0], pixel[1], pixel[2]]);
// lms prime
lms.iter_mut().for_each(|c| *c = pq_oetf(*c));
[pixel[0], pixel[1], pixel[2]] = matmul3(ICTCP_M2, lms);
[pixel[0], pixel[1], pixel[2]] = mm(ICTCP_M2, lms);
}

/// Converts an LAB based space to a cylindrical representation.
Expand Down Expand Up @@ -1334,7 +1338,7 @@ pub fn xyz_to_lrgb<T: DType, const N: usize>(pixel: &mut [T; N])
where
Channels<N>: ValidChannels,
{
[pixel[0], pixel[1], pixel[2]] = matmul3(XYZ65_MAT_INV, [pixel[0], pixel[1], pixel[2]])
[pixel[0], pixel[1], pixel[2]] = mm(XYZ65_MAT_INV, [pixel[0], pixel[1], pixel[2]])
}

/// Convert from CIE LAB to CIE XYZ.
Expand Down Expand Up @@ -1369,9 +1373,9 @@ pub fn oklab_to_xyz<T: DType, const N: usize>(pixel: &mut [T; N])
where
Channels<N>: ValidChannels,
{
let mut lms = matmul3t([pixel[0], pixel[1], pixel[2]], OKLAB_M2_INV);
let mut lms = mm(OKLAB_M2_INV, [pixel[0], pixel[1], pixel[2]]);
lms.iter_mut().for_each(|c| *c = c.powi(3));
[pixel[0], pixel[1], pixel[2]] = matmul3t(lms, OKLAB_M1_INV);
[pixel[0], pixel[1], pixel[2]] = mm(OKLAB_M1_INV, lms);
}

/// Convert JzAzBz to CIE XYZ
Expand All @@ -1381,7 +1385,7 @@ pub fn jzazbz_to_xyz<T: DType, const N: usize>(pixel: &mut [T; N])
where
Channels<N>: ValidChannels,
{
let mut lms = matmul3(
let mut lms = mm(
JZAZBZ_M2_INV,
[
(pixel[0] + JZAZBZ_D0.to_dt())
Expand All @@ -1393,7 +1397,7 @@ where

lms.iter_mut().for_each(|c| *c = pqz_eotf(*c));

[pixel[0], pixel[1], pixel[2]] = matmul3(JZAZBZ_M1_INV, lms);
[pixel[0], pixel[1], pixel[2]] = mm(JZAZBZ_M1_INV, lms);

pixel[0] = pixel[2].fma((JZAZBZ_B - 1.0).to_dt(), pixel[0]) / JZAZBZ_B.to_dt();
pixel[1] = pixel[0].fma((JZAZBZ_G - 1.0).to_dt(), pixel[1]) / JZAZBZ_G.to_dt();
Expand All @@ -1415,10 +1419,10 @@ where
Channels<N>: ValidChannels,
{
// lms prime
let mut lms = matmul3(ICTCP_M2_INV, [pixel[0], pixel[1], pixel[2]]);
let mut lms = mm(ICTCP_M2_INV, [pixel[0], pixel[1], pixel[2]]);
// non-prime lms
lms.iter_mut().for_each(|c| *c = pq_eotf(*c));
[pixel[0], pixel[1], pixel[2]] = matmul3(ICTCP_M1_INV, lms);
[pixel[0], pixel[1], pixel[2]] = mm(ICTCP_M1_INV, lms);
}

/// Retrieves an LAB based space from its cylindrical representation.
Expand Down

0 comments on commit aaa2742

Please sign in to comment.