Skip to content

Commit

Permalink
Speeding up the random amplitudes code
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffreyEarly committed Sep 24, 2024
1 parent 4ab959b commit 81ce492
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,16 @@
methods (Static)
wvt = waveVortexTransformFromFile(path,options)

resultsTable = speedTest

matrix = CosineTransformForwardMatrix(N)
matrix = CosineTransformBackMatrix(n)
matrix = SineTransformForwardMatrix(N)
matrix = SineTransformBackMatrix(N)
end


methods (Access=protected)
methods %(Access=protected)
function ProfileTransforms(self)
Ubar = self.UAp.*self.Ap + self.UAm.*self.Am + self.UA0.*self.A0;
Nbar = self.NAp.*self.Ap + self.NAm.*self.Am + self.NA0.*self.A0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
function resultsTable = speedTest
% Initialize a WVTransformConstantStratification instance from an existing file
%
% - Topic: Initialization
% - Declaration: wvt = waveVortexTransformFromFile(path,options)
% - Parameter path: path to a NetCDF file
% - Parameter iTime: (optional) time index to initialize from (default 1)

Nxy = [32 64 64 128 128 128 256 256 256].';
Nz = [32 32 64 32 64 128 32 64 128].'+1;
nReps = [50 50 50 50 50 50 50 25 10].';
measuredTime = zeros(size(Nxy));

for iProfile=1:length(Nxy)
wvt = WVTransformConstantStratification([15e3, 15e3, 1300], [Nxy(iProfile) Nxy(iProfile) Nz(iProfile)]);
wvt.initWithRandomFlow(uvMax=0.01);
[Fp,Fm,F0] = wvt.nonlinearFlux();
tic
for iRep=1:nReps(iProfile)
wvt.t = iRep; % prevent caching
[Fp,Fm,F0] = wvt.nonlinearFlux();
end
measuredTime(iProfile) = toc/nReps(iProfile);
end
resultsTable = table(Nxy,Nz,measuredTime);
disp(resultsTable)
end
12 changes: 8 additions & 4 deletions Matlab/WaveVortexModel/UnitTests/ProfileableSpeedTest.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@

% profile on
% wvt = WVTransformHydrostatic([15e3, 15e3, 5000], 2*[64 64 33], N2=@(z) (5.2e-3)*(5.2e-3)*ones(size(z)));
wvt = WVTransformBoussinesq([15e3, 15e3, 5000], [64 64 33], N2=@(z) (5.2e-3)*(5.2e-3)*ones(size(z)));
% wvt = WVTransformConstantStratification([15e3, 15e3, 5000], [64 64 33]);
% wvt = WVTransformBoussinesq([15e3, 15e3, 5000], [64 64 33], N2=@(z) (5.2e-3)*(5.2e-3)*ones(size(z)));
wvt = WVTransformConstantStratification([15e3, 15e3, 5000], [128 128 128]);
% wvt = WVTransformSingleMode([2000e3 1000e3], 2*[256 128], h=0.8, latitude=25);
% profile viewer

%%
profile on
wvt.initWithRandomFlow();
profile viewer

% wvt.removeEnergyFromAliasedModes();
% spatialFlux = WVNonlinearFluxSpatial(wvt);
Expand Down Expand Up @@ -59,7 +63,7 @@
[Fp,Fm,F0] = wvt.nonlinearFlux();

tic
for i=1:50
for i=1:10
wvt.t = i;
[Fp,Fm,F0] = wvt.nonlinearFlux();
end
Expand All @@ -74,7 +78,7 @@

