Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New neighborhood code #2314

Merged
merged 9 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Libs/Optimize/Domain/MeshDomain.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class MeshDomain : public ParticleDomain {

void UpdateZeroCrossingPoint() override {}

std::shared_ptr<MeshWrapper> GetMeshWrapper() const { return mesh_wrapper_; }

private:
std::shared_ptr<MeshWrapper> mesh_wrapper_;
std::shared_ptr<MeshWrapper> geodesics_mesh_;
Expand Down
20 changes: 8 additions & 12 deletions Libs/Optimize/Function/CurvatureSamplingFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,8 @@ void CurvatureSamplingFunction::UpdateNeighborhood(const CurvatureSamplingFuncti
m_CurrentNeighborhood.clear();
for (int offset = 0; offset < domains_per_shape; offset++) {
const auto domain_t = domain_base * domains_per_shape + offset;
const auto neighborhood_ = system->GetNeighborhood(domain_t).GetPointer();
auto neighborhood = system->GetNeighborhood(domain_t);
using ImageType = itk::Image<float, Dimension>;
auto neighborhood__ = dynamic_cast<const ParticleSurfaceNeighborhood*>(neighborhood_);

// unfortunately required because we need to mutate the cosine weighting state
auto neighborhood = const_cast<ParticleSurfaceNeighborhood*>(neighborhood__);

if (!m_IsSharedBoundaryEnabled && domain_t != d) {
continue;
Expand All @@ -273,19 +269,19 @@ void CurvatureSamplingFunction::UpdateNeighborhood(const CurvatureSamplingFuncti
std::vector<ParticlePointIndexPair> res;
if (domain_t == d) {
// same domain
res = neighborhood->FindNeighborhoodPoints(pos, idx, weights, distances, radius);
res = neighborhood->find_neighborhood_points(pos, idx, weights, distances, radius);
} else {
// cross domain

bool weighting_state = neighborhood->IsWeightingEnabled();
bool weighting_state = neighborhood->is_weighting_enabled();
// Disable cosine-falloff weighting for cross-domain sampling term. Contours don't have normals.
neighborhood->SetWeightingEnabled(false);
neighborhood->SetForceEuclidean(true);
neighborhood->set_weighting_enabled(false);
neighborhood->set_force_euclidean(true);

res = neighborhood->FindNeighborhoodPoints(pos, -1, weights, distances, radius);
res = neighborhood->find_neighborhood_points(pos, -1, weights, distances, radius);

neighborhood->SetForceEuclidean(false);
neighborhood->SetWeightingEnabled(weighting_state);
neighborhood->set_force_euclidean(false);
neighborhood->set_weighting_enabled(weighting_state);
}

assert(weights.size() == distances.size() && res.size() == weights.size());
Expand Down
1 change: 0 additions & 1 deletion Libs/Optimize/Function/CurvatureSamplingFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "Libs/Optimize/Container/MeanCurvatureContainer.h"
#include "Libs/Optimize/Domain/ImageDomainWithCurvature.h"
#include "Libs/Optimize/Domain/ImageDomainWithGradients.h"
#include "Libs/Optimize/Neighborhood/ParticleSurfaceNeighborhood.h"
#include "SamplingFunction.h"
#include "itkCommand.h"

Expand Down
116 changes: 116 additions & 0 deletions Libs/Optimize/Neighborhood/ParticleNeighborhood.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#include "ParticleNeighborhood.h"

#include <Libs/Optimize/Domain/MeshDomain.h>

#include "ParticleSystem.h"

namespace shapeworks {

//--------------------------------------------------------------------------------------------------
std::pair<std::vector<ParticlePointIndexPair>, std::vector<double>> ParticleNeighborhood::get_points_in_sphere(
const itk::Point<double, 3>& position, int id, double radius) {
// iterate over all particles in the system for this domain, return those within radius
std::vector<ParticlePointIndexPair> neighbors;
std::vector<double> distances;

auto num_particles = ps_->GetNumberOfParticles(domain_id_);

for (unsigned int i = 0; i < num_particles; i++) {
if (i == id) {
continue;
}
auto p = ps_->GetPosition(i, domain_id_);
double distance = (p - position).GetNorm();
if (distance < radius) {
neighbors.push_back(ParticlePointIndexPair(p, i));
distances.push_back(distance);
}
}

return {neighbors, distances};
}

//--------------------------------------------------------------------------------------------------

std::vector<ParticlePointIndexPair> ParticleNeighborhood::find_neighborhood_points(
const itk::Point<double, 3>& position, int id, std::vector<double>& weights, std::vector<double>& distances,
double radius) {
auto [pointlist, neighbor_distance] = get_points_in_sphere(position, id, radius);

using GradientVectorType = vnl_vector_fixed<float, 3>;
using PointType = itk::Point<double, 3>;

GradientVectorType normal;
if (weighting_enabled_) { // uninitialized otherwise, but we're trying to avoid looking up the normal if we can
normal = domain_->SampleNormalAtPoint(position, id);
}

weights.clear();
distances.clear();

bool use_euclidean = true;

if (domain_->GetDomainType() == DomainType::Mesh) {
// cast to MeshDomain to ask if geodesics are enabled
auto mesh_domain = std::dynamic_pointer_cast<MeshDomain>(domain_);
use_euclidean = !mesh_domain->GetMeshWrapper()->IsGeodesicsEnabled();
} else if (domain_->GetDomainType() == DomainType::Contour) {
use_euclidean = false;
}

if (force_euclidean_) {
use_euclidean = true;
}

std::vector<ParticlePointIndexPair> ret;

for (int i = 0; i < pointlist.size(); i++) {
const auto& pt_b = pointlist[i].Point;
const auto& idx_b = pointlist[i].Index;

double distance = neighbor_distance[i];
bool is_within_distance = true;

if (!use_euclidean) {
is_within_distance = domain_->IsWithinDistance(position, id, pt_b, idx_b, radius, distance);
}

if (is_within_distance) {
ret.push_back(pointlist[i]);
distances.push_back(distance);

// todo change the APIs so don't have to pass a std::vector<double> of 1s whenever weighting is disabled
if (!weighting_enabled_) {
weights.push_back(1.0);
continue;
}

const GradientVectorType pn = domain_->SampleNormalAtPoint(pointlist[i].Point, pointlist[i].Index);
const double cosine = dot_product(normal, pn); // normals already normalized
if (cosine >= flat_cutoff_) {
weights.push_back(1.0);
} else {
// Drop to zero influence over 90 degrees.
weights.push_back(cos((flat_cutoff_ - cosine) / (1.0 + flat_cutoff_) * 1.5708));
}
}
}
return ret;
}

//--------------------------------------------------------------------------------------------------
std::vector<ParticlePointIndexPair> ParticleNeighborhood::find_neighborhood_points(
const itk::Point<double, 3>& position, int id, std::vector<double>& weights, double radius) {
std::vector<double> distances;
return find_neighborhood_points(position, id, weights, distances, radius);
}

//--------------------------------------------------------------------------------------------------
std::vector<ParticlePointIndexPair> ParticleNeighborhood::find_neighborhood_points(
const itk::Point<double, 3>& position, int id, double radius) {
std::vector<double> weights;
std::vector<double> distances;
return find_neighborhood_points(position, id, weights, distances, radius);
}

} // namespace shapeworks
131 changes: 37 additions & 94 deletions Libs/Optimize/Neighborhood/ParticleNeighborhood.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,113 +2,56 @@

#include <vector>

#include "Libs/Optimize/Container/GenericContainer.h"
#include "Libs/Optimize/Domain/ParticleDomain.h"
#include "ParticlePointIndexPair.h"
#include "itkDataObject.h"
#include "itkPoint.h"
#include "itkWeakPointer.h"

namespace shapeworks {
class ParticleSystem;

/** \class ParticleNeighborhood
*
* A ParticleNeighborhood is responsible for computing neighborhoods of
* particles. Given a point position in a domain, and a neighborhood radius,
* the ParticleNeighborhood returns a list of points that are neighbors of that
* point. The base class, ParticleNeighborhood, must be subclassed to provide
* functionality; the base class will throw an exception when
* FindNeighborhoodPoints is called.
* point.
*/

class ParticleNeighborhood : public itk::DataObject {
class ParticleNeighborhood {
public:
constexpr static unsigned int VDimension = 3;
/** Standard class typedefs */
typedef ParticleNeighborhood Self;
typedef DataObject Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef itk::WeakPointer<const Self> ConstWeakPointer;

/** Method for creation through the object factory. */
itkNewMacro(Self);

/** Run-time type information (and related methods). */
itkTypeMacro(ParticleNeighborhood, DataObject);

/** Dimensionality of the domain of the particle system. */
itkStaticConstMacro(Dimension, unsigned int, VDimension);

/** Point type used to store particle locations. */
typedef itk::Point<double, VDimension> PointType;

/** Domain type. The Domain object provides bounds and distance
information. */
using DomainType = shapeworks::ParticleDomain;

/** Container type for points. This matches the itkParticleSystem container
type. */
typedef GenericContainer<PointType> PointContainerType;

/** Point list (vector) type. This is the type of list returned by FindNeighborhoodPoints. */
typedef std::vector<ParticlePointIndexPair> PointVectorType;

/** Set/Get the point container. These are the points parsed by the
Neighborhood class when FindNeighborhoodPoints is called. */
itkSetObjectMacro(PointContainer, PointContainerType);
itkGetConstObjectMacro(PointContainer, PointContainerType);

/** Compile a list of points that are within a specified radius of a given
point. The default implementation will throw an exception. */
virtual PointVectorType FindNeighborhoodPoints(const PointType&, int idx, double) const {
itkExceptionMacro("No algorithm for finding neighbors has been specified.");
}
/** This method finds neighborhood points as in the previous method, but also
computes a vector of weights associated with each of those points. */
virtual PointVectorType FindNeighborhoodPoints(const PointType&, int idx, std::vector<double>&, double) const {
itkExceptionMacro("No algorithm for finding neighbors has been specified.");
}
/** This method finds neighborhood points as in the previous method, but also
computes a vector of distances associated with each of those points. */
virtual PointVectorType FindNeighborhoodPoints(const PointType&, int idx, std::vector<double>&, std::vector<double>&,
double) const {
itkExceptionMacro("No algorithm for finding neighbors has been specified.");
}
virtual unsigned int FindNeighborhoodPoints(const PointType&, int idx, double, PointVectorType&) const {
itkExceptionMacro("No algorithm for finding neighbors has been specified.");
return 0;
}

/** Set the Domain that this neighborhood will use. The Domain object is
important because it defines bounds and distance measures. */
// itkSetObjectMacro(Domain, DomainType);
// itkGetConstObjectMacro(Domain, DomainType);
virtual void SetDomain(DomainType::Pointer domain) {
m_Domain = domain;
this->Modified();
};
DomainType::Pointer GetDomain() const { return m_Domain; };

/** For efficiency, itkNeighborhoods are not necessarily observers of
itkParticleSystem, but have specific methods invoked for various events.
AddPosition is called by itkParticleSystem when a particle location is
added. SetPosition is called when a particle location is set.
RemovePosition is called when a particle location is removed.*/
virtual void AddPosition(const PointType& p, unsigned int idx, int threadId = 0) {}
virtual void SetPosition(const PointType& p, unsigned int idx, int threadId = 0) {}
virtual void RemovePosition(unsigned int idx, int threadId = 0) {}

protected:
ParticleNeighborhood() {}
void PrintSelf(std::ostream& os, itk::Indent indent) const { Superclass::PrintSelf(os, indent); }
virtual ~ParticleNeighborhood(){};
explicit ParticleNeighborhood(ParticleSystem* ps, int domain_id = -1) : ps_(ps), domain_id_(domain_id) {}

private:
ParticleNeighborhood(const Self&); // purposely not implemented
void operator=(const Self&); // purposely not implemented
std::vector<ParticlePointIndexPair> find_neighborhood_points(const itk::Point<double, 3>& position, int id,
std::vector<double>& weights,
std::vector<double>& distances, double radius);

std::vector<ParticlePointIndexPair> find_neighborhood_points(const itk::Point<double, 3>& position, int id,
std::vector<double>& weights, double radius);

std::vector<ParticlePointIndexPair> find_neighborhood_points(const itk::Point<double, 3>& position, int id,
double radius);

void set_weighting_enabled(bool is_enabled) { weighting_enabled_ = is_enabled; }

typename PointContainerType::Pointer m_PointContainer;
typename DomainType::Pointer m_Domain;
bool is_weighting_enabled() const { return weighting_enabled_; }

void set_force_euclidean(bool is_enabled) { force_euclidean_ = is_enabled; }

bool is_force_euclidean() const { return force_euclidean_; }

void set_domain(ParticleDomain::Pointer domain) { domain_ = domain; };
ParticleDomain::Pointer get_domain() const { return domain_; };

void set_domain_id(int id) { domain_id_ = id; }

private:
std::pair<std::vector<ParticlePointIndexPair>, std::vector<double>> get_points_in_sphere(
const itk::Point<double, 3>& position, int id, double radius);

ParticleSystem* ps_;
ParticleDomain::Pointer domain_;
int domain_id_{-1};
double flat_cutoff_{0.3};
bool weighting_enabled_{true};
bool force_euclidean_{false};
};

} // end namespace shapeworks
Loading
Loading