diff --git a/ci/36.yaml b/ci/36.yaml index 91424574..80a7553f 100644 --- a/ci/36.yaml +++ b/ci/36.yaml @@ -35,3 +35,4 @@ dependencies: - pywget - proplot - contextily + - scikit-plot diff --git a/ci/37.yaml b/ci/37.yaml index 5381ba10..affa540b 100644 --- a/ci/37.yaml +++ b/ci/37.yaml @@ -35,3 +35,4 @@ dependencies: - pywget - proplot - contextily + - scikit-plot diff --git a/ci/38.yaml b/ci/38.yaml index db98c31f..84396022 100644 --- a/ci/38.yaml +++ b/ci/38.yaml @@ -35,3 +35,4 @@ dependencies: - pywget - proplot - contextily + - scikit-plot diff --git a/environment.yml b/environment.yml index 12c0b5be..6d261a81 100644 --- a/environment.yml +++ b/environment.yml @@ -25,4 +25,4 @@ dependencies: - tobler >=0.2.1 - proplot - contextily - + - scikit-plot diff --git a/geosnap/_community.py b/geosnap/_community.py index 105efece..0857536e 100644 --- a/geosnap/_community.py +++ b/geosnap/_community.py @@ -3,6 +3,7 @@ import geopandas as gpd import pandas as pd +import scikitplot as skplt from ._data import _Map, datasets from .analyze import cluster as _cluster @@ -321,6 +322,30 @@ def cluster_spatial( comm.models[model_name] = model return comm + def silplot(self, model_name=None, year=None, **kwargs): + """ Returns a silhouette plot of the model that is passed to it. + + Parameters + ---------- + model_name : str , required + model to be silhouette plotted + year : int, optional + year to be plotted if model created with pooling=='unique' + kwargs : **kwargs, optional + pass through to plot_silhouette() + Returns + ------- + silhouette plot of given model. + + """ + if not year: + plot = skplt.metrics.plot_silhouette(self.models[model_name].X, self.models[model_name].labels, + **kwargs) + else: + plot = skplt.metrics.plot_silhouette(self.models[model_name][year].X, self.models[model_name][year].labels, + **kwargs) + return plot + def transition( self, cluster_col, time_var="year", id_var="geoid", w_type=None, permutations=0 ): diff --git a/geosnap/analyze/analytics.py b/geosnap/analyze/analytics.py index c7163eab..48810ca8 100644 --- a/geosnap/analyze/analytics.py +++ b/geosnap/analyze/analytics.py @@ -1,7 +1,5 @@ """Tools for the spatial analysis of neighborhood change.""" -from collections import namedtuple - import numpy as np import pandas as pd from sklearn.preprocessing import StandardScaler @@ -25,9 +23,47 @@ ward_spatial, ) -ModelResults = namedtuple( - "model", ["X", "columns", "labels", "instance", "W"], rename=False -) + +class ModelResults: + """Stores data about cluster and cluster_spatial models. + + Attributes + ---------- + X: array-like + data used to compute model + columns: list-like + columns used in model + W: libpysal.weights.W + libpysal spatial weights matrix used in model + labels: array-like + labels of each column + instance: instance of model class used to generate neighborhood labels. + fitted model instance, e.g sklearn.cluster.AgglomerativeClustering object + or other model class used to estimate class labels + + """ + def __init__(self, X, columns, labels,instance,W,): + """Initialize a new ModelResults instance. + + Parameters + ---------- + X: array-like + data of the cluster + columns: list-like + columns used to compute model + W: libpysal.weights.W + libpysal spatial weights matrix used in model + labels: array-like + labels of each column + instance: AgglomerativeCluserting object, or other model specific object type + how many clusters model was computed with + + """ + self.columns = columns + self.X = X + self.W = W + self.instance = instance + self.labels = labels def cluster(