Skip to content

Commit

Permalink
t
Browse files Browse the repository at this point in the history
  • Loading branch information
SumGuo-88 committed Jan 3, 2025
1 parent bb9fbe1 commit 6a65561
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions source/tests/pt/test_make_stat_input.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np
import torch
from torch.utils.data import DataLoader
from deepmd.pt.utils.stat import make_stat_input,compute_output_stats
from deepmd.pt.utils.dataset import DeepmdDataSetForLoader
from deepmd.utils.data import DataRequirementItem

def collate_fn(batch):
if isinstance(batch, dict):
batch = [batch]
collated_batch = {}
for key in batch[0].keys():
for key in batch[0].keys():
data_list = [d[key] for d in batch]
if isinstance(data_list[0], np.ndarray):
data_np = np.stack(data_list)
Expand All @@ -20,7 +20,6 @@ def collate_fn(batch):
collated_batch[key] = torch.tensor(data_list)
return collated_batch


class TestMakeStatInput(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -43,7 +42,7 @@ def setUpClass(cls):
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=1,
batch_size=1,
num_workers=0,
drop_last=False,
collate_fn=collate_fn,
Expand All @@ -55,7 +54,7 @@ def test_make_stat_input(self):
lst = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
nbatches=nbatches,
nbatches=1,
min_frames_per_element_forstat=1,
enable_element_completion=True,
)
Expand Down

0 comments on commit 6a65561

Please sign in to comment.