Skip to content

Commit

Permalink
Merge pull request #8 from SNEWS2/feature-get-fields
Browse files Browse the repository at this point in the history
Feature get fields
  • Loading branch information
justinvasel authored Aug 19, 2024
2 parents 40d9f04 + d1ace16 commit a7d577b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
6 changes: 1 addition & 5 deletions docs/spec/index.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
---
title: SNEWS Data Specification
summary: A brief description of my document.
authors:
- Waylan Limberg
- Tom Christie
date: 2018-07-10
summary: TODO.
---
34 changes: 19 additions & 15 deletions snews/models/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,18 @@
"TimingTierMessage",
"compatible_message_types",
"create_messages",
"get_fields",
]


# .................................................................................................
def get_fields(model, required=False) -> list:
"""
Return a list of all or required fields for the message.
"""
return [k for k, v in model.model_fields.items() if v.is_required() or not required]


# .................................................................................................
def convert_timestamp_to_ns_precision(timestamp: Union[str, datetime, np.datetime64]) -> str:
"""
Expand Down Expand Up @@ -64,6 +73,9 @@ class MessageBase(BaseModel):

model_config = ConfigDict(validate_assignment=True)

# NOTE: This field is optional from the user's perspective, but during model validation,
# it will be automatically generated if not already specified, so in practice this field
# will never be empty.
id: Optional[str] = Field(
default=None,
title="Human-readable message ID",
Expand Down Expand Up @@ -155,18 +167,6 @@ def _format_id(self):

return self

def fields(self):
"""
Return a list of fields for the message.
"""
return list(self.model_fields.keys())

def required_fields(self):
"""
Return a list of required fields for the message.
"""
return [k for k, v in self.model_fields.items() if v.is_required()]


# .................................................................................................
class DetectorMessageBase(MessageBase):
Expand Down Expand Up @@ -423,6 +423,7 @@ def compatible_message_types(include_heartbeats=False, **kwargs) -> list:
message_type(**kwargs)
compatible_message_types.append(message_type)

# Coincidence tier messages can also double as heartbeats
if include_heartbeats and message_type == CoincidenceTierMessage:
compatible_message_types.append(HeartbeatMessage)

Expand All @@ -440,10 +441,13 @@ def create_messages(**kwargs) -> list:

messages = []
for message_type in compatible_message_types(**kwargs):
if message_type == HeartbeatMessage:
messages.append(message_type(detector_status="ON", **kwargs))
if message_type == HeartbeatMessage and "detector_status" not in kwargs.keys():
message = message_type(detector_status="ON", **kwargs)

else:
messages.append(message_type(**kwargs))
message = message_type(**kwargs)

messages.append(message)

if len(messages) == 0:
raise ValueError("No compatible message types found")
Expand Down

0 comments on commit a7d577b

Please sign in to comment.