@@ -97,7 +97,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema
97
97
return super ().__call__ (module , inputs )
98
98
except NotImplementedError :
99
99
raise ValueError (
100
- f"Could not get output schema of { module } " "please call mm.trace_schema first."
100
+ f"Could not get output schema of { module } " "please call ` mm.schema.trace` first."
101
101
)
102
102
103
103
def trace (
@@ -127,7 +127,7 @@ def _func(module: nn.Module, input: Schema) -> Schema:
127
127
128
128
def __call__ (self , module : nn .Module , inputs : Optional [Schema ] = None ) -> Schema :
129
129
try :
130
- _inputs = input (module )
130
+ _inputs = input_schema (module )
131
131
inputs = _inputs
132
132
except ValueError :
133
133
pass
@@ -156,7 +156,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema
156
156
return super ().__call__ (module , inputs )
157
157
except NotImplementedError :
158
158
raise ValueError (
159
- f"Could not get output schema of { module } " "please call mm.trace_schema first."
159
+ f"Could not get output schema of { module } " "please call ` mm.schema.trace` first."
160
160
)
161
161
162
162
def trace (
@@ -165,7 +165,7 @@ def trace(
165
165
inputs : Union [torch .Tensor , Dict [str , torch .Tensor ], Schema ],
166
166
outputs : Union [torch .Tensor , Dict [str , torch .Tensor ], Schema ],
167
167
) -> Schema :
168
- _input_schema = input .get_schema (inputs )
168
+ _input_schema = input_schema .get_schema (inputs )
169
169
_output_schema = self .get_schema (outputs )
170
170
171
171
try :
@@ -207,8 +207,8 @@ def extract(self, module: nn.Module, selection: Selection, route: nn.Module, nam
207
207
return fn (module , selection , route , name = name )
208
208
209
209
210
- input = _InputSchemaDispatch ("input_schema" )
211
- output = _OutputSchemaDispatch ("output_schema" )
210
+ input_schema = _InputSchemaDispatch ("input_schema" )
211
+ output_schema = _OutputSchemaDispatch ("output_schema" )
212
212
select = _SelectDispatch ("selection" )
213
213
extract = _ExtractDispatch ("extract" )
214
214
@@ -240,13 +240,13 @@ def _hook(mod: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor):
240
240
mod .__input_schemas = ()
241
241
mod .__output_schemas = ()
242
242
243
- _input_schema = input .trace (mod , inputs [0 ])
243
+ _input_schema = input_schema .trace (mod , inputs [0 ])
244
244
if _input_schema not in mod .__input_schemas :
245
245
mod .__input_schemas += (_input_schema ,)
246
- mod .__output_schemas += (output .trace (mod , _input_schema , outputs ),)
246
+ mod .__output_schemas += (output_schema .trace (mod , _input_schema , outputs ),)
247
247
248
248
def add_hook (m ):
249
- custom_modules = list (output .dispatcher .registry .keys ())
249
+ custom_modules = list (output_schema .dispatcher .registry .keys ())
250
250
if m and isinstance (m , tuple (custom_modules [1 :])):
251
251
return
252
252
@@ -261,7 +261,7 @@ def add_hook(m):
261
261
return module_out
262
262
263
263
264
- def features (module : nn .Module ) -> Schema :
264
+ def feature_schema (module : nn .Module ) -> Schema :
265
265
"""Extract the feature schema from a PyTorch Module.
266
266
267
267
This function operates by applying the `get_feature_schema` method
@@ -293,7 +293,7 @@ def get_feature_schema(module):
293
293
return feature_schema
294
294
295
295
296
- def targets (module : nn .Module ) -> Schema :
296
+ def target_schema (module : nn .Module ) -> Schema :
297
297
"""
298
298
Extract the target schema from a PyTorch Module.
299
299
@@ -484,7 +484,7 @@ def select(self, selection: Selection) -> "Selectable":
484
484
raise NotImplementedError ()
485
485
486
486
487
- @output .register_tensor (torch .Tensor )
487
+ @output_schema .register_tensor (torch .Tensor )
488
488
def _tensor_to_schema (input , name = "output" ):
489
489
kwargs = dict (dims = input .shape [1 :], dtype = input .dtype )
490
490
@@ -494,13 +494,13 @@ def _tensor_to_schema(input, name="output"):
494
494
return Schema ([ColumnSchema (name , ** kwargs )])
495
495
496
496
497
- @input .register_tensor (torch .Tensor )
497
+ @input_schema .register_tensor (torch .Tensor )
498
498
def _ (input ):
499
499
return _tensor_to_schema (input , "input" )
500
500
501
501
502
- @input .register_tensor (Dict [str , torch .Tensor ])
503
- @output .register_tensor (Dict [str , torch .Tensor ])
502
+ @input_schema .register_tensor (Dict [str , torch .Tensor ])
503
+ @output_schema .register_tensor (Dict [str , torch .Tensor ])
504
504
def _ (input ):
505
505
output = Schema ()
506
506
for k , v in sorted (input .items ()):
@@ -509,23 +509,27 @@ def _(input):
509
509
return output
510
510
511
511
512
- @input .register_tensor (Tuple [torch .Tensor ])
513
- @output .register_tensor (Tuple [torch .Tensor ])
514
- @input .register_tensor (Tuple [torch .Tensor , torch .Tensor ])
515
- @output .register_tensor (Tuple [torch .Tensor , torch .Tensor ])
516
- @input .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor ])
517
- @output .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor ])
518
- @input .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ])
519
- @output .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ])
520
- @input .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ])
521
- @output .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ])
522
- @input .register_tensor (
512
+ @input_schema .register_tensor (Tuple [torch .Tensor ])
513
+ @output_schema .register_tensor (Tuple [torch .Tensor ])
514
+ @input_schema .register_tensor (Tuple [torch .Tensor , torch .Tensor ])
515
+ @output_schema .register_tensor (Tuple [torch .Tensor , torch .Tensor ])
516
+ @input_schema .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor ])
517
+ @output_schema .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor ])
518
+ @input_schema .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ])
519
+ @output_schema .register_tensor (Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ])
520
+ @input_schema .register_tensor (
521
+ Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]
522
+ )
523
+ @output_schema .register_tensor (
524
+ Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]
525
+ )
526
+ @input_schema .register_tensor (
523
527
Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]
524
528
)
525
- @output .register_tensor (
529
+ @output_schema .register_tensor (
526
530
Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]
527
531
)
528
- @input .register_tensor (
532
+ @input_schema .register_tensor (
529
533
Tuple [
530
534
torch .Tensor ,
531
535
torch .Tensor ,
@@ -536,7 +540,7 @@ def _(input):
536
540
torch .Tensor ,
537
541
]
538
542
)
539
- @output .register_tensor (
543
+ @output_schema .register_tensor (
540
544
Tuple [
541
545
torch .Tensor ,
542
546
torch .Tensor ,
@@ -547,7 +551,7 @@ def _(input):
547
551
torch .Tensor ,
548
552
]
549
553
)
550
- @input .register_tensor (
554
+ @input_schema .register_tensor (
551
555
Tuple [
552
556
torch .Tensor ,
553
557
torch .Tensor ,
@@ -559,7 +563,7 @@ def _(input):
559
563
torch .Tensor ,
560
564
]
561
565
)
562
- @output .register_tensor (
566
+ @output_schema .register_tensor (
563
567
Tuple [
564
568
torch .Tensor ,
565
569
torch .Tensor ,
@@ -571,7 +575,7 @@ def _(input):
571
575
torch .Tensor ,
572
576
]
573
577
)
574
- @input .register_tensor (
578
+ @input_schema .register_tensor (
575
579
Tuple [
576
580
torch .Tensor ,
577
581
torch .Tensor ,
@@ -584,7 +588,7 @@ def _(input):
584
588
torch .Tensor ,
585
589
]
586
590
)
587
- @output .register_tensor (
591
+ @output_schema .register_tensor (
588
592
Tuple [
589
593
torch .Tensor ,
590
594
torch .Tensor ,
@@ -597,7 +601,7 @@ def _(input):
597
601
torch .Tensor ,
598
602
]
599
603
)
600
- @input .register_tensor (
604
+ @input_schema .register_tensor (
601
605
Tuple [
602
606
torch .Tensor ,
603
607
torch .Tensor ,
@@ -611,7 +615,7 @@ def _(input):
611
615
torch .Tensor ,
612
616
]
613
617
)
614
- @output .register_tensor (
618
+ @output_schema .register_tensor (
615
619
Tuple [
616
620
torch .Tensor ,
617
621
torch .Tensor ,
0 commit comments