Skip to content

Commit

Permalink
Merge pull request #467 from rest-for-physics/lobis-add-converters
Browse files Browse the repository at this point in the history
Add support for std::pair
  • Loading branch information
lobis committed Aug 28, 2023
2 parents 7220465 + ff34c2f commit cad41cd
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 14 deletions.
2 changes: 2 additions & 0 deletions macros/REST_OpenInputFile.C
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ void REST_OpenInputFile(const std::string& fileName) {
if (TRestTools::isRunFile(fileName)) {
printf("\n%s\n\n", "REST processed file identified. It contains a valid TRestRun.");
run = new TRestRun(fileName);
// print number of entries in the run
printf("\nThe run has %lld entries.\n", run->GetEntries());
printf("\nAttaching TRestRun %s as run...\n", fileName.c_str());
ana_tree = run->GetAnalysisTree();
printf("Attaching TRestAnalysisTree as ana_tree...\n");
Expand Down
2 changes: 1 addition & 1 deletion source/framework/core/inc/TRestHits.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TRestHits {

public:
void Translate(Int_t n, Double_t x, Double_t y, Double_t z);
void RotateIn3D(Int_t n, Double_t alpha, Double_t beta, Double_t gamma, const TVector3& vMean);
void RotateIn3D(Int_t n, Double_t alpha, Double_t beta, Double_t gamma, const TVector3& center);
void Rotate(Int_t n, Double_t alpha, const TVector3& vAxis, const TVector3& vMean);

void AddHit(Double_t x, Double_t y, Double_t z, Double_t en, Double_t t = 0, REST_HitType type = XYZ);
Expand Down
6 changes: 4 additions & 2 deletions source/framework/core/inc/TRestRun.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ class TRestRun : public TRestMetadata {
void GetEntry(Long64_t entry);

void GetNextEntry() {
if (fCurrentEvent + 1 >= GetEntries()) fCurrentEvent = -1;
if (fCurrentEvent + 1 >= GetEntries()) {
fCurrentEvent = -1;
}
GetEntry(fCurrentEvent + 1);
}

Expand Down Expand Up @@ -246,7 +248,7 @@ class TRestRun : public TRestMetadata {

// Constructor & Destructor
TRestRun();
TRestRun(const std::string& filename);
explicit TRestRun(const std::string& filename);
~TRestRun();

ClassDefOverride(TRestRun, 6);
Expand Down
23 changes: 14 additions & 9 deletions source/framework/core/src/TRestHits.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ Bool_t TRestHits::isNaN(Int_t n) const {
///
Double_t TRestHits::GetEnergyIntegral() const {
Double_t sum = 0;
for (unsigned int i = 0; i < GetNumberOfHits(); i++) sum += GetEnergy(i);
for (unsigned int i = 0; i < GetNumberOfHits(); i++) {
sum += GetEnergy(i);
}
return sum;
}

Expand Down Expand Up @@ -207,8 +209,11 @@ Double_t TRestHits::GetEnergyInPrism(const TVector3& x0, const TVector3& x1, Dou
Double_t theta) const {
Double_t energy = 0.;

for (unsigned int n = 0; n < GetNumberOfHits(); n++)
if (isHitNInsidePrism(n, x0, x1, sizeX, sizeY, theta)) energy += this->GetEnergy(n);
for (unsigned int n = 0; n < GetNumberOfHits(); n++) {
if (isHitNInsidePrism(n, x0, x1, sizeX, sizeY, theta)) {
energy += this->GetEnergy(n);
}
}

return energy;
}
Expand Down Expand Up @@ -416,17 +421,17 @@ void TRestHits::Translate(Int_t n, double x, double y, double z) {
/// \brief It rotates hit `n` following rotations in Z, Y and X by angles gamma, beta and alpha. The
/// rotation is performed with center at `vMean`.
///
void TRestHits::RotateIn3D(Int_t n, Double_t alpha, Double_t beta, Double_t gamma, const TVector3& vMean) {
TVector3 position = GetPosition(n);
TVector3 vHit = position - vMean;
void TRestHits::RotateIn3D(Int_t n, Double_t alpha, Double_t beta, Double_t gamma, const TVector3& center) {
const TVector3 position = GetPosition(n);
TVector3 vHit = position - center;

vHit.RotateZ(gamma);
vHit.RotateY(beta);
vHit.RotateX(alpha);

fX[n] = vHit[0] + vMean[0];
fY[n] = vHit[1] + vMean[1];
fZ[n] = vHit[2] + vMean[2];
fX[n] = vHit.X() + center.X();
fY[n] = vHit.Y() + center.Y();
fZ[n] = vHit.Z() + center.Z();
}

///////////////////////////////////////////////
Expand Down
7 changes: 5 additions & 2 deletions source/framework/core/src/TRestRun.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1617,8 +1617,11 @@ TRestMetadata* TRestRun::GetMetadata(const TString& name, TFile* file) {
}
}
} else {
for (unsigned int i = 0; i < fMetadata.size(); i++)
if (fMetadata[i]->GetName() == name) return fMetadata[i];
for (unsigned int i = 0; i < fMetadata.size(); i++) {
if (fMetadata[i]->GetName() == name) {
return fMetadata[i];
}
}
}

return nullptr;
Expand Down
252 changes: 252 additions & 0 deletions source/framework/core/src/startup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ vector<T> StringToVector(string vec) {
return result;
}
AddConverter(VectorToString, StringToVector, vector<int>);
AddConverter(VectorToString, StringToVector, vector<UShort_t>);
AddConverter(VectorToString, StringToVector, vector<float>);
AddConverter(VectorToString, StringToVector, vector<double>);
AddConverter(VectorToString, StringToVector, vector<string>);
Expand Down Expand Up @@ -433,3 +434,254 @@ AddConverter(MapToString, StringToMap, map<TString comma TString>);
AddConverter(MapToString, StringToMap, map<TString comma string>);

AddConverter(MapToString, StringToMap, map<TString comma TVector2>);

template <class T, class U>
string PairToString(pair<T, U> p) {
string result = "{";
result += Converter<T>::thisptr->ToStringFunc(p.first);
result += ",";
result += Converter<U>::thisptr->ToStringFunc(p.second);
result += "}";
return result;
}
template <class T, class U>
pair<T, U> StringToPair(string vec) {
pair<T, U> result;
if (vec[0] == '{' && vec[vec.size() - 1] == '}') {
vec.erase(vec.begin());
vec.erase(vec.end() - 1);
vector<string> parts = Split(vec, ",");

if (parts.size() == 2) {
while (parts[0][0] == ' ') {
parts[0].erase(parts[0].begin());
}
while (parts[0][parts[0].size() - 1] == ' ') {
parts[0].erase(parts[0].end() - 1);
}
while (parts[1][0] == ' ') {
parts[1].erase(parts[1].begin());
}
while (parts[1][parts[1].size() - 1] == ' ') {
parts[1].erase(parts[1].end() - 1);
}
result.first = Converter<T>::thisptr->ParseStringFunc(parts[0]);
result.second = Converter<U>::thisptr->ParseStringFunc(parts[1]);
} else {
cout << "illegal format!" << endl;
return pair<T, U>{};
}

} else {
cout << "illegal format!" << endl;
return pair<T, U>{};
}
return result;
}
AddConverter(PairToString, StringToPair, pair<int comma int>);
AddConverter(PairToString, StringToPair, pair<int comma float>);
AddConverter(PairToString, StringToPair, pair<int comma double>);
AddConverter(PairToString, StringToPair, pair<UShort_t comma float>);
AddConverter(PairToString, StringToPair, pair<UShort_t comma double>);

// a vector of pairs
template <class T, class U>
string PairVectorToString(vector<pair<T, U>> vec) {
stringstream ss;
ss << "{";
int cont = 0;
for (auto const& x : vec) {
if (cont > 0) ss << ",";
cont++;

ss << "[";
ss << Converter<T>::thisptr->ToStringFunc(x.first);
ss << ":";
ss << Converter<U>::thisptr->ToStringFunc(x.second);
ss << "]";
}
ss << "}";
return ss.str();
}
template <class T, class U>
vector<pair<T, U>> StringToPairVector(string vec) {
vector<pair<T, U>> result;
// input string format: {[dd:7],[aa:8],[ss:9]}
if (vec[0] == '{' && vec[vec.size() - 1] == '}') {
vec.erase(vec.begin());
vec.erase(vec.end() - 1);
vector<string> parts = Split(vec, ",");

for (string part : parts) {
while (part[0] == ' ') {
part.erase(part.begin());
}
while (part[part.size() - 1] == ' ') {
part.erase(part.end() - 1);
}

if (part[0] == '[' && part[part.size() - 1] == ']') {
part.erase(part.begin());
part.erase(part.end() - 1);
vector<string> key_value = Split(part, ":");
if (key_value.size() == 2) {
T key = Converter<T>::thisptr->ParseStringFunc(key_value[0]);
U value = Converter<U>::thisptr->ParseStringFunc(key_value[1]);
result.push_back(pair<T, U>(key, value));
} else {
cout << "illegal format!" << endl;
return vector<pair<T, U>>{};
}
} else {
cout << "illegal format!" << endl;
return vector<pair<T, U>>{};
}
}

} else {
cout << "illegal format!" << endl;
return vector<pair<T, U>>{};
}

return result;
}
AddConverter(PairVectorToString, StringToPairVector, vector<pair<int comma int>>);
AddConverter(PairVectorToString, StringToPairVector, vector<pair<int comma float>>);
AddConverter(PairVectorToString, StringToPairVector, vector<pair<int comma double>>);
AddConverter(PairVectorToString, StringToPairVector, vector<pair<UShort_t comma float>>);
AddConverter(PairVectorToString, StringToPairVector, vector<pair<UShort_t comma double>>);

// Implement for triple (tuple)
template <class T, class U, class V>
string TripleToString(tuple<T, U, V> t) {
string result = "{";
result += Converter<T>::thisptr->ToStringFunc(get<0>(t));
result += ",";
result += Converter<U>::thisptr->ToStringFunc(get<1>(t));
result += ",";
result += Converter<V>::thisptr->ToStringFunc(get<2>(t));
result += "}";
return result;
}

template <class T, class U, class V>
tuple<T, U, V> StringToTriple(string vec) {
tuple<T, U, V> result;
if (vec[0] == '{' && vec[vec.size() - 1] == '}') {
vec.erase(vec.begin());
vec.erase(vec.end() - 1);
vector<string> parts = Split(vec, ",");

if (parts.size() == 3) {
while (parts[0][0] == ' ') {
parts[0].erase(parts[0].begin());
}
while (parts[0][parts[0].size() - 1] == ' ') {
parts[0].erase(parts[0].end() - 1);
}
while (parts[1][0] == ' ') {
parts[1].erase(parts[1].begin());
}
while (parts[1][parts[1].size() - 1] == ' ') {
parts[1].erase(parts[1].end() - 1);
}
while (parts[2][0] == ' ') {
parts[2].erase(parts[2].begin());
}
while (parts[2][parts[2].size() - 1] == ' ') {
parts[2].erase(parts[2].end() - 1);
}
get<0>(result) = Converter<T>::thisptr->ParseStringFunc(parts[0]);
get<1>(result) = Converter<U>::thisptr->ParseStringFunc(parts[1]);
get<2>(result) = Converter<V>::thisptr->ParseStringFunc(parts[2]);
} else {
cout << "illegal format!" << endl;
return tuple<T, U, V>{};
}

} else {
cout << "illegal format!" << endl;
return tuple<T, U, V>{};
}
return result;
}

AddConverter(TripleToString, StringToTriple, tuple<int comma int comma int>);
AddConverter(TripleToString, StringToTriple, tuple<int comma int comma float>);
AddConverter(TripleToString, StringToTriple, tuple<int comma int comma double>);
AddConverter(TripleToString, StringToTriple, tuple<UShort_t comma UShort_t comma int>);
AddConverter(TripleToString, StringToTriple, tuple<UShort_t comma UShort_t comma float>);
AddConverter(TripleToString, StringToTriple, tuple<UShort_t comma UShort_t comma double>);

// vector of triple
template <class T, class U, class V>
string TripleVectorToString(vector<tuple<T, U, V>> vec) {
stringstream ss;
ss << "{";
int cont = 0;
for (auto const& x : vec) {
if (cont > 0) ss << ",";
cont++;

ss << "[";
ss << Converter<T>::thisptr->ToStringFunc(get<0>(x));
ss << ":";
ss << Converter<U>::thisptr->ToStringFunc(get<1>(x));
ss << ":";
ss << Converter<V>::thisptr->ToStringFunc(get<2>(x));
ss << "]";
}
ss << "}";
return ss.str();
}

template <class T, class U, class V>
vector<tuple<T, U, V>> StringToTripleVector(string vec) {
vector<tuple<T, U, V>> result;
// input string format: {[dd:7],[aa:8],[ss:9]}
if (vec[0] == '{' && vec[vec.size() - 1] == '}') {
vec.erase(vec.begin());
vec.erase(vec.end() - 1);
vector<string> parts = Split(vec, ",");

for (string part : parts) {
while (part[0] == ' ') {
part.erase(part.begin());
}
while (part[part.size() - 1] == ' ') {
part.erase(part.end() - 1);
}

if (part[0] == '[' && part[part.size() - 1] == ']') {
part.erase(part.begin());
part.erase(part.end() - 1);
vector<string> key_value = Split(part, ":");
if (key_value.size() == 3) {
T key = Converter<T>::thisptr->ParseStringFunc(key_value[0]);
U value = Converter<U>::thisptr->ParseStringFunc(key_value[1]);
V value2 = Converter<V>::thisptr->ParseStringFunc(key_value[2]);
result.push_back(tuple<T, U, V>(key, value, value2));
} else {
cout << "illegal format!" << endl;
return vector<tuple<T, U, V>>{};
}
} else {
cout << "illegal format!" << endl;
return vector<tuple<T, U, V>>{};
}
}

} else {
cout << "illegal format!" << endl;
return vector<tuple<T, U, V>>{};
}

return result;
}

AddConverter(TripleVectorToString, StringToTripleVector, vector<tuple<int comma int comma int>>);
AddConverter(TripleVectorToString, StringToTripleVector, vector<tuple<int comma int comma float>>);
AddConverter(TripleVectorToString, StringToTripleVector, vector<tuple<int comma int comma double>>);
AddConverter(TripleVectorToString, StringToTripleVector, vector<tuple<UShort_t comma UShort_t comma int>>);
AddConverter(TripleVectorToString, StringToTripleVector, vector<tuple<UShort_t comma UShort_t comma float>>);
AddConverter(TripleVectorToString, StringToTripleVector, vector<tuple<UShort_t comma UShort_t comma double>>);

0 comments on commit cad41cd

Please sign in to comment.