5
5
6
6
using namespace std ;
7
7
8
+ // Input params
8
9
#define LATTICE prhs[0 ]
9
10
#define RIN prhs[1 ]
10
11
#define NEWLATTICE prhs[2 ]
@@ -15,6 +16,14 @@ using namespace std;
15
16
#define GPUPOOL prhs[7 ]
16
17
#define INTEGRATOR prhs[8 ]
17
18
19
+ // Free locally allocated memory
20
+ #define CLEANUP () \
21
+ if (mxLostCoord) mxDestroyArray(mxLostCoord); \
22
+ delete[] ref_pts; \
23
+ delete[] xnturnPtr; \
24
+ delete[] xnelemPtr; \
25
+ delete[] xlostPtr;
26
+
18
27
void mexFunction (int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
19
28
20
29
if (nlhs > 2 )
@@ -25,13 +34,20 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
25
34
AbstractInterface::setHandler (new MatlabInterface ());
26
35
}
27
36
37
+ // Temporary buffers
38
+ uint32_t *ref_pts = nullptr ;
39
+ uint32_t *xnturnPtr = nullptr ;
40
+ uint32_t *xnelemPtr = nullptr ;
41
+ bool *xlostPtr = nullptr ;
42
+ mxArray *mxLostCoord = nullptr ;
43
+
28
44
// Default symplectic integrator (4th order)
29
45
static SymplecticIntegrator integrator (4 );
30
46
// Lattice object
31
47
static Lattice *gpuLattice = nullptr ;
32
48
33
49
int num_turns=(int )mxGetScalar (NTURNS);
34
- int keep_lattice=(mxGetScalar (NEWLATTICE) == 0 ) ? 0 : 1 ;
50
+ int keep_lattice=(mxGetScalar (NEWLATTICE) == 0 ) ? 1 : 0 ;
35
51
int keep_counter=(int )mxGetScalar (KEEPCOUNTER);
36
52
int counter=(int )mxGetScalar (TURN);
37
53
int losses=(nlhs == 2 );
@@ -49,18 +65,17 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
49
65
AT_FLOAT *drin = (AT_FLOAT *)mxGetDoubles (RIN);
50
66
51
67
// Reference points
52
- uint32_t *ref_pts;
53
68
uint32_t num_refs = (uint32_t )mxGetNumberOfElements (REFPTS);
54
69
55
70
if ( num_refs==0 ) {
56
71
// One ref at the end of the turn
57
72
num_refs = 1 ;
58
- ref_pts = ( uint32_t *) mxCalloc (num_refs, sizeof ( uint32_t )) ;
73
+ ref_pts = new uint32_t [num_refs] ;
59
74
ref_pts[0 ] = mxGetNumberOfElements (LATTICE);
60
75
} else {
61
76
// Convert indices to uint32_t
62
77
mxDouble *dblrefpts = mxGetDoubles (REFPTS);
63
- ref_pts = ( uint32_t *) mxCalloc (num_refs, sizeof ( uint32_t )) ;
78
+ ref_pts = new uint32_t [num_refs] ;
64
79
for (int i = 0 ; i < num_refs; i++)
65
80
ref_pts[i] = ((int ) dblrefpts[i]) - 1 ;
66
81
}
@@ -105,7 +120,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
105
120
} catch (string& errStr) {
106
121
delete gpuLattice;
107
122
gpuLattice = nullptr ;
108
- mxFree (ref_pts );
123
+ CLEANUP ( );
109
124
string err = " at_gpupass() build lattice failed: " + errStr;
110
125
mexErrMsgIdAndTxt (" Atpass:RuntimeError" ,err.c_str ());
111
126
}
@@ -116,7 +131,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
116
131
try {
117
132
gpuLattice->fillGPUMemory ();
118
133
} catch (string& errStr) {
119
- mxFree (ref_pts );
134
+ CLEANUP ( );
120
135
string err = " at_gpupass() fill GPU memory failed: " + errStr;
121
136
mexErrMsgIdAndTxt (" Atpass:RuntimeError" ,err.c_str ());
122
137
}
@@ -125,11 +140,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
125
140
if ( !keep_counter )
126
141
gpuLattice->setTurnCounter (counter);
127
142
128
- // Buffer for lost info
129
- uint32_t *xnturnPtr = nullptr ;
130
- uint32_t *xnelemPtr = nullptr ;
131
- bool *xlostPtr = nullptr ;
132
- mxArray *mxLostCoord = nullptr ;
133
143
134
144
try {
135
145
@@ -138,7 +148,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
138
148
uint32_t outsize=num_particles*num_refs*num_turns;
139
149
plhs[0 ] = mxCreateDoubleMatrix (6 ,outsize,mxREAL);
140
150
if ( plhs[0 ]==nullptr ) {
141
- mxFree (ref_pts );
151
+ CLEANUP ( );
142
152
mexErrMsgIdAndTxt (" Atpass:RuntimeError" ," Not enough memory while trying to allocate particle output coordinates" );
143
153
}
144
154
AT_FLOAT *drout = (AT_FLOAT *)mxGetDoubles (plhs[0 ]);
@@ -155,6 +165,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
155
165
xlostPtr = new bool [num_particles];
156
166
AT_FLOAT *xlostcoordPtr = (AT_FLOAT *)mxGetDoubles (mxLostCoord);
157
167
168
+ // Tracking
158
169
gpuLattice->run (num_turns,num_particles,drin,drout,num_refs,ref_pts,num_starts,track_starts,xnturnPtr,xnelemPtr,xlostcoordPtr,false );
159
170
160
171
// Format result for AT
@@ -183,29 +194,23 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
183
194
mxSetField (mxLoss, 0 , lossinfo[1 ], mxNturn);
184
195
mxSetField (mxLoss, 0 , lossinfo[2 ], mxNelem);
185
196
mxSetField (mxLoss, 0 , lossinfo[3 ], mxLostCoord);
197
+ mxLostCoord = nullptr ; // Mark as used to avoid unwanted free
186
198
plhs[1 ]=mxLoss;
187
199
188
200
} else {
189
201
202
+ // Tracking
190
203
gpuLattice->run (num_turns,num_particles,drin,drout,num_refs,ref_pts,num_starts,track_starts,nullptr ,nullptr ,nullptr ,false );
191
204
192
205
}
193
206
194
207
} catch (string& errStr) {
195
- mxFree (ref_pts );
208
+ CLEANUP ( );
196
209
mxDestroyArray (plhs[0 ]);
197
- if (mxLostCoord) mxDestroyArray (mxLostCoord);
198
- delete[] xnturnPtr;
199
- delete[] xnelemPtr;
200
- delete[] xlostPtr;
201
210
string err = " at_gpupass() run failed: " + errStr;
202
211
mexErrMsgIdAndTxt (" Atpass:RuntimeError" ,err.c_str ());
203
- return ; // Avoid warning (delete non allocated memory)
204
212
}
205
213
206
- delete[] xnturnPtr;
207
- delete[] xnelemPtr;
208
- delete[] xlostPtr;
209
- mxFree (ref_pts);
214
+ CLEANUP ();
210
215
211
216
}
0 commit comments