@@ -134,11 +134,11 @@ def create(self, ename, data, icnfg=None, rname=None, shape=(1,), lat=None):
134
134
self .ename .append (ename )
135
135
136
136
pyobs .check_type (data , "data" , list , np .ndarray )
137
- if isinstance (data [0 ], ( int , np . int32 , np . int64 ) ):
137
+ if pyobs . is_type (data [0 ], pyobs . types . INT ):
138
138
data = [np .array (data ).astype (pyobs .int )]
139
- elif isinstance (data [0 ], ( float , np . float32 , np . float64 ) ):
139
+ elif pyobs . is_type (data [0 ], pyobs . types . FLOAT ):
140
140
data = [np .array (data ).astype (pyobs .double )]
141
- elif isinstance (data [0 ], ( complex , np . complex64 , np . complex128 ) ):
141
+ elif pyobs . is_type (data [0 ], pyobs . types . COMPLEX ):
142
142
data = [np .array (data ).astype (pyobs .complex )]
143
143
else :
144
144
pyobs .assertion (True , "Data type not supported" )
@@ -147,11 +147,11 @@ def create(self, ename, data, icnfg=None, rname=None, shape=(1,), lat=None):
147
147
nc = [len (data [ir ]) // self .size for ir in range (R )]
148
148
if rname is None :
149
149
rname = list (range (R ))
150
- elif isinstance (rname , ( str , int , np . int32 , np . int64 ) ):
150
+ elif pyobs . is_type (rname , str , pyobs . types . INT ):
151
151
rname = [rname ]
152
152
if icnfg is None :
153
153
icnfg = [range (_nc ) for _nc in nc ]
154
- elif isinstance (icnfg [0 ], ( int , np . int32 , np . int64 ) ):
154
+ elif pyobs . is_type (icnfg [0 ], pyobs . types . INT ):
155
155
icnfg = [icnfg ]
156
156
157
157
pyobs .check_type (rname , "rname" , list )
@@ -409,7 +409,7 @@ def f(x):
409
409
return transform (self , f )
410
410
411
411
def __getitem__ (self , args ):
412
- if isinstance (args , ( int , np . int32 , np . int64 , slice , np .ndarray ) ):
412
+ if pyobs . is_type (args , pyobs . types . INT , slice , np .ndarray ):
413
413
args = [args ]
414
414
na = len (args )
415
415
pyobs .assertion (na == len (self .shape ), "Unexpected argument" )
@@ -420,12 +420,12 @@ def f(x):
420
420
return transform (self , f )
421
421
422
422
def __setitem__ (self , args , yobs ):
423
- if isinstance (args , ( int , np . int32 , np . int64 , slice , np .ndarray ) ):
423
+ if pyobs . is_type (args , pyobs . types . INT , slice , np .ndarray ):
424
424
args = [args ]
425
425
else :
426
426
args = [
427
427
[a ]
428
- if isinstance (a , ( int , np . int32 , np . int64 , slice , np .ndarray ) )
428
+ if pyobs . is_type (a , pyobs . types . INT , slice , np .ndarray )
429
429
else a
430
430
for a in args
431
431
]
@@ -454,6 +454,41 @@ def __setitem__(self, args, yobs):
454
454
)
455
455
self .cdata [key ].assign (submask , yobs .cdata [key ])
456
456
457
+ def rt (self , axis = None ):
458
+ """
459
+ Removes trivial tensor indices reducing the dimensionality of the observable.
460
+
461
+ Parameters:
462
+ axis (int, list or array): axis to be considered for tensor removal.
463
+
464
+ Returns:
465
+ observable
466
+
467
+ Examples:
468
+ >>> obs.shape
469
+ (10, 3, 1)
470
+ >>> obs.rt().shape
471
+ (10, 3)
472
+ """
473
+ ## duplicate of pyobs.remove_tensor, to be eventually deprecated
474
+ Nd = len (self .shape )
475
+ if axis is None :
476
+ selection = [True ] * Nd
477
+ else :
478
+ selection = [False ] * Nd
479
+ for a in pyobs .to_list (axis ):
480
+ selection [a ] = True
481
+
482
+ new_shape = []
483
+ for mu in range (Nd ):
484
+ if (self .shape [mu ] == 1 ) and (selection [mu ] is True ):
485
+ continue
486
+ new_shape .append (self .shape [mu ])
487
+ if not new_shape :
488
+ new_shape .append (1 )
489
+ return pyobs .reshape (self , tuple (new_shape ))
490
+
491
+
457
492
##################################
458
493
# overloaded basic math operations
459
494
@@ -654,14 +689,18 @@ class provides additional details for the automatic or manual
654
689
>>> einfo = {'A': errinfo(Stau=3.0), 'B': errinfo(W=30)}
655
690
>>> [v,e] = obsC.error(errinfo=einfo,plot=True)
656
691
"""
657
- [sigma , sigma_tot , _ ] = self .error_core (errinfo , plot , pfile )
658
-
659
- if plot : # pragma: no cover
660
- h = [len (self .ename ), len (self .cdata )]
661
- if sum (h ) > 1 :
662
- plot_piechart (self .description , sigma , sigma_tot .real )
663
-
664
- return [self .mean , np .sqrt (sigma_tot )]
692
+ def error_real (obs ):
693
+ [sigma , sigma_tot , _ ] = obs .error_core (errinfo , plot , pfile )
694
+
695
+ if plot : # pragma: no cover
696
+ h = [len (obs .ename ), len (obs .cdata )]
697
+ if sum (h ) > 1 :
698
+ plot_piechart (obs .description , sigma , sigma_tot .real )
699
+ return sigma_tot
700
+
701
+ if np .iscomplexobj (self .mean ):
702
+ return [self .mean , np .sqrt (error_real (self .real ())) + 1j * np .sqrt (error_real (self .imag ()))]
703
+ return [self .mean , np .sqrt (error_real (self ))]
665
704
666
705
def error_breakdown (self , errinfo = {}):
667
706
"""
0 commit comments