Skip to content

Commit

Permalink
Merge pull request #337 from laserkelvin/batchschema-pyg-collate-fix
Browse files Browse the repository at this point in the history
`BatchSchema` collate fix for PyG graphs
  • Loading branch information
laserkelvin authored Dec 20, 2024
2 parents a28a5f5 + cd6b530 commit 32d690c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions matsciml/datasets/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def collate_samples_into_batch_schema(samples: list[DataSampleSchema]) -> object
Instance of a ``BatchSchema`` object. This is not explicitly annotated
since the model/class is defined dynamically based off incoming data.
"""
ref_schema = samples[0].schema()
ref_schema = samples[0].model_json_schema()
# initial keys are going to hold the main structure of the schema
schema_to_generate = {
"num_atoms": (NDArray[Shape["*"], int] | torch.LongTensor, ...),
Expand All @@ -1103,7 +1103,7 @@ def collate_samples_into_batch_schema(samples: list[DataSampleSchema]) -> object
schema_to_generate[key] = (type(data), ...)
collected_data[key] = data
collected_data["num_edges"] = _concatenate_data_list(
[sample.graph.batch_num_edges() for sample in samples]
[sample.graph.edge_index.size(-1) for sample in samples]
).long()
else:
from dgl import DGLGraph, batch
Expand Down

0 comments on commit 32d690c

Please sign in to comment.