Skip to content

Commit

Permalink
Merge pull request #14 from badranX/qdkt
Browse files Browse the repository at this point in the history
optemize QDKT laplacian matrix generation
  • Loading branch information
kervias authored Dec 3, 2023
2 parents ea1f688 + 8395d57 commit 04f2622
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions edustudio/atom_op/mid2cache/single/M2C_QDKT_OP.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import networkx as nx
from ..common import BaseMid2Cache
import torch
from torch.nn import functional as F
import numpy as np


Expand All @@ -12,11 +14,19 @@ def process(self, **kwargs):
self.num_q = dt_info['exer_count']
self.num_c = dt_info['cpt_count']
self.Q_mat = kwargs['Q_mat']
graph = self.generate_graph()
laplacian_matrix = self.laplacian_matrix(graph)
laplacian_matrix = self.laplacian_matrix_by_vectorization()
kwargs['laplacian_matrix'] = laplacian_matrix
return kwargs

def laplacian_matrix_by_vectorization(self):
normQ = F.normalize(self.Q_mat.float(), p=2, dim=-1)
A = torch.mm(normQ, normQ.T) > (1 - 1/len(normQ))
A = A.int() #Adjacency matrix
D = A.sum(-1, dtype=torch.int32)
diag_idx = [range(len(A)), range(len(A))]
A[diag_idx] = D - A[diag_idx]
return A

def generate_graph(self):

graph = nx.Graph()
Expand Down

0 comments on commit 04f2622

Please sign in to comment.