@@ -9,6 +9,7 @@ function atmexall(varargin)
9
9
% -fail Throw an exception if compiling any passmethod fails
10
10
% (By defaults compilation goes on)
11
11
% -openmp Build the integrators for OpenMP parallelisation
12
+ % -cuda CUDA_PATH Build the GPU tracking support using Cuda
12
13
% -c_only Do no compile C++ passmethods
13
14
% -DOMP_PARTICLE_THRESHOLD=n
14
15
% Set the parallelisation threshold to n particles
@@ -25,6 +26,7 @@ function atmexall(varargin)
25
26
26
27
pdir= fullfile(fileparts(atroot ),' atintegrators' );
27
28
[openmp ,varargs ]=getflag(varargin ,' -openmp' );
29
+ [cuda ,varargs ]=getoption(varargs ,' -cuda' ,' None' );
28
30
[miss_only ,varargs ]=getflag(varargs ,' -missing' );
29
31
[c_only ,varargs ]=getflag(varargs ,' -c_only' );
30
32
[fail ,varargs ]=getflag(varargs ,' -fail' );
@@ -92,6 +94,49 @@ function atmexall(varargin)
92
94
compile([alloptions , {passinclude }, LIBDL , ompoptions ], fullfile(cdir ,' atpass.c' ));
93
95
compile([atoptions , ompoptions ],fullfile(cdir ,' coptions.c' ))
94
96
97
+ % gpuextensions
98
+ if ~strcmp(cuda ,' None' )
99
+ gpudir= fullfile(fileparts(atroot ),' atgpu' ,' ' );
100
+ if ispc()
101
+ % TODO
102
+ error(' AT:atmexall' , ' GPU windows not supported' );
103
+ elseif ismac()
104
+ % TODO
105
+ error(' AT:atmexall' , ' GPU ismac not supported' );
106
+ else
107
+ gpuflags = {sprintf(' -I"%s "' ,gpudir ),...
108
+ sprintf(' -I"%s /include"' ,cuda ),...
109
+ sprintf(' -L"%s /lib64"' ,cuda ),...
110
+ sprintf(' LDFLAGS=$LDFLAGS -Wl,-rpath,"%s /lib64"' ,cuda ),...
111
+ ' -DCUDA' };
112
+ end
113
+ compile([alloptions , {passinclude }, gpuflags ], ...
114
+ fullfile(cdir ,' gpuinfo.cpp' ),...
115
+ fullfile(gpudir ,' MatlabInterface.cpp' ), ...
116
+ fullfile(gpudir ,' AbstractInterface.cpp' ), ...
117
+ fullfile(gpudir ,' CudaGPU.cpp' ), ...
118
+ fullfile(gpudir ,' AbstractGPU.cpp' ), ...
119
+ ' -lcuda' ,' -lnvrtc' );
120
+ compile([alloptions , {passinclude }, gpuflags ], ...
121
+ fullfile(cdir ,' gpupass.cpp' ),...
122
+ fullfile(gpudir ,' AbstractGPU.cpp' ), ...
123
+ fullfile(gpudir ,' CudaGPU.cpp' ), ...
124
+ fullfile(gpudir ,' AbstractInterface.cpp' ), ...
125
+ fullfile(gpudir ,' MatlabInterface.cpp' ), ...
126
+ fullfile(gpudir ,' Lattice.cpp' ), ...
127
+ fullfile(gpudir ,' PassMethodFactory.cpp' ), ...
128
+ fullfile(gpudir ,' SymplecticIntegrator.cpp' ), ...
129
+ fullfile(gpudir ,' IdentityPass.cpp' ), ...
130
+ fullfile(gpudir ,' DriftPass.cpp' ), ...
131
+ fullfile(gpudir ,' StrMPoleSymplectic4Pass.cpp' ), ...
132
+ fullfile(gpudir ,' BndMPoleSymplectic4Pass.cpp' ), ...
133
+ fullfile(gpudir ,' StrMPoleSymplectic4RadPass.cpp' ), ...
134
+ fullfile(gpudir ,' BndMPoleSymplectic4RadPass.cpp' ), ...
135
+ fullfile(gpudir ,' CavityPass.cpp' ), ...
136
+ fullfile(gpudir ,' RFCavityPass.cpp' ), ...
137
+ ' -lcuda' ,' -lnvrtc' );
138
+ end
139
+
95
140
[warnmess ,warnid ]=lastwarn ; % #ok<ASGLU>
96
141
if strcmp(warnid ,' MATLAB:mex:GccVersion_link' )
97
142
warning(' Disabling the compiler warning' );
0 commit comments