Skip to content

Commit

Permalink
"QC Normalization"
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDAVAUX committed Jun 25, 2024
1 parent d3ba5e1 commit 68b3761
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 27 deletions.
1 change: 1 addition & 0 deletions install_pytorch3d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
# Command to determine PyTorch version string
PYTORCH_VERSION=$(python -c "import sys; import torch; pyt_version_str=torch.__version__.split('+')[0].replace('.', ''); version_str=''.join([f'py3{sys.version_info.minor}_cu', torch.version.cuda.replace('.', ''), f'_pyt{pyt_version_str}']); print(version_str)")


# Command to install PyTorch3D
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/${PYTORCH_VERSION}/download.html
1 change: 0 additions & 1 deletion shapeaxi/saxi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def set_wm_as_texture(self, sphere, wm_path):
def getitem_per_hemisphere(self, hemisphere, idx):
row = self.df.loc[idx]
sub_session = '_' + row['eventname']
print(row['Subject_ID'], row['eventname'])
path_to_fs_data = os.path.join(self.freesurfer_path, row['Subject_ID'], row['Subject_ID'] + sub_session, 'surf')

# Load Data
Expand Down
22 changes: 14 additions & 8 deletions shapeaxi/saxi_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from pytorch3d.renderer import (
FoVPerspectiveCameras, PerspectiveCameras, look_at_rotation,
RasterizationSettings, MeshRenderer, MeshRasterizer, MeshRendererWithFragments, BlendParams,
SoftSilhouetteShader, HardPhongShader, SoftPhongShader, AmbientLights, PointLights, TexturesUV, TexturesVertex, TexturesAtlas
SoftSilhouetteShader, HardPhongShader, SoftPhongShader, AmbientLights, PointLights,
TexturesUV, TexturesVertex, TexturesAtlas, Textures
)

import json
Expand Down Expand Up @@ -1278,10 +1279,11 @@ def test_step(self):
class AttentionRing(nn.Module):
def __init__(self, in_units, out_units, neigh_orders):
super().__init__()
self.num_heads = 8
# neigh_order: (Nviews previous level, Neighbors next level)
self.neigh_orders = neigh_orders
#MHA

# self.MHA = MultiHeadAttentionModule(in_units, self.num_heads, batch_first=True)
self.Att = SelfAttention(in_units, out_units, dim=2)

def forward(self, query, values):
Expand All @@ -1298,6 +1300,7 @@ def forward(self, query, values):
query = query[:, self.neigh_orders] # (batch, Nv_{n-1}, Idx_{n}, features)
values = values[:, self.neigh_orders] # (batch, Nv_{n-1}, Idx_{n}, features)

# x, _ = self.MHA(query, values, values)
context_vector, score = self.Att(query, values)

return context_vector, score
Expand Down Expand Up @@ -1456,15 +1459,18 @@ def get_features(self,V,F,VF,FF,side):
x, _ = getattr(self, f'Attention{side}')(x,values)

return x



def render(self,V,F,VF,FF):
textures = TexturesVertex(verts_features=VF[:, :, :3])
meshes = Meshes(
verts=V,
faces=F,
textures=textures
)
# textures = TexturesVertex(verts_features=VF[:, :, :3])

dummy_textures = [torch.ones((v.shape[0], 3), device=v.device) for v in V] # (V, C) for each mesh
dummy_textures = torch.stack(dummy_textures) # (N, V, C)

textures = Textures(verts_rgb=dummy_textures)
meshes = Meshes(verts=V, faces=F, textures=textures)

PF = []
for i in range(self.nbr_cam):
pix_to_face = self.GetView(meshes,i)
Expand Down
36 changes: 18 additions & 18 deletions shapeaxi/saxi_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,26 +264,26 @@ def SaxiRing_train(args, checkpoint_callback, mount_point, train, val, test, ear
# textures=textures_r
# )

data.setup()
# data.setup()

for batch in data.train_dataloader():
VL, FL, VFL, FFL, VR, FR, VFR, FFR, Y = batch
# for batch in data.train_dataloader():
# VL, FL, VFL, FFL, VR, FR, VFR, FFR, Y = batch

textures_l = TexturesVertex(verts_features=VL[:, :, :3])
meshes = Meshes(
verts=VL,
faces=FL,
textures=textures_l
)

textures_r = TexturesVertex(verts_features=VR[:, :, :3])
meshes = Meshes(
verts=VR,
faces=FR,
textures=textures_r
)

quit()
# textures_l = TexturesVertex(verts_features=VL[:, :, :3])
# meshes = Meshes(
# verts=VL,
# faces=FL,
# textures=textures_l
# )

# textures_r = TexturesVertex(verts_features=VR[:, :, :3])
# meshes = Meshes(
# verts=VR,
# faces=FR,
# textures=textures_r
# )

# quit()

#Creation of our model
SAXINETS = getattr(saxi_nets, args.nn)
Expand Down

0 comments on commit 68b3761

Please sign in to comment.