diff --git a/pecos/ann/hnsw/model.py b/pecos/ann/hnsw/model.py index 97f8e2f0..8db989ee 100644 --- a/pecos/ann/hnsw/model.py +++ b/pecos/ann/hnsw/model.py @@ -93,7 +93,7 @@ def __init__(self, model_ptr, num_item, feat_dim, fn_dict, pred_params=None): self.pred_params = self.PredParams.from_dict(pred_params) def __del__(self): - if self.model_ptr and type(self.model_ptr) == c_void_p: + if self.model_ptr and isinstance(self.model_ptr, c_void_p): self.fn_dict["destruct"](self.model_ptr) @property diff --git a/pecos/ann/pairwise/__init__.py b/pecos/ann/pairwise/__init__.py new file mode 100644 index 00000000..8b023b61 --- /dev/null +++ b/pecos/ann/pairwise/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. +from .model import PairwiseANN # noqa diff --git a/pecos/ann/pairwise/model.py b/pecos/ann/pairwise/model.py new file mode 100644 index 00000000..20f1bedb --- /dev/null +++ b/pecos/ann/pairwise/model.py @@ -0,0 +1,294 @@ +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. +from ctypes import ( + POINTER, + c_bool, + c_float, + c_uint32, + c_char_p, + c_void_p, +) +import os +import copy +import json +import dataclasses as dc +import numpy as np +import scipy.sparse as smat + +import pecos +from pecos.core import ScipyCsrF32, ScipyDrmF32 +from pecos.core import clib as pecos_clib + + +class PairwiseANN(pecos.BaseClass): + @dc.dataclass + class TrainParams(pecos.BaseParams): + """Training Parameters of PairwiseANN class + + Attributes: + metric_type (str): distance metric type, can only be "ip" for inner product for now + """ + + metric_type: str = "ip" + + @dc.dataclass + class PredParams(pecos.BaseParams): + """Prediction Parameters of PairwiseANN class + + Attributes: + topk (int): maximum number of candidates (sorted by distances, nearest first) return by the searcher per query + """ + + topk: int = 10 + + class Searchers(object): + def __init__(self, model, max_batch_size=256, max_only_topk=10, num_searcher=1): + self.searchers_ptr = model.fn_dict["searchers_create"]( + model.model_ptr, + num_searcher, + ) + self.destruct_fn = model.fn_dict["searchers_destruct"] + + # searchers also hold the memory of returned np.ndarray + self.max_batch_size = max_batch_size + self.max_only_topk = max_only_topk + max_nnz = max_batch_size * max_only_topk + self.Imat = np.zeros(max_nnz, dtype=np.uint32) + self.Mmat = np.zeros(max_nnz, dtype=np.uint32) + self.Dmat = np.zeros(max_nnz, dtype=np.float32) + self.Vmat = np.zeros(max_nnz, dtype=np.float32) + + def __del__(self): + if self.searchers_ptr is not None: + self.destruct_fn(self.searchers_ptr) + + def ctypes(self): + return self.searchers_ptr + + def reset(self, reset_nnz): + self.Imat[:reset_nnz].fill(0) + self.Mmat[:reset_nnz].fill(0) + self.Dmat[:reset_nnz].fill(0.0) + self.Vmat[:reset_nnz].fill(0.0) + + def __init__( + self, model_ptr, num_input_keys, num_label_keys, feat_dim, fn_dict, pred_params=None + ): + """constructor of PairwiseANN class + Args: + model_ptr (c_void_p): pointer to C instance pecos::ann:PairwiseANN. + num_item (int): number of item being indexed + feat_dim (int): feature dimension of each item + fn_dict (dict): dictionary that stores the C/C++ functions to call + pred_params (PairwiseANN.PredParams, optional): prediction parameters + """ + self.model_ptr = model_ptr + self.num_input_keys = num_input_keys + self.num_label_keys = num_label_keys + self.feat_dim = feat_dim + self.fn_dict = fn_dict + self.pred_params = self.PredParams.from_dict(pred_params) + + def __del__(self): + if self.model_ptr and isinstance(self.model_ptr, c_void_p): + self.fn_dict["destruct"](self.model_ptr) + + @property + def data_type(self): + return self.fn_dict["data_type"] + + @property + def metric_type(self): + return self.fn_dict["metric_type"] + + @staticmethod + def create_pymat(X): + """create PyMat wrapper given the input X matrix + Args: + X (nd.array, scipy.sparse.csr_matrix): database matrix to be indexed. (num_item x feat_dim). + Returns: + pX (ScipyDrmF32/ScipyCsrF32): python wrapper class for np.array/csr_matrix + data_type (str): data type of X, either drm or csr + """ + pX = None + data_type = None + if isinstance(X, (np.ndarray, ScipyDrmF32)): + pX = ScipyDrmF32.init_from(X) + data_type = "drm" + elif isinstance(X, (smat.csr_matrix, ScipyCsrF32)): + pX = ScipyCsrF32.init_from(X) + data_type = "csr" + else: + raise ValueError("type(X)={} is NOT supported!".format(type(X))) + return pX, data_type + + @classmethod + def train(cls, X, Y, train_params=None, pred_params=None): + """train and return the PairwiseANN indexer + Args: + X (numpy.array or smat.csr_matrix): database matrix to be indexed. (num_item x feat_dim). + train_params (Pairwise.TrainParams, optional): instance of PairwiseANN.TrainParams + pred_params (Pairwise.PredParams, optional): instance of PairwiseANN.PredParams + Returns: + PairwiseANN (pecos.ann.pairwise.PairwiseANN): the trained PairwiseANN model + """ + train_params = cls.TrainParams.from_dict(train_params) + pred_params = cls.PredParams.from_dict(pred_params) + + if isinstance(Y, smat.csr_matrix): + pY = ScipyCsrF32.init_from(Y) + else: + raise ValueError("type(Y) != smat.csr_matix..") + pX, data_type = cls.create_pymat(X) + fn_dict = pecos_clib.pairwise_ann_init(data_type, train_params.metric_type) + model_ptr = fn_dict["train"](pX, pY) + return cls(model_ptr, pY.rows, pY.cols, pX.cols, fn_dict, pred_params) + + @classmethod + def load(cls, model_folder, lazy_load=False): + """Load PairwiseANN model from file + Args: + model_folder (str): model directory from which the model is loaded. + lazy_load (bool): whether to lazy_load memory-mapped files (default False). + Returns: + PairwiseANN (pecos.ann.pairwise.PairwiseANN): the loaded PairwiseANN model + """ + with open("{}/param.json".format(model_folder), "r") as fin: + param = json.loads(fin.read()) + if param["model"] != cls.__name__: + raise ValueError(f"param[model] != cls.__name__") + + if not ("data_type" in param and "metric_type" in param): + raise ValueError(f"param.json did not have data_type or metric_type!") + fn_dict = pecos_clib.pairwise_ann_init(param["data_type"], param["metric_type"]) + + c_model_dir = f"{model_folder}/c_model" + if not os.path.isdir(c_model_dir): + raise ValueError(f"c_model_dir did not exist: {c_model_dir}") + model_ptr = fn_dict["load"](c_char_p(c_model_dir.encode("utf-8")), c_bool(lazy_load)) + pred_params = cls.PredParams.from_dict(param["pred_kwargs"]) + return cls( + model_ptr, + param["num_input_keys"], + param["num_label_keys"], + param["feat_dim"], + fn_dict, + pred_params, + ) + + def save(self, model_folder): + """Save PairwiseANN Model to file + Args: + model_folder (str): model directory to which the model is saved + """ + if not os.path.exists(model_folder): + os.makedirs(model_folder) + param = { + "model": self.__class__.__name__, + "data_type": self.data_type, + "metric_type": self.metric_type, + "num_input_keys": self.num_input_keys, + "num_label_keys": self.num_label_keys, + "feat_dim": self.feat_dim, + "pred_kwargs": self.pred_params.to_dict(), + } + param = self.append_meta(param) + with open("{}/param.json".format(model_folder), "w") as fout: + fout.write(json.dumps(param, indent=True)) + c_model_dir = f"{model_folder}/c_model" + self.fn_dict["save"](self.model_ptr, c_char_p(c_model_dir.encode("utf-8"))) + + def searchers_create(self, max_batch_size=256, max_only_topk=10, num_searcher=1): + """create searchers that pre-allocate intermediate variables (e.g., topk_queue) + Args: + max_batch_size (int): the maximum batch size for the input/label pairs to be inference + max_only_topk (int): the maximum only topk for the kNN to return + num_searcher: number of searcher for multi-thread inference + Returns: + PairwiseANN.Searchers: the pre-allocated PairwiseANN.Searchers (class object) + """ + if not self.model_ptr: + raise ValueError("self.model_ptr must exist before using searchers_create()") + if num_searcher <= 0: + raise ValueError("num_searcher={} <= 0 is NOT valid".format(num_searcher)) + return PairwiseANN.Searchers(self, max_batch_size, max_only_topk, num_searcher) + + def get_pred_params(self): + """Return a deep copy of prediction parameters + + Returns: + copied_pred_params (dict): Prediction parameters. + """ + return copy.deepcopy(self.pred_params) + + def predict(self, input_feat, label_keys, searchers, pred_params=None, is_same_input=False): + """predict with multi-thread. The searchers are required to be provided. + Args: + input_feat (numpy.array or smat.csr_matrix): input feature matrix (first key) to find kNN. + if is_same_input == False, the shape should be (batch_size, feat_dim) + if is_same_input == True, the shape should be (1, feat_dim) + label_keys (numpy.array): the label keys (second key) to find kNN. The shape should be (batch_size, ). + searchers (c_void_p): pointer to C/C++ vector. Created by PairwiseANN.searchers_create(). + pred_params (Pairwise.PredParams, optional): instance of pecos.ann.pairwise.Pairwise.PredParams. + is_same_input (bool): whether to use the same first row of X to do prediction. + For real-time inference with same input query, set is_same_input = True. + For batch prediction with varying input querues, set is_same_input = False. + Returns: + Imat (np.array): returned kNN input key indices. Shape of (batch_size, topk) + Mmat (np.array): returned kNN masking array. 1/0 mean value is or is not presented. Shape of (batch_size, topk) + Dmat (np.array): returned kNN distance array. Shape of (batch_size, topk) + Vmat (np.array): returned kNN value array. Shape of (batch_size, topk) + """ + input_feat_py, data_type = self.create_pymat(input_feat) + if data_type != self.data_type: + raise ValueError( + "data_type={} is NOT consistent with self.data_type={}".format( + data_type, self.data_type + ) + ) + if input_feat_py.cols != self.feat_dim: + raise ValueError( + "input_feat_py.cols={} is NOT consistent with self.feat_dim={}".format( + input_feat_py.cols, self.feat_dim + ) + ) + if not isinstance(label_keys, np.ndarray): + raise TypeError(f"type(label_keys) != np.array") + if not is_same_input and input_feat_py.rows != label_keys.shape[0]: + raise ValueError(f"input_feat_py.rows != label_keys.shape[0]") + + batch_size = label_keys.shape[0] + pred_params = self.get_pred_params() if pred_params is None else pred_params + only_topk = pred_params.topk + cur_nnz = batch_size * only_topk + if batch_size > searchers.max_batch_size: + raise ValueError(f"cur_batch_size > searchers.max_batch_size") + if only_topk > searchers.max_only_topk: + raise ValueError(f"cur_only_topk > searchers.max_only_topk") + + searchers.reset(cur_nnz) + self.fn_dict["predict"]( + searchers.ctypes(), + batch_size, + only_topk, + input_feat_py, + label_keys.ctypes.data_as(POINTER(c_uint32)), + searchers.Imat.ctypes.data_as(POINTER(c_uint32)), + searchers.Mmat.ctypes.data_as(POINTER(c_uint32)), + searchers.Dmat.ctypes.data_as(POINTER(c_float)), + searchers.Vmat.ctypes.data_as(POINTER(c_float)), + c_bool(is_same_input), + ) + Imat = searchers.Imat[:cur_nnz].reshape(batch_size, only_topk) + Mmat = searchers.Mmat[:cur_nnz].reshape(batch_size, only_topk) + Dmat = searchers.Dmat[:cur_nnz].reshape(batch_size, only_topk) + Vmat = searchers.Vmat[:cur_nnz].reshape(batch_size, only_topk) + return Imat, Mmat, Dmat, Vmat diff --git a/pecos/core/ann/distance.hpp b/pecos/core/ann/distance.hpp index eb823c01..e3a45655 100644 --- a/pecos/core/ann/distance.hpp +++ b/pecos/core/ann/distance.hpp @@ -11,6 +11,9 @@ * and limitations under the License. */ +#ifndef __DISTANCE_H__ +#define __DISTANCE_H__ + namespace pecos { namespace ann { @@ -23,6 +26,8 @@ namespace ann { #include "distance_impl/default.hpp" #endif -} +} // end of namespace ann +} // end of namespace pecos + +#endif // end of __DISTANCE_H__ -} diff --git a/pecos/core/ann/feat_vectors.hpp b/pecos/core/ann/feat_vectors.hpp index 65d74399..6daccbc3 100644 --- a/pecos/core/ann/feat_vectors.hpp +++ b/pecos/core/ann/feat_vectors.hpp @@ -11,6 +11,9 @@ * and limitations under the License. */ +#ifndef __FEAT_VECTORS_H__ +#define __FEAT_VECTORS_H__ + #pragma once #if defined(__GNUC__) #define PORTABLE_ALIGN32 __attribute__((aligned(32))) @@ -334,3 +337,6 @@ namespace ann { } // end of namespace ann } // end of namespace pecos + +#endif // end of __FEAT_VECTORS_H__ + diff --git a/pecos/core/ann/hnsw.hpp b/pecos/core/ann/hnsw.hpp index af1f9bdd..f8962a5e 100644 --- a/pecos/core/ann/hnsw.hpp +++ b/pecos/core/ann/hnsw.hpp @@ -11,6 +11,9 @@ * and limitations under the License. */ +#ifndef __HNSW_H__ +#define __HNSW_H__ + #include #include #include @@ -1333,3 +1336,6 @@ namespace ann { } // end of namespace ann } // end of namespace pecos + +#endif // end of __HNSW_H__ + diff --git a/pecos/core/ann/pairwise.hpp b/pecos/core/ann/pairwise.hpp new file mode 100644 index 00000000..7b7fd559 --- /dev/null +++ b/pecos/core/ann/pairwise.hpp @@ -0,0 +1,287 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +#ifndef __PAIRWISE_ANN_H__ +#define __PAIRWISE_ANN_H__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ann/hnsw.hpp" +#include "third_party/nlohmann_json/json.hpp" +#include "utils/file_util.hpp" +#include "utils/mmap_util.hpp" +#include "utils/matrix.hpp" +#include "utils/random.hpp" +#include "utils/type_util.hpp" + +namespace pecos { + +namespace ann { + + typedef uint32_t index_type; + typedef uint64_t mem_index_type; + typedef float32_t value_type; + + template + struct is_sparse_mat { + static const bool value = false; // compile-time constant + }; + template<> + struct is_sparse_mat { + static const bool value = true; // compile-time constant + }; + template<> + struct is_sparse_mat { + static const bool value = true; // compile-time constant + }; + + template + void save_mat( + const MAT_T &X, mmap_util::MmapStore& mmap_s, + typename std::enable_if::value, MAT_T>::type* = 0 + ) { + auto nnz = X.get_nnz(); + mmap_s.fput_one(X.rows); + mmap_s.fput_one(X.cols); + mmap_s.fput_one(nnz); + mmap_s.fput_multiple(X.indptr, (X.IS_COLUMN_MAJORED ? X.cols : X.rows) + 1); + mmap_s.fput_multiple(X.indices, nnz); + mmap_s.fput_multiple(X.data, nnz); + } + + template + void save_mat( + const MAT_T &X, mmap_util::MmapStore& mmap_s, + typename std::enable_if::value, MAT_T>::type* = 0 + ) { + auto nnz = X.get_nnz(); + mmap_s.fput_one(X.rows); + mmap_s.fput_one(X.cols); + mmap_s.fput_one(nnz); + mmap_s.fput_multiple(X.val, nnz); + } + + template + void load_mat(MAT_T &X, mmap_util::MmapStore& mmap_s) { + X.rows = mmap_s.fget_one(); + X.cols = mmap_s.fget_one(); + auto nnz = mmap_s.fget_one(); + X.indptr = mmap_s.fget_multiple((X.IS_COLUMN_MAJORED ? X.cols : X.rows) + 1); + X.indices = mmap_s.fget_multiple(nnz); + X.data = mmap_s.fget_multiple(nnz); + } + + template<> + void load_mat(pecos::drm_t &X, mmap_util::MmapStore& mmap_s) { + X.rows = mmap_s.fget_one(); + X.cols = mmap_s.fget_one(); + auto nnz = mmap_s.fget_one(); + X.val = mmap_s.fget_multiple(nnz); + } + + template + struct KeyValPair { + T1 input_key_idx; + T2 input_key_dist; + T2 input_label_val; + KeyValPair(const T1& input_key_idx=T1(), const T2& input_key_dist=T2(), const T2& input_label_val=T2()): + input_key_idx(input_key_idx), input_key_dist(input_key_dist), input_label_val(input_label_val) {} + bool operator<(const KeyValPair& other) const { return input_key_dist < other.input_key_dist; } + bool operator>(const KeyValPair& other) const { return input_key_dist > other.input_key_dist; } + }; + + // PairwiseANN Interface + template + struct PairwiseANN { + typedef FeatVec_T feat_vec_t; + typedef MAT_T mat_t; + typedef pecos::ann::KeyValPair pair_t; + typedef pecos::ann::heap_t> max_heap_t; + + struct Searcher { + typedef PairwiseANN pairwise_ann_t; + typedef pecos::ann::heap_t> max_heap_t; + + const pairwise_ann_t* pairwise_ann; + max_heap_t topk_queue; + + Searcher(const pairwise_ann_t* _pairwise_ann=nullptr): pairwise_ann(_pairwise_ann) {} + + void reset() { topk_queue.clear(); } + + max_heap_t& predict_single(const feat_vec_t& query_vec, const index_type label_key, index_type topk) { + return pairwise_ann->predict_single(query_vec, label_key, topk, *this); + } + }; + + Searcher create_searcher() const { + return Searcher(this); + } + + // scalar variables + index_type num_input_keys; // N + index_type num_label_keys; // L + index_type feat_dim; // d + + // matrices + pecos::csc_t Y_csc; // shape of [N, L] + mat_t X_trn; // shape of [N, d] + + // for loading memory-mapped file + pecos::mmap_util::MmapStore mmap_store; + + // destructor + ~PairwiseANN() {} + + static nlohmann::json load_config(const std::string& filepath) { + std::ifstream loadfile(filepath); + std::string json_str; + if (loadfile.is_open()) { + json_str.assign( + std::istreambuf_iterator(loadfile), + std::istreambuf_iterator() + ); + } else { + throw std::runtime_error("Unable to open config file at " + filepath); + } + auto j_params = nlohmann::json::parse(json_str); + std::string cur_pairwise_ann_t = pecos::type_util::full_name(); + std::string inp_pairwise_ann_t = j_params["pairwise_ann_t"]; + if (cur_pairwise_ann_t != inp_pairwise_ann_t) { + throw std::invalid_argument("Inconsistent PairwiseANN_T: cur = " + cur_pairwise_ann_t + " inp = " + inp_pairwise_ann_t); + } + return j_params; + } + + void save_config(const std::string& filepath) const { + nlohmann::json j_params = { + {"pairwise_ann_t", pecos::type_util::full_name()}, + {"version", "v1.0"}, + {"train_params", { + {"num_input_keys", num_input_keys}, + {"num_label_keys", num_label_keys}, + {"feat_dim", feat_dim}, + {"nnz_of_Y", Y_csc.get_nnz()}, + {"nnz_of_X", X_trn.get_nnz()} + } + } + }; + std::ofstream savefile(filepath, std::ofstream::trunc); + if (savefile.is_open()) { + savefile << j_params.dump(4); + savefile.close(); + } else { + throw std::runtime_error("Unable to save config file to " + filepath); + } + } + + void save(const std::string& model_dir) const { + if (mkdir(model_dir.c_str(), 0777) == -1) { + if (errno != EEXIST) { + throw std::runtime_error("Unable to create save folder at " + model_dir); + } + } + save_config(model_dir + "/config.json"); + std::string index_path = model_dir + "/index.mmap_store"; + mmap_util::MmapStore mmap_s = mmap_util::MmapStore(); + mmap_s.open(index_path.c_str(), "w"); + // save scalar variables + mmap_s.fput_one(num_input_keys); + mmap_s.fput_one(num_label_keys); + mmap_s.fput_one(feat_dim); + // save matrices + save_mat(Y_csc, mmap_s); + save_mat(X_trn, mmap_s); + mmap_s.close(); + } + + void load(const std::string& model_dir, bool lazy_load = false) { + auto config = load_config(model_dir + "/config.json"); + std::string version = config.find("version") != config.end() ? config["version"] : "not found"; + if (version == "v1.0") { + std::string index_path = model_dir + "/index.mmap_store"; + mmap_store.open(index_path.c_str(), lazy_load ? "r_lazy" : "r"); + // load scalar variables + num_input_keys = mmap_store.fget_one(); + num_label_keys = mmap_store.fget_one(); + feat_dim = mmap_store.fget_one(); + // load matrices + load_mat(Y_csc, mmap_store); + load_mat(X_trn, mmap_store); + // DO NOT call mmap_store.close() as the actual memory is held by this->mmap_store object. + } else { + throw std::runtime_error("Unable to load memory-mapped file with version = " + version); + } + } + + void train(const mat_t &X_trn, const pecos::csr_t &Y_trn) { + // sanity check + std::string mat_t_str = pecos::type_util::full_name(); + if (mat_t_str != "pecos::csr_t" && mat_t_str != "pecos::drm_t") { + throw std::runtime_error("X_trn should be either csr_t or drm_t!"); + } + if (X_trn.rows != Y_trn.rows) { + throw std::runtime_error("X_trn.rows != Y_trn.rows"); + } + // scalar variables + this->num_input_keys = Y_trn.rows; + this->num_label_keys = Y_trn.cols; + this->feat_dim = X_trn.cols; + // matrices + this->Y_csc = Y_trn.to_csc(); + this->X_trn = X_trn; + } + + max_heap_t& predict_single( + const feat_vec_t& query_vec, + const index_type label_key, + index_type topk, + Searcher& searcher + ) const { + searcher.reset(); + max_heap_t& topk_queue = searcher.topk_queue; + + const auto& rid_vec = this->Y_csc.get_col(label_key); + for (index_type idx = 0; idx < rid_vec.nnz; idx++) { + const auto input_key_idx = rid_vec.idx[idx]; + const auto input_label_val = rid_vec.val[idx]; + value_type input_key_dist = feat_vec_t::distance(query_vec, X_trn.get_row(input_key_idx)); + topk_queue.emplace(input_key_idx, input_key_dist, input_label_val); + } + if (topk < rid_vec.nnz) { + while (topk_queue.size() > topk) { + topk_queue.pop(); + } + } + std::sort_heap(topk_queue.begin(), topk_queue.end()); + return topk_queue; + } + }; + +} // end of namespace ann +} // end of namespace pecos + +#endif // end of __PAIRWISE_ANN_H__ + diff --git a/pecos/core/ann/quantizer.hpp b/pecos/core/ann/quantizer.hpp index 5e4f12a1..ed2b2891 100644 --- a/pecos/core/ann/quantizer.hpp +++ b/pecos/core/ann/quantizer.hpp @@ -11,6 +11,8 @@ * and limitations under the License. */ +#ifndef __QUANTIZER_H__ +#define __QUANTIZER_H__ #if defined(__x86_64__) || defined(__amd64__) #include "quantizer_impl/x86.hpp" @@ -20,3 +22,5 @@ #include "quantizer_impl/default.hpp" #endif +#endif // end of __QUANTIZER_H__ + diff --git a/pecos/core/base.py b/pecos/core/base.py index a72ebf7e..a4e969d0 100644 --- a/pecos/core/base.py +++ b/pecos/core/base.py @@ -532,6 +532,7 @@ def __init__(self, dirname, soname, forced_rebuild=False): self.link_clustering() self.link_tfidf_vectorizer() self.link_ann_hnsw_methods() + self.link_pairwise_ann_methods() self.link_mmap_hashmap_methods() self.link_mmap_valstore_methods() self.link_calibrator_methods() @@ -1440,7 +1441,7 @@ def tfidf_destruct(self, model): Args: model (ptr): Pointer to C Tfdif model. """ - if type(model) == c_void_p: + if isinstance(model, c_void_p): self.clib_float32.c_tfidf_destruct(model) def tfidf_save(self, model, save_dir): @@ -1703,6 +1704,108 @@ def ann_hnsw_init(self, data_type, metric_type): ) return self.ann_hnsw_fn_dict[data_type, metric_type] + def link_pairwise_ann_methods(self): + """ + Specify C-lib's PairwiseANN method argument and return type. + """ + data_type_map = {"drm": POINTER(ScipyDrmF32), "csr": POINTER(ScipyCsrF32)} + metric_type_list = ["ip"] + self.pairwise_ann_fn_dict = {} + for data_type in data_type_map: + for metric_type in metric_type_list: + local_fn_dict = {"data_type": data_type, "metric_type": metric_type} + + fn_name = "train" + c_fn_name = f"c_pairwise_ann_{fn_name}_{data_type}_{metric_type}_f32" + local_fn_dict[fn_name] = getattr(self.clib_float32, c_fn_name) + res_list = c_void_p # pointer to C/C++ pecos::ann::PairwiseANN + arg_list = [ + data_type_map[data_type], # X_trn (either drm/csr) + data_type_map["csr"], # Y_csr (csr only) + ] + corelib.fillprototype(local_fn_dict[fn_name], res_list, arg_list) + + fn_name = "load" + c_fn_name = f"c_pairwise_ann_{fn_name}_{data_type}_{metric_type}_f32" + local_fn_dict[fn_name] = getattr(self.clib_float32, c_fn_name) + res_list = c_void_p # pointer to C/C++ pecos::ann::PairwiseANN + arg_list = [ + c_char_p, # pointer to C/C++ pecos:ann::PairwiseANN + c_bool, # bool for lazy_load of mmap files + ] + corelib.fillprototype(local_fn_dict[fn_name], res_list, arg_list) + + fn_name = "save" + c_fn_name = f"c_pairwise_ann_{fn_name}_{data_type}_{metric_type}_f32" + local_fn_dict[fn_name] = getattr(self.clib_float32, c_fn_name) + res_list = None + arg_list = [ + c_void_p, # pointer to C/C++ pecos::ann::PairwiseANN + c_char_p, # pointer to char* model_dir + ] + corelib.fillprototype(local_fn_dict[fn_name], res_list, arg_list) + + fn_name = "destruct" + c_fn_name = f"c_pairwise_ann_{fn_name}_{data_type}_{metric_type}_f32" + local_fn_dict[fn_name] = getattr(self.clib_float32, c_fn_name) + res_list = None + arg_list = [c_void_p] # pointer to C/C++ pecos::ann::PairwiseANN + corelib.fillprototype(local_fn_dict[fn_name], res_list, arg_list) + + fn_name = "searchers_create" + c_fn_name = f"c_pairwise_ann_{fn_name}_{data_type}_{metric_type}_f32" + local_fn_dict[fn_name] = getattr(self.clib_float32, c_fn_name) + res_list = c_void_p # pointer to vector + arg_list = [ + c_void_p, # pointer C/C++ pecos::ann::PairwiseANN + c_uint32, # number of searcher + ] + corelib.fillprototype(local_fn_dict[fn_name], res_list, arg_list) + + fn_name = "searchers_destruct" + c_fn_name = f"c_pairwise_ann_{fn_name}_{data_type}_{metric_type}_f32" + local_fn_dict[fn_name] = getattr(self.clib_float32, c_fn_name) + res_list = None + arg_list = [ + c_void_p + ] # pointer to C/C++ std::vector + corelib.fillprototype(local_fn_dict[fn_name], res_list, arg_list) + + fn_name = "predict" + c_fn_name = f"c_pairwise_ann_{fn_name}_{data_type}_{metric_type}_f32" + local_fn_dict[fn_name] = getattr(self.clib_float32, c_fn_name) + res_list = None + arg_list = [ + c_void_p, # pointer to C/C++ std::vector + c_uint32, # batch_size + c_uint32, # topk + data_type_map[data_type], # ScipyDrmF32*/ScipyCsrF32* Q_tst + POINTER(c_uint32), # uint32_t* label_keys + POINTER(c_uint32), # uint32_t* ret_Imat + POINTER(c_uint32), # uint32_t* ret_Mmat + POINTER(c_float), # float* ret_Dmat + POINTER(c_float), # float* ret_Vmat + c_bool, # bool for is_same_input flag for either batch vs. real-time inference + ] + corelib.fillprototype(local_fn_dict[fn_name], res_list, arg_list) + + self.pairwise_ann_fn_dict[data_type, metric_type] = local_fn_dict + + def pairwise_ann_init(self, data_type, metric_type): + """Python to C/C++ interface for PairwiseANN initialization + Args: + data_type (str): data type for query matrices, can be either drm or csr + metric_type (str): metric type for computing distance functions, can only be ip + Returns: + pairwise_ann_fn_dict (dict): a dictionary that holds clib's C/C++ functions for Python to call + """ + + if (data_type, metric_type) not in self.pairwise_ann_fn_dict: + raise NotImplementedError( + "data_type={} and metric_type={} is not implemented".format(data_type, metric_type) + ) + return self.pairwise_ann_fn_dict[data_type, metric_type] + def link_mmap_hashmap_methods(self): """ Specify C-lib's Memory-mappable Hashmap methods arguments and return types. diff --git a/pecos/core/libpecos.cpp b/pecos/core/libpecos.cpp index d1dcb1bf..87695445 100644 --- a/pecos/core/libpecos.cpp +++ b/pecos/core/libpecos.cpp @@ -21,6 +21,7 @@ #include "xmc/linear_solver.hpp" #include "ann/feat_vectors.hpp" #include "ann/hnsw.hpp" +#include "ann/pairwise.hpp" // ===== C Interface of Functions ====== // C Interface of Types/Structures can be found in utils/matrix.hpp @@ -476,6 +477,102 @@ extern "C" { C_ANN_HNSW_PREDICT(_csr_ip_f32, ScipyCsrF32, pecos::csr_t, hnsw_csr_ip_t) C_ANN_HNSW_PREDICT(_csr_l2_f32, ScipyCsrF32, pecos::csr_t, hnsw_csr_l2_t) + + // ==== C Interface of PairwiseANN ==== + + typedef pecos::ann::PairwiseANN, pecos::csr_t> pairwise_ann_csr_ip_t; + typedef pecos::ann::PairwiseANN, pecos::drm_t> pairwise_ann_drm_ip_t; + + #define C_PAIRWISE_ANN_TRAIN(SUFFIX, PY_MAT, C_MAT, PAIRWISE_ANN_T) \ + void* c_pairwise_ann_train ## SUFFIX(const PY_MAT* pX, const ScipyCsrF32* pY) { \ + C_MAT X_trn(pX); \ + pecos::csr_t Y_csr(pY); \ + PAIRWISE_ANN_T *model_ptr = new PAIRWISE_ANN_T(); \ + model_ptr->train(X_trn, Y_csr); \ + return static_cast(model_ptr); \ + } + C_PAIRWISE_ANN_TRAIN(_csr_ip_f32, ScipyCsrF32, pecos::csr_t, pairwise_ann_csr_ip_t) + C_PAIRWISE_ANN_TRAIN(_drm_ip_f32, ScipyDrmF32, pecos::drm_t, pairwise_ann_drm_ip_t) + + #define C_PAIRWISE_ANN_LOAD(SUFFIX, PAIRWISE_ANN_T) \ + void* c_pairwise_ann_load ## SUFFIX(const char* model_dir, const bool lazy_load) { \ + PAIRWISE_ANN_T *model_ptr = new PAIRWISE_ANN_T(); \ + model_ptr->load(model_dir, lazy_load); \ + return static_cast(model_ptr); \ + } + C_PAIRWISE_ANN_LOAD(_drm_ip_f32, pairwise_ann_drm_ip_t) + C_PAIRWISE_ANN_LOAD(_csr_ip_f32, pairwise_ann_csr_ip_t) + + #define C_PAIRWISE_ANN_SAVE(SUFFIX, PAIRWISE_ANN_T) \ + void c_pairwise_ann_save ## SUFFIX(void* model_ptr, const char* model_dir) { \ + const auto &model = *static_cast(model_ptr); \ + model.save(model_dir); \ + } + C_PAIRWISE_ANN_SAVE(_drm_ip_f32, pairwise_ann_drm_ip_t) + C_PAIRWISE_ANN_SAVE(_csr_ip_f32, pairwise_ann_csr_ip_t) + + #define C_PAIRWISE_ANN_DESTRUCT(SUFFIX, PAIRWISE_ANN_T) \ + void c_pairwise_ann_destruct ## SUFFIX(void* model_ptr) { \ + delete static_cast(model_ptr); \ + } + C_PAIRWISE_ANN_DESTRUCT(_drm_ip_f32, pairwise_ann_drm_ip_t) + C_PAIRWISE_ANN_DESTRUCT(_csr_ip_f32, pairwise_ann_csr_ip_t) + + #define C_PAIRWISE_ANN_SEARCHERS_CREATE(SUFFIX, PAIRWISE_ANN_T) \ + void* c_pairwise_ann_searchers_create ## SUFFIX(void* model_ptr, uint32_t num_searcher) { \ + typedef typename PAIRWISE_ANN_T::Searcher searcher_t; \ + const auto &model = *static_cast(model_ptr); \ + auto searchers_ptr = new std::vector(); \ + for (uint32_t t = 0; t < num_searcher; t++) { \ + searchers_ptr->emplace_back(model.create_searcher()); \ + } \ + return static_cast(searchers_ptr); \ + } + C_PAIRWISE_ANN_SEARCHERS_CREATE(_drm_ip_f32, pairwise_ann_drm_ip_t) + C_PAIRWISE_ANN_SEARCHERS_CREATE(_csr_ip_f32, pairwise_ann_csr_ip_t) + + #define C_PAIRWISE_ANN_SEARCHERS_DESTRUCT(SUFFIX, PAIRWISE_ANN_T) \ + void c_pairwise_ann_searchers_destruct ## SUFFIX(void* searchers_ptr) { \ + typedef typename PAIRWISE_ANN_T::Searcher searcher_t; \ + delete static_cast*>(searchers_ptr); \ + } + C_PAIRWISE_ANN_SEARCHERS_DESTRUCT(_drm_ip_f32, pairwise_ann_drm_ip_t) + C_PAIRWISE_ANN_SEARCHERS_DESTRUCT(_csr_ip_f32, pairwise_ann_csr_ip_t) + + #define C_PAIRWISE_ANN_PREDICT(SUFFIX, PY_MAT, C_MAT, PAIRWISE_ANN_T) \ + void c_pairwise_ann_predict ## SUFFIX( \ + void* searchers_ptr, \ + uint32_t batch_size, \ + uint32_t topk, \ + const PY_MAT* pQ, \ + uint32_t* label_keys, \ + uint32_t* ret_Imat, \ + uint32_t* ret_Mmat, \ + float* ret_Dmat, \ + float* ret_Vmat, \ + const bool is_sample_input) { \ + C_MAT Q_tst(pQ); \ + auto& searchers = *static_cast*>(searchers_ptr); \ + omp_set_num_threads(searchers.size()); \ + OMP_PARA_FOR \ + for (uint32_t bidx=0; bidx < batch_size; bidx++) { \ + int tid = omp_get_thread_num(); \ + auto input_key = (is_sample_input ? 0 : bidx); \ + auto label_key = label_keys[bidx]; \ + auto& ret_pairs = searchers[tid].predict_single(Q_tst.get_row(input_key), label_key, topk); \ + for (uint32_t k=0; k < ret_pairs.size(); k++) { \ + uint64_t offset = static_cast(bidx) * topk; \ + ret_Imat[offset + k] = ret_pairs[k].input_key_idx; \ + ret_Dmat[offset + k] = ret_pairs[k].input_key_dist; \ + ret_Vmat[offset + k] = ret_pairs[k].input_label_val; \ + ret_Mmat[offset + k] = 1; \ + } \ + } \ + } + C_PAIRWISE_ANN_PREDICT(_drm_ip_f32, ScipyDrmF32, pecos::drm_t, pairwise_ann_drm_ip_t) + C_PAIRWISE_ANN_PREDICT(_csr_ip_f32, ScipyCsrF32, pecos::csr_t, pairwise_ann_csr_ip_t) + + // ==== C Interface of Memory-mappable Hashmap ==== typedef pecos::mmap_hashmap::Str2IntMap mmap_hashmap_str2int; diff --git a/pecos/core/utils/type_util.hpp b/pecos/core/utils/type_util.hpp index ff4a699c..e8462132 100644 --- a/pecos/core/utils/type_util.hpp +++ b/pecos/core/utils/type_util.hpp @@ -1,3 +1,19 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +#ifndef __TYPE_UTIL_H__ +#define __TYPE_UTIL_H__ + #include #include #include @@ -164,3 +180,6 @@ namespace type_util { } // end namespace type_util } // end namespace pecos + +#endif // end of __TYPE_UTIL_H__ + diff --git a/pecos/utils/mmap_valstore_util.py b/pecos/utils/mmap_valstore_util.py index e152154e..a337e7bc 100644 --- a/pecos/utils/mmap_valstore_util.py +++ b/pecos/utils/mmap_valstore_util.py @@ -296,7 +296,7 @@ def __init__(self, max_row_size: int, max_col_size: int, trunc_val_len: int): # For str decoding, from memory view is faster than from Numpy view bytes_start_loc = [idx * self.trunc_val_len for idx in range(max_row_size * max_col_size)] self.byte_mem_views = [ - memoryview(self.vals[start_idx : start_idx + self.trunc_val_len]) + memoryview(self.vals[start_idx : start_idx + self.trunc_val_len]) # type: ignore[arg-type] for start_idx in bytes_start_loc ] diff --git a/test/pecos/ann/test_pairwise_ann.py b/test/pecos/ann/test_pairwise_ann.py new file mode 100644 index 00000000..80ee597d --- /dev/null +++ b/test/pecos/ann/test_pairwise_ann.py @@ -0,0 +1,169 @@ +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. +import pytest # noqa: F401; pylint: disable=unused-variable +import random # noqa +import numpy as np # noqa +import scipy.sparse as smat # noqa +from pytest import approx # noqa + + +def test_importable(): + from pecos.ann.pairwise import PairwiseANN # noqa: F401 + + +def test_save_and_load(tmpdir): + from pecos.ann.pairwise import PairwiseANN # noqa: F401 + + # train data + X_trn = np.array([[1, 0], [0, 0], [-1, 0]]).astype(np.float32) + Y_csr = smat.csr_matrix( + np.array( + [ + [1.1, 0.0, 0.0], + [0.0, 2.2, 2.4], + [3.1, 3.2, 0.0], + ] + ).astype(np.float32) + ) + # test data, noqa + X_tst = X_trn + label_keys = np.array([0, 1, 2]).astype(np.uint32) + # train & predict + train_params = PairwiseANN.TrainParams(metric_type="ip") + model = PairwiseANN.train(X_trn, Y_csr, train_params=train_params) + pred_params = PairwiseANN.PredParams(topk=2) + searchers = model.searchers_create(max_batch_size=250, max_only_topk=10, num_searcher=1) + It, Mt, Dt, Vt = model.predict( + X_tst, + label_keys, + searchers, + pred_params=pred_params, + is_same_input=False, + ) + # save model + model_folder = tmpdir.join("hnsw_model_dir") + model.save(model_folder) + del model, searchers + # load back and predict again + model = PairwiseANN.load(model_folder) + searchers = model.searchers_create(max_batch_size=250, max_only_topk=10, num_searcher=1) + Ip, Mp, Dp, Vp = model.predict( + X_tst, + label_keys, + searchers, + pred_params=pred_params, + is_same_input=False, + ) + assert Ip == approx(It, abs=0.0), f"pred faield: Ip != It" + assert Mp == approx(Mt, abs=0.0), f"pred faield: Mp != Mt" + assert Dp == approx(Dt, abs=0.0), f"pred faield: Dp != Dt" + assert Vp == approx(Vt, abs=0.0), f"pred faield: Vp != Vt" + del model, searchers + + +def test_predict_with_same_input(): + from pecos.ann.pairwise import PairwiseANN + + # train data + X_trn = np.array([[1, 0], [0, 0], [-1, 0]]).astype(np.float32) + Y_csr = smat.csr_matrix( + np.array( + [ + [1.1, 0.0, 0.0], + [0.0, 2.2, 2.4], + [3.1, 3.2, 0.0], + ] + ).astype(np.float32) + ) + # test data, noqa + X_tst = np.array([[2, 0]]).astype(np.float32) + label_keys = np.array([0, 1, 2]).astype(np.uint32) + # train + train_params = PairwiseANN.TrainParams(metric_type="ip") + model = PairwiseANN.train(X_trn, Y_csr, train_params=train_params) + # predict + pred_params = PairwiseANN.PredParams(topk=2) + searchers = model.searchers_create(max_batch_size=250, max_only_topk=10, num_searcher=1) + Ip, Mp, Dp, Vp = model.predict( + X_tst, + label_keys, + searchers, + pred_params=pred_params, + is_same_input=True, + ) + # compare to expected ground truth + It = np.array([[0, 2], [1, 2], [1, 0]]).astype(np.uint32) + Mt = np.array([[1, 1], [1, 1], [1, 0]]).astype(np.uint32) + Dt = np.array([[-1, 3], [1, 3], [1.0, 0.0]]).astype(np.float32) + Vt = np.array([[1.1, 3.1], [2.2, 3.2], [2.4, 0.0]]).astype(np.float32) + assert Ip == approx(It, abs=0.0), f"pred faield: Ip != It" + assert Mp == approx(Mt, abs=0.0), f"pred faield: Mp != Mt" + assert Dp == approx(Dt, abs=0.0), f"pred faield: Dp != Dt" + assert Vp == approx(Vt, abs=0.0), f"pred faield: Vp != Vt" + del model, searchers + + +def test_predict_with_multiple_calls(): + from pecos.ann.pairwise import PairwiseANN + + # train data + X_trn = np.array([[1, 0], [0, 2], [-1, 0]]).astype(np.float32) + Y_csr = smat.csr_matrix( + np.array( + [ + [1.1, 0.0, 0.0], + [0.0, 2.2, 2.4], + [3.1, 3.2, 0.0], + ] + ).astype(np.float32) + ) + # test data, noqa + batch_size = 3 + X_tst = X_trn + label_keys = np.array(range(batch_size)).astype(np.uint32) + # train + train_params = PairwiseANN.TrainParams(metric_type="ip") + model = PairwiseANN.train(X_trn, Y_csr, train_params=train_params) + # batch predict + pred_params = PairwiseANN.PredParams(topk=2) + searchers = model.searchers_create(max_batch_size=250, max_only_topk=10, num_searcher=1) + Ip, Mp, Dp, Vp = model.predict( + X_tst, + label_keys, + searchers, + pred_params=pred_params, + is_same_input=False, + ) + It = np.array([[0, 2], [1, 2], [1, 0]]).astype(np.uint32) + Mt = np.array([[1, 1], [1, 1], [1, 0]]).astype(np.uint32) + Dt = np.array([[0, 2], [-3, 1], [1, 0]]).astype(np.float32) + Vt = np.array([[1.1, 3.1], [2.2, 3.2], [2.4, 0.0]]).astype(np.float32) + assert Ip == approx(It, abs=0.0), f"pred failed: Ip != It" + assert Mp == approx(Mt, abs=0.0), f"pred failed: Mp != Mt" + assert Dp == approx(Dt, abs=0.0), f"pred failed: Dp != Dt" + assert Vp == approx(Vt, abs=0.0), f"pred failed: Vp != Vt" + + # make predict on single (q,a) pair with multiple calls + # to test if we properly reset the memory buffer + for bidx in range(batch_size): + Ip_b, Mp_b, Dp_b, Vp_b = model.predict( + X_tst[bidx, :].reshape(1, -1), + np.array([label_keys[bidx]]), + searchers, + pred_params=pred_params, + is_same_input=True, + ) + assert Ip_b == approx(It[bidx, :].reshape(1, -1), abs=0.0), f"bidx={bidx} failed: Ip != It" + assert Mp_b == approx(Mt[bidx, :].reshape(1, -1), abs=0.0), f"bidx={bidx} failed: Mp != Mt" + assert Dp_b == approx(Dt[bidx, :].reshape(1, -1), abs=0.0), f"bidx={bidx} failed: Dp != Dt" + assert Vp_b == approx(Vt[bidx, :].reshape(1, -1), abs=0.0), f"bidx={bidx} failed: Vp != Vt" + + del model, searchers diff --git a/test/pecos/utils/featurization/text/test_preprocessor.py b/test/pecos/utils/featurization/text/test_preprocessor.py index f0f7bc88..0861ffef 100644 --- a/test/pecos/utils/featurization/text/test_preprocessor.py +++ b/test/pecos/utils/featurization/text/test_preprocessor.py @@ -29,7 +29,7 @@ def assert_matrix_equal(Xtgt, X, tolerance=1e-6): - assert type(X) == type(Xtgt) + assert type(X) == type(Xtgt) # noqa: E721 if isinstance(Xtgt, np.ndarray): assert Xtgt == approx(X, abs=tolerance) elif isinstance(Xtgt, smat.spmatrix): diff --git a/test/pecos/utils/test_smat_utils.py b/test/pecos/utils/test_smat_utils.py index 69f92edb..08e20bf9 100644 --- a/test/pecos/utils/test_smat_utils.py +++ b/test/pecos/utils/test_smat_utils.py @@ -87,18 +87,18 @@ def test_stack_csr(): X_hstack = smat_util.hstack_csr([X1, X2]) assert X_hstack.todense() == approx(np.hstack([X0, X0])) assert X_hstack.dtype == X1.dtype - assert type(X_hstack) == smat.csr_matrix + assert isinstance(X_hstack, smat.csr_matrix) X_vstack = smat_util.vstack_csr([X1, X2]) assert X_vstack.todense() == approx(np.vstack([X0, X0])) assert X_vstack.dtype == X1.dtype - assert type(X_vstack) == smat.csr_matrix + assert isinstance(X_vstack, smat.csr_matrix) X_block_diag = smat_util.block_diag_csr([X1, X2]) X_np_block_diag = np.hstack( [np.vstack([X0, np.zeros_like(X0)]), np.vstack([np.zeros_like(X0), X0])] ) assert X_block_diag.todense() == approx(X_np_block_diag) assert X_block_diag.dtype == X1.dtype - assert type(X_block_diag) == smat.csr_matrix + assert isinstance(X_block_diag, smat.csr_matrix) def test_stack_csc(): @@ -112,18 +112,18 @@ def test_stack_csc(): X_hstack = smat_util.hstack_csc([X1, X2]) assert X_hstack.todense() == approx(np.hstack([X0, X0])) assert X_hstack.dtype == X1.dtype - assert type(X_hstack) == smat.csc_matrix + assert isinstance(X_hstack, smat.csc_matrix) X_vstack = smat_util.vstack_csc([X1, X2]) assert X_vstack.todense() == approx(np.vstack([X0, X0])) assert X_vstack.dtype == X1.dtype - assert type(X_vstack) == smat.csc_matrix + assert isinstance(X_vstack, smat.csc_matrix) X_block_diag = smat_util.block_diag_csc([X1, X2]) X_np_block_diag = np.hstack( [np.vstack([X0, np.zeros_like(X0)]), np.vstack([np.zeros_like(X0), X0])] ) assert X_block_diag.todense() == approx(X_np_block_diag) assert X_block_diag.dtype == X1.dtype - assert type(X_block_diag) == smat.csc_matrix + assert isinstance(X_block_diag, smat.csc_matrix) def test_get_col_row_nonzero(): @@ -205,9 +205,9 @@ def test_get_row_submatrices(): X0_sub, X1_sub = smat_util.get_row_submatrices([X0, X1], row_indices) - assert type(X0_sub) == type(X0) + assert type(X0_sub) == type(X0) # noqa: E721 assert X0_sub == approx(Xres) - assert type(X1_sub) == type(X1) + assert type(X1_sub) == type(X1) # noqa: E721 assert X1_sub.todense() == approx(Xres)