Skip to content

Commit b00efc2

Browse files
committed
rt method added to observable class; remove_tensor will be deprecated; internal types cleanup
1 parent 72ebf2f commit b00efc2

File tree

7 files changed

+90
-29
lines changed

7 files changed

+90
-29
lines changed

pyobs/core/data.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626

2727

2828
def is_int(x):
29-
return isinstance(x, (int, np.int32, np.int64))
30-
29+
return pyobs.is_type(x, pyobs.types.INT)
3130

3231
def expand_data(data, idx, shape):
3332
v = np.prod(shape)

pyobs/core/ndobs.py

+55-16
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ def create(self, ename, data, icnfg=None, rname=None, shape=(1,), lat=None):
134134
self.ename.append(ename)
135135

136136
pyobs.check_type(data, "data", list, np.ndarray)
137-
if isinstance(data[0], (int, np.int32, np.int64)):
137+
if pyobs.is_type(data[0], pyobs.types.INT):
138138
data = [np.array(data).astype(pyobs.int)]
139-
elif isinstance(data[0], (float, np.float32, np.float64)):
139+
elif pyobs.is_type(data[0], pyobs.types.FLOAT):
140140
data = [np.array(data).astype(pyobs.double)]
141-
elif isinstance(data[0], (complex, np.complex64, np.complex128)):
141+
elif pyobs.is_type(data[0], pyobs.types.COMPLEX):
142142
data = [np.array(data).astype(pyobs.complex)]
143143
else:
144144
pyobs.assertion(True, "Data type not supported")
@@ -147,11 +147,11 @@ def create(self, ename, data, icnfg=None, rname=None, shape=(1,), lat=None):
147147
nc = [len(data[ir]) // self.size for ir in range(R)]
148148
if rname is None:
149149
rname = list(range(R))
150-
elif isinstance(rname, (str, int, np.int32, np.int64)):
150+
elif pyobs.is_type(rname, str, pyobs.types.INT):
151151
rname = [rname]
152152
if icnfg is None:
153153
icnfg = [range(_nc) for _nc in nc]
154-
elif isinstance(icnfg[0], (int, np.int32, np.int64)):
154+
elif pyobs.is_type(icnfg[0], pyobs.types.INT):
155155
icnfg = [icnfg]
156156

157157
pyobs.check_type(rname, "rname", list)
@@ -409,7 +409,7 @@ def f(x):
409409
return transform(self, f)
410410

411411
def __getitem__(self, args):
412-
if isinstance(args, (int, np.int32, np.int64, slice, np.ndarray)):
412+
if pyobs.is_type(args, pyobs.types.INT, slice, np.ndarray):
413413
args = [args]
414414
na = len(args)
415415
pyobs.assertion(na == len(self.shape), "Unexpected argument")
@@ -420,12 +420,12 @@ def f(x):
420420
return transform(self, f)
421421

422422
def __setitem__(self, args, yobs):
423-
if isinstance(args, (int, np.int32, np.int64, slice, np.ndarray)):
423+
if pyobs.is_type(args, pyobs.types.INT, slice, np.ndarray):
424424
args = [args]
425425
else:
426426
args = [
427427
[a]
428-
if isinstance(a, (int, np.int32, np.int64, slice, np.ndarray))
428+
if pyobs.is_type(a, pyobs.types.INT, slice, np.ndarray)
429429
else a
430430
for a in args
431431
]
@@ -454,6 +454,41 @@ def __setitem__(self, args, yobs):
454454
)
455455
self.cdata[key].assign(submask, yobs.cdata[key])
456456

457+
def rt(self, axis=None):
458+
"""
459+
Removes trivial tensor indices reducing the dimensionality of the observable.
460+
461+
Parameters:
462+
axis (int, list or array): axis to be considered for tensor removal.
463+
464+
Returns:
465+
observable
466+
467+
Examples:
468+
>>> obs.shape
469+
(10, 3, 1)
470+
>>> obs.rt().shape
471+
(10, 3)
472+
"""
473+
## duplicate of pyobs.remove_tensor, to be eventually deprecated
474+
Nd = len(self.shape)
475+
if axis is None:
476+
selection = [True] * Nd
477+
else:
478+
selection = [False] * Nd
479+
for a in pyobs.to_list(axis):
480+
selection[a] = True
481+
482+
new_shape = []
483+
for mu in range(Nd):
484+
if (self.shape[mu] == 1) and (selection[mu] is True):
485+
continue
486+
new_shape.append(self.shape[mu])
487+
if not new_shape:
488+
new_shape.append(1)
489+
return pyobs.reshape(self, tuple(new_shape))
490+
491+
457492
##################################
458493
# overloaded basic math operations
459494