%%
profile on
for i=1:50
for i=1:10
wvt.t = i;
[Fp,Fm,F0] = wvt.nonlinearFlux();
end
Expand Down
55 changes: 34 additions & 21 deletions Matlab/WaveVortexModel/WVFlowComponent.m
Original file line number Diff line number Diff line change
Expand Up @@ -171,36 +171,49 @@
Am double
A0 double
end
if isequal(options.A0Spectrum,@isempty)
A0Spectrum = @(k,j) ones(size(k));
else
A0Spectrum = options.A0Spectrum;
end
if isequal(options.ApmSpectrum,@isempty)
ApmSpectrum = @(k,j) ones(size(k));
else
ApmSpectrum = options.ApmSpectrum;
end
% if isequal(options.A0Spectrum,@isempty)
% A0Spectrum = @(k,j) ones(size(k));
% else
% A0Spectrum = options.A0Spectrum;
% end
% if isequal(options.ApmSpectrum,@isempty)
% ApmSpectrum = @(k,j) ones(size(k));
% else
% ApmSpectrum = options.ApmSpectrum;
% end

[Ap,Am,A0] = self.randomAmplitudes(shouldOnlyRandomizeOrientations=options.shouldOnlyRandomizeOrientations);
hasRandomA0 = any(A0(:));
hasRandomApm = any(Ap(:)) || any(Am(:));

kRadial = self.wvt.kRadial;
Kh = self.wvt.Kh;
J = self.wvt.J;
dk = kRadial(2)-kRadial(1);
for iJ=1:length(self.wvt.j)
for iK=1:length(kRadial)
indicesForK = kRadial(iK)-dk/2 < Kh & Kh <= kRadial(iK)+dk/2 & J == self.wvt.j(iJ);

if any(A0(:))
energyPerA0Component = integral(@(k) A0Spectrum(k,J(iJ)),max(kRadial(iK)-dk/2,0),kRadial(iK)+dk/2)/sum(indicesForK(:));
A0(indicesForK) = A0(indicesForK).*sqrt(energyPerA0Component./(self.wvt.A0_TE_factor(indicesForK) ));
for iK=1:length(kRadial)
indicesForK = kRadial(iK)-dk/2 < Kh & Kh <= kRadial(iK)+dk/2;
for iJ=1:length(self.wvt.j)
% this is faster than logical indexing
indicesForKJ = find(indicesForK & J == self.wvt.j(iJ));
nIndicesForKJ = length(indicesForKJ);

if hasRandomA0
if isequal(options.A0Spectrum,@isempty)
energyPerA0Component = (kRadial(iK)+dk/2 - max(kRadial(iK)-dk/2,0))/nIndicesForKJ;
else
energyPerA0Component = integral(@(k) A0Spectrum(k,J(iJ)),max(kRadial(iK)-dk/2,0),kRadial(iK)+dk/2)/nIndicesForKJ;
end
A0(indicesForKJ) = A0(indicesForKJ).*sqrt(energyPerA0Component./(self.wvt.A0_TE_factor(indicesForKJ) ));
end

if any(Ap(:)) || any(Am(:))
energyPerApmComponent = integral(@(k) ApmSpectrum(k,J(iJ)),max(kRadial(iK)-dk/2,0),kRadial(iK)+dk/2)/sum(indicesForK(:))/2;
Ap(indicesForK) = Ap(indicesForK).*sqrt(energyPerApmComponent./(self.wvt.Apm_TE_factor(indicesForK) ));
Am(indicesForK) = Am(indicesForK).*sqrt(energyPerApmComponent./(self.wvt.Apm_TE_factor(indicesForK) ));
if hasRandomApm
if isequal(options.ApmSpectrum,@isempty)
energyPerApmComponent = (kRadial(iK)+dk/2 - max(kRadial(iK)-dk/2,0))/nIndicesForKJ/2;
else
energyPerApmComponent = integral(@(k) ApmSpectrum(k,J(iJ)),max(kRadial(iK)-dk/2,0),kRadial(iK)+dk/2)/nIndicesForKJ/2;
end
Ap(indicesForKJ) = Ap(indicesForKJ).*sqrt(energyPerApmComponent./(self.wvt.Apm_TE_factor(indicesForKJ) ));
Am(indicesForKJ) = Am(indicesForKJ).*sqrt(energyPerApmComponent./(self.wvt.Apm_TE_factor(indicesForKJ) ));
end
end
end
Expand Down

0 comments on commit 81ce492

Please sign in to comment.