Skip to content

Commit 4642c99

Browse files
author
Jean-Luc Pons
committed
Fix keeplattice issue (matlab) and improved error status
1 parent 044b082 commit 4642c99

File tree

2 files changed

+28
-23
lines changed

2 files changed

+28
-23
lines changed

atgpu/Lattice.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void Lattice::addElement() {
4343

4444
} catch (string& err) {
4545
// Try to retrieve name of element (if any)
46-
string idxStr = "#" + to_string(elements.size());
46+
string idxStr = "#" + to_string(elements.size()) + " (from #0)";
4747
string name = "";
4848
try {
4949
name = " (" + AbstractInterface::getInstance()->getString("Name") + ")";

atmat/attrack/gpupass.cpp

+27-22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
using namespace std;
77

8+
// Input params
89
#define LATTICE prhs[0]
910
#define RIN prhs[1]
1011
#define NEWLATTICE prhs[2]
@@ -15,6 +16,14 @@ using namespace std;
1516
#define GPUPOOL prhs[7]
1617
#define INTEGRATOR prhs[8]
1718

19+
// Free locally allocated memory
20+
#define CLEANUP() \
21+
if(mxLostCoord) mxDestroyArray(mxLostCoord); \
22+
delete[] ref_pts; \
23+
delete[] xnturnPtr; \
24+
delete[] xnelemPtr; \
25+
delete[] xlostPtr;
26+
1827
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
1928

2029
if (nlhs > 2)
@@ -25,13 +34,20 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
2534
AbstractInterface::setHandler(new MatlabInterface());
2635
}
2736

37+
// Temporary buffers
38+
uint32_t *ref_pts = nullptr;
39+
uint32_t *xnturnPtr = nullptr;
40+
uint32_t *xnelemPtr = nullptr;
41+
bool *xlostPtr = nullptr;
42+
mxArray *mxLostCoord = nullptr;
43+
2844
// Default symplectic integrator (4th order)
2945
static SymplecticIntegrator integrator(4);
3046
// Lattice object
3147
static Lattice *gpuLattice = nullptr;
3248

3349
int num_turns=(int)mxGetScalar(NTURNS);
34-
int keep_lattice=(mxGetScalar(NEWLATTICE) == 0) ? 0 : 1;
50+
int keep_lattice=(mxGetScalar(NEWLATTICE) == 0) ? 1 : 0;
3551
int keep_counter=(int)mxGetScalar(KEEPCOUNTER);
3652
int counter=(int)mxGetScalar(TURN);
3753
int losses=(nlhs == 2);
@@ -49,18 +65,17 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
4965
AT_FLOAT *drin = (AT_FLOAT *)mxGetDoubles(RIN);
5066

5167
// Reference points
52-
uint32_t *ref_pts;
5368
uint32_t num_refs = (uint32_t)mxGetNumberOfElements(REFPTS);
5469

5570
if( num_refs==0 ) {
5671
// One ref at the end of the turn
5772
num_refs = 1;
58-
ref_pts = (uint32_t *) mxCalloc(num_refs, sizeof(uint32_t));
73+
ref_pts = new uint32_t[num_refs];
5974
ref_pts[0] = mxGetNumberOfElements(LATTICE);
6075
} else {
6176
// Convert indices to uint32_t
6277
mxDouble *dblrefpts = mxGetDoubles(REFPTS);
63-
ref_pts = (uint32_t *) mxCalloc(num_refs, sizeof(uint32_t));
78+
ref_pts = new uint32_t[num_refs];
6479
for (int i = 0; i < num_refs; i++)
6580
ref_pts[i] = ((int) dblrefpts[i]) - 1;
6681
}
@@ -105,7 +120,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
105120
} catch (string& errStr) {
106121
delete gpuLattice;
107122
gpuLattice = nullptr;
108-
mxFree(ref_pts);
123+
CLEANUP();
109124
string err = "at_gpupass() build lattice failed: " + errStr;
110125
mexErrMsgIdAndTxt("Atpass:RuntimeError",err.c_str());
111126
}
@@ -116,7 +131,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
116131
try {
117132
gpuLattice->fillGPUMemory();
118133
} catch (string& errStr) {
119-
mxFree(ref_pts);
134+
CLEANUP();
120135
string err = "at_gpupass() fill GPU memory failed: " + errStr;
121136
mexErrMsgIdAndTxt("Atpass:RuntimeError",err.c_str());
122137
}
@@ -125,11 +140,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
125140
if( !keep_counter )
126141
gpuLattice->setTurnCounter(counter);
127142

128-
// Buffer for lost info
129-
uint32_t *xnturnPtr = nullptr;
130-
uint32_t *xnelemPtr = nullptr;
131-
bool *xlostPtr = nullptr;
132-
mxArray *mxLostCoord = nullptr;
133143

134144
try {
135145

@@ -138,7 +148,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
138148
uint32_t outsize=num_particles*num_refs*num_turns;
139149
plhs[0] = mxCreateDoubleMatrix(6,outsize,mxREAL);
140150
if( plhs[0]==nullptr ) {
141-
mxFree(ref_pts);
151+
CLEANUP();
142152
mexErrMsgIdAndTxt("Atpass:RuntimeError","Not enough memory while trying to allocate particle output coordinates");
143153
}
144154
AT_FLOAT *drout = (AT_FLOAT *)mxGetDoubles(plhs[0]);
@@ -155,6 +165,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
155165
xlostPtr = new bool[num_particles];
156166
AT_FLOAT *xlostcoordPtr = (AT_FLOAT *)mxGetDoubles(mxLostCoord);
157167

168+
// Tracking
158169
gpuLattice->run(num_turns,num_particles,drin,drout,num_refs,ref_pts,num_starts,track_starts,xnturnPtr,xnelemPtr,xlostcoordPtr,false);
159170

160171
// Format result for AT
@@ -183,29 +194,23 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
183194
mxSetField(mxLoss, 0, lossinfo[1], mxNturn);
184195
mxSetField(mxLoss, 0, lossinfo[2], mxNelem);
185196
mxSetField(mxLoss, 0, lossinfo[3], mxLostCoord);
197+
mxLostCoord = nullptr; // Mark as used to avoid unwanted free
186198
plhs[1]=mxLoss;
187199

188200
} else {
189201

202+
// Tracking
190203
gpuLattice->run(num_turns,num_particles,drin,drout,num_refs,ref_pts,num_starts,track_starts,nullptr,nullptr,nullptr,false);
191204

192205
}
193206

194207
} catch (string& errStr) {
195-
mxFree(ref_pts);
208+
CLEANUP();
196209
mxDestroyArray(plhs[0]);
197-
if(mxLostCoord) mxDestroyArray(mxLostCoord);
198-
delete[] xnturnPtr;
199-
delete[] xnelemPtr;
200-
delete[] xlostPtr;
201210
string err = "at_gpupass() run failed: " + errStr;
202211
mexErrMsgIdAndTxt("Atpass:RuntimeError",err.c_str());
203-
return; // Avoid warning (delete non allocated memory)
204212
}
205213

206-
delete[] xnturnPtr;
207-
delete[] xnelemPtr;
208-
delete[] xlostPtr;
209-
mxFree(ref_pts);
214+
CLEANUP();
210215

211216
}

0 commit comments

Comments
 (0)