From 3bfc5a4e62c8b8231f6df129fab708466bdfccb8 Mon Sep 17 00:00:00 2001 From: taozhiwang Date: Sat, 22 Jun 2024 13:56:14 -0400 Subject: [PATCH 1/4] Update index_data.py for data convertion and alignment --- qlib/model/riskmodel/shrink.py | 4 +++- qlib/utils/index_data.py | 12 +++++++++++- tests/misc/test_index_data.py | 18 ++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/qlib/model/riskmodel/shrink.py b/qlib/model/riskmodel/shrink.py index c3c0e48ef8..b2594f707d 100644 --- a/qlib/model/riskmodel/shrink.py +++ b/qlib/model/riskmodel/shrink.py @@ -247,7 +247,9 @@ def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np v1 = y.T.dot(z) / t - cov_mkt[:, None] * S roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt v3 = z.T.dot(z) / t - var_mkt * S - roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 + roff3 = ( + np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 + ) roff = 2 * roff1 - roff3 rho = rdiag + roff diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 113f9802d7..b5ae3df125 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -108,6 +108,12 @@ def __init__(self, idx_list: Union[List, pd.Index, "Index", int]): self.index_map = self.idx_list = np.arange(idx_list) self._is_sorted = True else: + # Check if all elements in idx_list are of the same type + if not all(isinstance(x, type(idx_list[0])) for x in idx_list): + raise TypeError("All elements in idx_list must be of the same type") + # Check if all elements in idx_list are of the same datetime64 precision + if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list): + raise TypeError("All elements in idx_list must be of the same datetime64 precision") self.idx_list = np.array(idx_list) # NOTE: only the first appearance is indexed self.index_map = dict(zip(self.idx_list, range(len(self)))) @@ -131,7 +137,11 @@ def _convert_type(self, item): if self.idx_list.dtype.type is np.datetime64: if isinstance(item, pd.Timestamp): # This happens often when creating index based on pandas.DatetimeIndex and query with pd.Timestamp - return item.to_numpy() + return item.to_numpy().astype(self.idx_list.dtype) + elif isinstance(item, np.datetime64): + # This happens often when creating index based on np.datetime64 and query with another precision + return item.astype(self.idx_list.dtype) + return item def index(self, item) -> int: diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index 2db644f8a6..b3045a5c7f 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -94,6 +94,24 @@ def test_corner_cases(self): print(sd) self.assertTrue(sd.iloc[0] == 2) + # test different precisions of time data + timeindex = [ + np.datetime64("2024-06-22T00:00:00.000000000"), + np.datetime64("2024-06-21T00:00:00.000000000"), + np.datetime64("2024-06-20T00:00:00.000000000"), + ] + sd = idd.SingleData([1, 2, 3], index=timeindex) + self.assertTrue( + sd.index.index(np.datetime64("2024-06-21T00:00:00.000000000")) + == sd.index.index(np.datetime64("2024-06-21T00:00:00")) + ) + self.assertTrue(sd.index.index(pd.Timestamp("2024-06-21 00:00")) == 1) + + # Bad case: the input is not aligned + timeindex[1] = (np.datetime64("2024-06-21T00:00:00.00"),) + with self.assertRaises(TypeError): + sd = idd.SingleData([1, 2, 3], index=timeindex) + def test_ops(self): sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) From cd7e3fd0010ccdc84795ed4bcc92c41bf3a58c65 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Sun, 23 Jun 2024 09:48:43 +0800 Subject: [PATCH 2/4] Update qlib/utils/index_data.py --- qlib/utils/index_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index b5ae3df125..ddc2694dfe 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -141,7 +141,8 @@ def _convert_type(self, item): elif isinstance(item, np.datetime64): # This happens often when creating index based on np.datetime64 and query with another precision return item.astype(self.idx_list.dtype) - + # NOTE: It is hard to consider every cases at first. + # We just try to cover part of cases to make it more user-friendly return item def index(self, item) -> int: From a6911508d19bf09445f7a8608c3bdb0258b6a8d6 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Sun, 23 Jun 2024 09:49:15 +0800 Subject: [PATCH 3/4] Update qlib/utils/index_data.py --- qlib/utils/index_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index ddc2694dfe..715040538a 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -141,7 +141,7 @@ def _convert_type(self, item): elif isinstance(item, np.datetime64): # This happens often when creating index based on np.datetime64 and query with another precision return item.astype(self.idx_list.dtype) - # NOTE: It is hard to consider every cases at first. + # NOTE: It is hard to consider every case at first. # We just try to cover part of cases to make it more user-friendly return item From 7d03d4430e20f8d1e4d1e581967df82a5a045c02 Mon Sep 17 00:00:00 2001 From: taozhiwang Date: Sat, 22 Jun 2024 23:28:12 -0400 Subject: [PATCH 4/4] fix linting --- qlib/model/riskmodel/shrink.py | 4 +--- qlib/utils/index_data.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/qlib/model/riskmodel/shrink.py b/qlib/model/riskmodel/shrink.py index b2594f707d..c3c0e48ef8 100644 --- a/qlib/model/riskmodel/shrink.py +++ b/qlib/model/riskmodel/shrink.py @@ -247,9 +247,7 @@ def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np v1 = y.T.dot(z) / t - cov_mkt[:, None] * S roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt v3 = z.T.dot(z) / t - var_mkt * S - roff3 = ( - np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 - ) + roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 roff = 2 * roff1 - roff3 rho = rdiag + roff diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 715040538a..6c4525ce36 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -141,8 +141,8 @@ def _convert_type(self, item): elif isinstance(item, np.datetime64): # This happens often when creating index based on np.datetime64 and query with another precision return item.astype(self.idx_list.dtype) - # NOTE: It is hard to consider every case at first. - # We just try to cover part of cases to make it more user-friendly + # NOTE: It is hard to consider every case at first. + # We just try to cover part of cases to make it more user-friendly return item def index(self, item) -> int: