Skip to content

Commit 7dfa115

Browse files
Merge branch 'main' into add_pytorch_DLRM_example
2 parents 489c705 + eddc2ab commit 7dfa115

File tree

12 files changed

+96
-83
lines changed

12 files changed

+96
-83
lines changed

merlin/models/torch/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
from merlin.models.torch.router import RouterBlock
3232
from merlin.models.torch.transforms.agg import Concat, Stack
3333

34+
input_schema = schema.input_schema
35+
output_schema = schema.output_schema
36+
target_schema = schema.target_schema
37+
feature_schema = schema.feature_schema
38+
3439
__all__ = [
3540
"Batch",
3641
"BinaryOutput",
@@ -55,6 +60,10 @@
5560
"Concat",
5661
"Stack",
5762
"schema",
63+
"input_schema",
64+
"output_schema",
65+
"feature_schema",
66+
"target_schema",
5867
"DLRMBlock",
5968
"DLRMModel",
6069
]

merlin/models/torch/batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,10 @@ def sample_features(
375375
return sample_batch(data, batch_size, shuffle).features
376376

377377

378-
@schema.output.register_tensor(Batch)
378+
@schema.output_schema.register_tensor(Batch)
379379
def _(input):
380380
output_schema = Schema()
381-
output_schema += schema.output.tensors(input.features)
382-
output_schema += schema.output.tensors(input.targets)
381+
output_schema += schema.output_schema.tensors(input.features)
382+
output_schema += schema.output_schema.tensors(input.targets)
383383

384384
return output_schema

merlin/models/torch/block.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -588,31 +588,31 @@ def set_pre(module: nn.Module, pre: BlockContainer):
588588
return set_pre(module[0], pre)
589589

590590

591-
@schema.input.register(BlockContainer)
591+
@schema.input_schema.register(BlockContainer)
592592
def _(module: BlockContainer, input: Schema):
593-
return schema.input(module[0], input) if module else input
593+
return schema.input_schema(module[0], input) if module else input
594594

595595

596-
@schema.input.register(ParallelBlock)
596+
@schema.input_schema.register(ParallelBlock)
597597
def _(module: ParallelBlock, input: Schema):
598598
if module.pre:
599-
return schema.input(module.pre)
599+
return schema.input_schema(module.pre)
600600

601601
out_schema = Schema()
602602
for branch in module.branches.values():
603-
out_schema += schema.input(branch, input)
603+
out_schema += schema.input_schema(branch, input)
604604

605605
return out_schema
606606

607607

608-
@schema.output.register(ParallelBlock)
608+
@schema.output_schema.register(ParallelBlock)
609609
def _(module: ParallelBlock, input: Schema):
610610
if module.post:
611-
return schema.output(module.post, input)
611+
return schema.output_schema(module.post, input)
612612

613613
output = Schema()
614614
for name, branch in module.branches.items():
615-
branch_schema = schema.output(branch, input)
615+
branch_schema = schema.output_schema(branch, input)
616616

617617
if len(branch_schema) == 1 and branch_schema.first.name == "output":
618618
branch_schema = Schema([branch_schema.first.with_name(name)])
@@ -622,9 +622,9 @@ def _(module: ParallelBlock, input: Schema):
622622
return output
623623

624624

625-
@schema.output.register(BlockContainer)
625+
@schema.output_schema.register(BlockContainer)
626626
def _(module: BlockContainer, input: Schema):
627-
return schema.output(module[-1], input) if module else input
627+
return schema.output_schema(module[-1], input) if module else input
628628

629629

630630
BlockT = TypeVar("BlockT", bound=BlockContainer)
@@ -720,13 +720,13 @@ def _extract_block(main, selection, route, name=None):
720720
if isinstance(main, ParallelBlock):
721721
return _extract_parallel(main, selection, route=route, name=name)
722722

723-
main_schema = schema.input(main)
724-
route_schema = schema.input(route)
723+
main_schema = schema.input_schema(main)
724+
route_schema = schema.input_schema(route)
725725

726726
if main_schema == route_schema:
727727
from merlin.models.torch.inputs.select import SelectFeatures
728728

729-
out_schema = schema.output(main, main_schema)
729+
out_schema = schema.output_schema(main, main_schema)
730730
if len(out_schema) == 1 and out_schema.first.name == "output":
731731
out_schema = Schema([out_schema.first.with_name(name)])
732732

merlin/models/torch/blocks/mlp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55

66
from merlin.models.torch.block import Block
7-
from merlin.models.torch.schema import Schema, output
7+
from merlin.models.torch.schema import Schema, output_schema
88
from merlin.models.torch.transforms.agg import Concat, MaybeAgg
99

1010

@@ -84,8 +84,8 @@ def __init__(
8484
super().__init__(*modules)
8585

8686

87-
@output.register(nn.LazyLinear)
88-
@output.register(nn.Linear)
89-
@output.register(MLPBlock)
90-
def _output_schema_block(module: nn.LazyLinear, input: Schema):
91-
return output.tensors(torch.ones((1, module.out_features), dtype=float))
87+
@output_schema.register(nn.LazyLinear)
88+
@output_schema.register(nn.Linear)
89+
@output_schema.register(MLPBlock)
90+
def _output_schema_block(module: nn.LazyLinear, inputs: Schema):
91+
return output_schema.tensors(torch.ones((1, module.out_features), dtype=float))

merlin/models/torch/inputs/select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ def forward(self, inputs, batch: Batch) -> Dict[str, torch.Tensor]:
201201

202202
@schema.extract.register(SelectKeys)
203203
def _(main, selection, route, name=None):
204-
main_schema = schema.input(main)
205-
route_schema = schema.input(route)
204+
main_schema = schema.input_schema(main)
205+
route_schema = schema.input_schema(route)
206206

207207
diff = main_schema.excluding_by_name(route_schema.column_names)
208208

merlin/models/torch/schema.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema
9797
return super().__call__(module, inputs)
9898
except NotImplementedError:
9999
raise ValueError(
100-
f"Could not get output schema of {module} " "please call mm.trace_schema first."
100+
f"Could not get output schema of {module} " "please call `mm.schema.trace` first."
101101
)
102102

103103
def trace(
@@ -127,7 +127,7 @@ def _func(module: nn.Module, input: Schema) -> Schema:
127127

128128
def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema:
129129
try:
130-
_inputs = input(module)
130+
_inputs = input_schema(module)
131131
inputs = _inputs
132132
except ValueError:
133133
pass
@@ -156,7 +156,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema
156156
return super().__call__(module, inputs)
157157
except NotImplementedError:
158158
raise ValueError(
159-
f"Could not get output schema of {module} " "please call mm.trace_schema first."
159+
f"Could not get output schema of {module} " "please call `mm.schema.trace` first."
160160
)
161161

162162
def trace(
@@ -165,7 +165,7 @@ def trace(
165165
inputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema],
166166
outputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema],
167167
) -> Schema:
168-
_input_schema = input.get_schema(inputs)
168+
_input_schema = input_schema.get_schema(inputs)
169169
_output_schema = self.get_schema(outputs)
170170

171171
try:
@@ -207,8 +207,8 @@ def extract(self, module: nn.Module, selection: Selection, route: nn.Module, nam
207207
return fn(module, selection, route, name=name)
208208

209209

210-
input = _InputSchemaDispatch("input_schema")
211-
output = _OutputSchemaDispatch("output_schema")
210+
input_schema = _InputSchemaDispatch("input_schema")
211+
output_schema = _OutputSchemaDispatch("output_schema")
212212
select = _SelectDispatch("selection")
213213
extract = _ExtractDispatch("extract")
214214

@@ -240,13 +240,13 @@ def _hook(mod: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor):
240240
mod.__input_schemas = ()
241241
mod.__output_schemas = ()
242242

243-
_input_schema = input.trace(mod, inputs[0])
243+
_input_schema = input_schema.trace(mod, inputs[0])
244244
if _input_schema not in mod.__input_schemas:
245245
mod.__input_schemas += (_input_schema,)
246-
mod.__output_schemas += (output.trace(mod, _input_schema, outputs),)
246+
mod.__output_schemas += (output_schema.trace(mod, _input_schema, outputs),)
247247

248248
def add_hook(m):
249-
custom_modules = list(output.dispatcher.registry.keys())
249+
custom_modules = list(output_schema.dispatcher.registry.keys())
250250
if m and isinstance(m, tuple(custom_modules[1:])):
251251
return
252252

@@ -261,7 +261,7 @@ def add_hook(m):
261261
return module_out
262262

263263

264-
def features(module: nn.Module) -> Schema:
264+
def feature_schema(module: nn.Module) -> Schema:
265265
"""Extract the feature schema from a PyTorch Module.
266266
267267
This function operates by applying the `get_feature_schema` method
@@ -293,7 +293,7 @@ def get_feature_schema(module):
293293
return feature_schema
294294

295295

296-
def targets(module: nn.Module) -> Schema:
296+
def target_schema(module: nn.Module) -> Schema:
297297
"""
298298
Extract the target schema from a PyTorch Module.
299299
@@ -484,7 +484,7 @@ def select(self, selection: Selection) -> "Selectable":
484484
raise NotImplementedError()
485485

486486

487-
@output.register_tensor(torch.Tensor)
487+
@output_schema.register_tensor(torch.Tensor)
488488
def _tensor_to_schema(input, name="output"):
489489
kwargs = dict(dims=input.shape[1:], dtype=input.dtype)
490490

@@ -494,13 +494,13 @@ def _tensor_to_schema(input, name="output"):
494494
return Schema([ColumnSchema(name, **kwargs)])
495495

496496

497-
@input.register_tensor(torch.Tensor)
497+
@input_schema.register_tensor(torch.Tensor)
498498
def _(input):
499499
return _tensor_to_schema(input, "input")
500500

501501

502-
@input.register_tensor(Dict[str, torch.Tensor])
503-
@output.register_tensor(Dict[str, torch.Tensor])
502+
@input_schema.register_tensor(Dict[str, torch.Tensor])
503+
@output_schema.register_tensor(Dict[str, torch.Tensor])
504504
def _(input):
505505
output = Schema()
506506
for k, v in sorted(input.items()):
@@ -509,23 +509,27 @@ def _(input):
509509
return output
510510

511511

512-
@input.register_tensor(Tuple[torch.Tensor])
513-
@output.register_tensor(Tuple[torch.Tensor])
514-
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor])
515-
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor])
516-
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
517-
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
518-
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
519-
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
520-
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
521-
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
522-
@input.register_tensor(
512+
@input_schema.register_tensor(Tuple[torch.Tensor])
513+
@output_schema.register_tensor(Tuple[torch.Tensor])
514+
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor])
515+
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor])
516+
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
517+
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
518+
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
519+
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
520+
@input_schema.register_tensor(
521+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
522+
)
523+
@output_schema.register_tensor(
524+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
525+
)
526+
@input_schema.register_tensor(
523527
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
524528
)
525-
@output.register_tensor(
529+
@output_schema.register_tensor(
526530
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
527531
)
528-
@input.register_tensor(
532+
@input_schema.register_tensor(
529533
Tuple[
530534
torch.Tensor,
531535
torch.Tensor,
@@ -536,7 +540,7 @@ def _(input):
536540
torch.Tensor,
537541
]
538542
)
539-
@output.register_tensor(
543+
@output_schema.register_tensor(
540544
Tuple[
541545
torch.Tensor,
542546
torch.Tensor,
@@ -547,7 +551,7 @@ def _(input):
547551
torch.Tensor,
548552
]
549553
)
550-
@input.register_tensor(
554+
@input_schema.register_tensor(
551555
Tuple[
552556
torch.Tensor,
553557
torch.Tensor,
@@ -559,7 +563,7 @@ def _(input):
559563
torch.Tensor,
560564
]
561565
)
562-
@output.register_tensor(
566+
@output_schema.register_tensor(
563567
Tuple[
564568
torch.Tensor,
565569
torch.Tensor,
@@ -571,7 +575,7 @@ def _(input):
571575
torch.Tensor,
572576
]
573577
)
574-
@input.register_tensor(
578+
@input_schema.register_tensor(
575579
Tuple[
576580
torch.Tensor,
577581
torch.Tensor,
@@ -584,7 +588,7 @@ def _(input):
584588
torch.Tensor,
585589
]
586590
)
587-
@output.register_tensor(
591+
@output_schema.register_tensor(
588592
Tuple[
589593
torch.Tensor,
590594
torch.Tensor,
@@ -597,7 +601,7 @@ def _(input):
597601
torch.Tensor,
598602
]
599603
)
600-
@input.register_tensor(
604+
@input_schema.register_tensor(
601605
Tuple[
602606
torch.Tensor,
603607
torch.Tensor,
@@ -611,7 +615,7 @@ def _(input):
611615
torch.Tensor,
612616
]
613617
)
614-
@output.register_tensor(
618+
@output_schema.register_tensor(
615619
Tuple[
616620
torch.Tensor,
617621
torch.Tensor,

tests/unit/torch/inputs/test_select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def test_forward(self):
7272

7373
outputs = mm.schema.trace(block, self.batch.features["session_id"], batch=self.batch)
7474
assert len(outputs) == 5
75-
assert mm.schema.input(block).column_names == ["input"]
76-
assert mm.schema.features(block).column_names == [
75+
assert mm.input_schema(block).column_names == ["input"]
76+
assert mm.feature_schema(block).column_names == [
7777
"user_id",
7878
"country",
7979
"user_age",

0 commit comments

Comments
 (0)