Skip to content

Commit

Permalink
Take review comments into account
Browse files Browse the repository at this point in the history
  • Loading branch information
cphyc committed Sep 6, 2023
1 parent 9384474 commit dc614e1
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions yt/frontends/rockstar/data_structures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import glob
import os
from functools import cached_property
from typing import List, Optional
from typing import Any, List, Optional

import numpy as np

Expand All @@ -11,6 +11,7 @@
from yt.geometry.particle_geometry_handler import ParticleIndex
from yt.utilities import fortran_utils as fpu
from yt.utilities.cosmology import Cosmology
from yt.utilities.exceptions import YTFieldNotFound

from .definitions import header_dt
from .fields import RockstarFieldInfo
Expand All @@ -20,7 +21,7 @@ class RockstarBinaryFile(HaloCatalogFile):
header: dict
_position_offset: int
_member_offset: int
_Npart: np.array
_Npart: "np.ndarray[Any, np.dtype[np.int64]]"
_ids_halos: List[int]
_file_size: int

Expand All @@ -46,7 +47,9 @@ def __init__(self, ds, io, filename, file_id, range):

super().__init__(ds, io, filename, file_id, range)

def _read_member(self, ihalo: int) -> Optional[np.array]:
def _read_member(
self, ihalo: int
) -> Optional["np.ndarray[Any, np.dtype[np.int64]]"]:
if ihalo not in self._ids_halos:
return None

Expand All @@ -59,7 +62,7 @@ def _read_member(self, ihalo: int) -> Optional[np.array]:
ids = np.fromfile(f, dtype=np.int64, count=self._Npart[ind_halo])
return ids

def _read_particle_positions(self, ptype, f=None):
def _read_particle_positions(self, ptype: str, f=None):
"""
Read all particle positions in this file.
"""
Expand Down Expand Up @@ -166,32 +169,47 @@ def _is_valid(cls, filename, *args, **kwargs):
return True
return False

def halo(self, halo_id, ptype="DM"):
def halo(self, ptype, particle_identifier):
return RockstarHaloContainer(
halo_id,
ptype,
particle_identifier,
parent_ds=None,
halo_ds=self,
)


class RockstarHaloContainer:
def __init__(self, ptype, particle_identifier, parent_ds, halo_ds):
# if ptype not in parent_ds.particle_types_raw:
# raise RuntimeError(
# f'Possible halo types are {parent_ds.particle_types_raw}, supplied "{ptype}".'
# )
def __init__(self, ptype, particle_identifier, *, parent_ds, halo_ds):
if ptype not in halo_ds.particle_types_raw:
raise RuntimeError(
f'Possible halo types are {halo_ds.particle_types_raw}, supplied "{ptype}".'
)

self.ds = parent_ds
self.halo_ds = halo_ds
self.ptype = ptype
self.particle_identifier = particle_identifier

def __repr__(self):
return "%s_%s_%09d" % (self.ds, self.ptype, self.particle_identifier)
return "%s_%s_%09d" % (self.halo_ds, self.ptype, self.particle_identifier)

def __getitem__(self, key):
return self.region[key]
if isinstance(key, tuple):
ptype, field = key
else:
ptype = self.ptype
field = key

data = {
"mass": self.mass,
"position": self.position,
"velocity": self.velocity,
"member_ids": self.member_ids,
}
if ptype == "halos" and field in data:
return data[field]

raise YTFieldNotFound((ptype, field), dataset=self.ds)

@cached_property
def ihalo(self):
Expand Down

0 comments on commit dc614e1

Please sign in to comment.