@@ -654,14 +689,18 @@ class provides additional details for the automatic or manual
654689
>>> einfo = {'A': errinfo(Stau=3.0), 'B': errinfo(W=30)}
655690
>>> [v,e] = obsC.error(errinfo=einfo,plot=True)
656691
"""
657-
[sigma, sigma_tot, _] = self.error_core(errinfo, plot, pfile)
658-
659-
if plot: # pragma: no cover
660-
h = [len(self.ename), len(self.cdata)]
661-
if sum(h) > 1:
662-
plot_piechart(self.description, sigma, sigma_tot.real)
663-
664-
return [self.mean, np.sqrt(sigma_tot)]
692+
def error_real(obs):
693+
[sigma, sigma_tot, _] = obs.error_core(errinfo, plot, pfile)
694+
695+
if plot: # pragma: no cover
696+
h = [len(obs.ename), len(obs.cdata)]
697+
if sum(h) > 1:
698+
plot_piechart(obs.description, sigma, sigma_tot.real)
699+
return sigma_tot
700+
701+
if np.iscomplexobj(self.mean):
702+
return [self.mean, np.sqrt(error_real(self.real())) + 1j*np.sqrt(error_real(self.imag()))]
703+
return [self.mean, np.sqrt(error_real(self))]
665704

666705
def error_breakdown(self, errinfo={}):
667706
"""

pyobs/default.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,22 @@
1919
#
2020
#################################################################################
2121

22-
import numpy
22+
import numpy as np
2323
import functools
2424
import time
2525

26-
__all__ = ["is_verbose", "set_verbose", "log_timer", "complex", "double", "int"]
26+
__all__ = [
27+
"is_verbose",
28+
"set_verbose",
29+
"log_timer",
30+
"complex",
31+
"double",
32+
"int",
33+
]
2734

28-
complex = numpy.complex128
29-
double = numpy.float64
30-
int = numpy.int32
35+
complex = np.complex128
36+
double = np.float64
37+
int = np.int32
3138

3239
verbose = ["save", "load", "mfit"]
3340

@@ -46,7 +53,6 @@ def set_verbose(func, yesno=True):
4653
if func in verbose:
4754
verbose.remove(func)
4855

49-
5056
def log_timer(tag):
5157
def decorator(func):
5258
@functools.wraps(func)

pyobs/tensor/manipulate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def remove_tensor(x, axis=None):
7474
(10, 3)
7575
"""
7676
Nd = len(x.shape)
77-
if isinstance(axis, int):
77+
if pyobs.is_type(axis, pyobs.types.INT):
7878
axis = [axis]
7979
if axis is None:
8080
selection = [True] * Nd

pyobs/utils.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#################################################################################
22
#
33
# utils.py: generic utility routines
4-
# Copyright (C) 2020 Mattia Bruno
4+
# Copyright (C) 2020-2025 Mattia Bruno
55
#
66
# This program is free software; you can redistribute it and/or
77
# modify it under the terms of the GNU General Public License
@@ -30,13 +30,16 @@
3030
"PyobsError",
3131
"assertion",
3232
"check_type",
33+
"types",
34+
"is_type",
3335
"valerr",
3436
"tex_table",
3537
"slice_ndarray",
3638
"import_string",
3739
"array",
3840
"double_array",
3941
"int_array",
42+
"to_list",
4043
]
4144

4245

@@ -161,6 +164,13 @@ def check_type(obj, s, *t):
161164
if c == len(t):
162165
raise TypeError(f"Unexpected type for {s} [{t}]")
163166

167+
class types:
168+
INT = (int, numpy.int32, numpy.int64)
169+
FLOAT = (float, numpy.float32, numpy.float64)
170+
COMPLEX = (complex, numpy.complex64, numpy.complex128)
171+
172+
def is_type(x, *args):
173+
return isinstance(x, args)
164174

165175
def slice_to_range(sl, n):
166176
return list(range(n)[sl])
@@ -203,7 +213,7 @@ def slice_ndarray(t, *args):
203213
aa.append(range(s[ia]))
204214
else:
205215
aa.append(a)
206-
elif isinstance(a, (int, numpy.int32, numpy.int64)):
216+
elif pyobs.is_type(a, pyobs.types.INT):
207217
aa.append([a])
208218
else: # pragma: no cover
209219
raise PyobsError("slicing not understood")
@@ -233,3 +243,8 @@ def core(string):
233243
out = [core(s) for s in data]
234244
return numpy.array(out)
235245
return core(data)
246+
247+
def to_list(x):
248+
if numpy.isdim(x)==0:
249+
return [x]
250+
return list(x)

tests/core/complex.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
obsA_i = pyobs.observable()
2323
obsA_i.create('a', data_i, shape=(2,))
2424

25-
print(obsA_r, obsA_i)
25+
print(obsA_r.error())
26+
print(obsA_i.error())
27+
print(obsB.error())
2628

2729
def check(A,B):
2830
b, db = B.error()

tests/core/slice.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
obsB.create('B',data)
6363
[vB, eB] = obsB.error()
6464

65-
obsC = pyobs.remove_tensor(pyobs.stack([obsA,obsB]))
65+
obsC = pyobs.stack([obsA,obsB]).rt()
6666
obsC.peek()
6767
[v,e] = obsC[0].error()
6868
assert abs(vA - v) < 1e-12

0 commit comments

Comments
 (0)