-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrating a language model with ULTRA #9
Comments
Hi!
Those aren't really tail representations anymore because message passing updates all node states and starts with initial node states called Lines 137 to 140 in 33c6e6b
In the GNN layer code, we use those Lines 192 to 194 in 33c6e6b
That is, in each GNN layer we actually do have an interaction function of (initial) head and (current) node states. Adding other entity/relation features seems quite straightforward, I see two possible ways:
This is a less expressive way but won't require re-training the model from scratch. In any case, if your graphs have >10k nodes, I'd recommend projecting down the LLM features (usually 768d or more, depends on the LLM) to smaller dimension (32/64d) in order to fit the full-batch GNN layer onto a GPU. |
Just an update, I tested all three suggested methods. Generally, the pre-trained embedding was added to the EntityNBFNet
Slightly modified the code and added the following: if lm_vectors is not None:
# can decide whether to freeze or not...
self.lm_vectors = nn.Embedding.from_pretrained(lm_vectors, freeze=True)
self.merge_linear = nn.Linear(feature_dim, 64) Per your 1st suggestion, it seems like training from scratch with the following: # .....original code.....
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
# interaction of boundary with lm vectors
if self.lm_vectors is not None:
lm_vectors = self.lm_vectors(h_index) # mistake - see @migalkin
lm_vectors = lm_vectors.unsqueeze(1).expand(-1, data.num_nodes, -1)
boundary = torch.cat([boundary, lm_vectors], dim=-1)
boundary = self.merge_linear(boundary) The |
Those lines lm_vectors = self.lm_vectors(h_index)
lm_vectors = lm_vectors.unsqueeze(1).expand(-1, data.num_nodes, -1) would take only lm features of head nodes in the batch and copy them to all nodes in the graph - is this what you want? If you want to initialize each node with its own lm feature, then you don't need to call the embedding layer and just take all of its weights (and repeat along the batch dimension) like |
@daniel4x Have you made this into a separate branch? I would be really interested to see your code and hear which of the integration methods performed best for you. In my use case I have a mixture of LLM features (edges) and a set of pre-trained embedding features for each of the node types. Your experience that node features offer significant performance benefits tallies a lot with mine so it would be great to integrate this into my code. |
Hi @migalkin,
First of all, Kudus for your work!!!! (both ULTRA and nodepiece 😄 ) .
I'm curious to hear your thoughts about integrating a language model (LM) with ULTRA.
Previously, with other KG models such as nodepiece, it was straightforward to integrate a language model to enrich the graph embeddings with textual embeddings.
I used to concat both the entity textual and graph representations and maybe apply additional layers to match the desired dimensions.
example:
So far, it worked well and boosted the model's performance from ~50% when used with transE and up to ~30% with nodepiece on my datasets.
With ULTRA I guess that I have some additional work to do :)...
I started with understanding how the entity representation is "generated" on the fly:
https://github.com/DeepGraphLearning/ULTRA/blob/33c6e6b8e522aed3d33f6ce5d3a1883ca9284718/ultra/models.py#L166-L174C4
I understand that from that point only the tail representations are used to feed the MLP.
I replaced the MLP with my own MLP - to match the dim to the concatenation of both representations. Then, I tried to contact both, output from ULTRA with the textual entity representation. As far as I understand, due to this "late" concatenation only the tail entity textual representation will be used.
When tested, I got (almost) the same results with/without the textual representation.
Not sure what I expect to hear :), but I hope you may have an idea for combining both representations.
The text was updated successfully, but these errors were encountered: