Skip to content

Commit

Permalink
Merge branch 'microsoft:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
karlotimmerman authored Jan 18, 2023
2 parents 56401de + 60d8870 commit 8f5723f
Show file tree
Hide file tree
Showing 37 changed files with 1,379 additions and 177 deletions.
11 changes: 9 additions & 2 deletions erroranalysis/erroranalysis/analyzer/error_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,17 @@ def _compute_mutual_info(self, input_data, diff):
error.
:rtype: list[float]
"""
# if only one row, replicate it to avoid exception
if input_data.shape[0] == 1:
input_data = np.concatenate((input_data, input_data))
diff = np.concatenate((diff, diff))
n_neighbors = min(3, input_data.shape[0] - 1)
if self._model_task == ModelTask.CLASSIFICATION:
return mutual_info_classif(input_data, diff).tolist()
return mutual_info_classif(
input_data, diff, n_neighbors=n_neighbors).tolist()
else:
return mutual_info_regression(input_data, diff).tolist()
return mutual_info_regression(
input_data, diff, n_neighbors=n_neighbors).tolist()

def compute_root_stats(self):
"""Compute the root all data statistics.
Expand Down
4 changes: 2 additions & 2 deletions erroranalysis/erroranalysis/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

name = 'erroranalysis'
_major = '0'
_minor = '3'
_patch = '13'
_minor = '4'
_patch = '0'
version = '{}.{}.{}'.format(_major, _minor, _patch)
22 changes: 21 additions & 1 deletion erroranalysis/tests/test_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time

import numpy as np
import pytest
from common_utils import replicate_dataset

from erroranalysis._internal.constants import ModelTask
Expand All @@ -15,9 +16,12 @@
from rai_test_utils.models.model_utils import (create_models_classification,
create_models_regression)
from rai_test_utils.models.sklearn import (
create_sklearn_random_forest_classifier,
create_sklearn_random_forest_regressor, create_titanic_pipeline)

TOL = 1e-10
NUM_SAMPLE_ROWS = 100
DEFAULT_SAMPLE_COLS = 20


class TestImportances(object):
Expand Down Expand Up @@ -77,7 +81,7 @@ def test_large_data_importances(self):
# mutual information can be very costly for large number of rows
# hence, assert we downsample to compute importances for large data
X_train, y_train, X_test, y_test, _ = \
create_binary_classification_dataset(100)
create_binary_classification_dataset(NUM_SAMPLE_ROWS)
feature_names = list(X_train.columns)
model = create_sklearn_random_forest_regressor(X_train, y_train)
X_test, y_test = replicate_dataset(X_test, y_test)
Expand All @@ -95,6 +99,22 @@ def test_large_data_importances(self):
# note execution time is in seconds
assert execution_time < 20

@pytest.mark.parametrize('num_rows', [1, 2, 3, 4])
def test_small_data_importances(self, num_rows):
# validate we can run on very few rows
X_train, y_train, X_test, y_test, _ = \
create_binary_classification_dataset(NUM_SAMPLE_ROWS)
feature_names = list(X_train.columns)
model = create_sklearn_random_forest_classifier(X_train, y_train)
X_test = X_test[:num_rows]
y_test = y_test[:num_rows]
categorical_features = []
model_analyzer = ModelAnalyzer(model, X_test, y_test,
feature_names,
categorical_features)
scores = model_analyzer.compute_importances()
assert len(scores) == DEFAULT_SAMPLE_COLS

def test_importances_missings(self):
X_train, X_test, y_train, y_test, feature_names, _ = create_iris_data()

Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

export * from "./lib/cohortKey";
export * from "./lib/Cohort/isAllDataCohort";
export * from "./lib/Cohort/allDataCohortUtils";
export * from "./lib/Cohort/Cohort";
export * from "./lib/Cohort/CohortList/CohortList";
export * from "./lib/Cohort/Constants";
Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import { localization } from "@responsible-ai/localization";
import React from "react";

import { getCohortFilterCount } from "../../util/getCohortFilterCount";
import { isAllDataErrorCohort } from "../allDataCohortUtils";
import { ErrorCohortStats } from "../CohortStats";
import { ErrorCohort } from "../ErrorCohort";
import { isAllDataErrorCohort } from "../isAllDataCohort";
import { PredictionPath } from "../PredictionPath/PredictionPath";

import { cohortInfoStyles } from "./CohortInfo.styles";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
IModelAssessmentContext,
ModelAssessmentContext
} from "../../Context/ModelAssessmentContext";
import { isAllDataErrorCohort } from "../isAllDataCohort";
import { isAllDataErrorCohort } from "../allDataCohortUtils";

export interface ICohortInfoSectionProps {
toggleShiftCohortVisibility: () => void;
Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/components/AxisConfig.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export interface IAxisConfigProps {
canBin: boolean;
mustBin: boolean;
canDither: boolean;
allowTreatAsCategorical?: boolean;
allowTreatAsCategorical: boolean;
hideDroppedFeatures?: boolean;
onAccept: (newConfig: ISelectorConfig) => void;
}
Expand Down
62 changes: 39 additions & 23 deletions libs/core-ui/src/lib/components/AxisConfigBinOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import React from "react";

import { AxisTypes, ISelectorConfig } from "../util/IGenericChartProps";
import { JointDataset } from "../util/JointDataset";
import { IJointMeta } from "../util/JointDatasetUtils";

import { axisConfigBinOptionsStyles } from "./AxisConfigBinOptions.styles";
import { AxisConfigDialogSpinButton } from "./AxisConfigDialogSpinButton";
Expand All @@ -26,7 +27,8 @@ export interface IAxisConfigBinOptionsProps {
minHistCols: number;
mustBin: boolean;
selectedBinCount?: number;
allowTreatAsCategorical?: boolean;
allowTreatAsCategorical: boolean;
allowLogarithmicScaling?: boolean;
selectedColumn: ISelectorConfig;
onBinCountUpdated: (binCount?: number) => void;
onEnableLogarithmicScaling: (checked?: boolean | undefined) => void;
Expand Down Expand Up @@ -54,28 +56,24 @@ export class AxisConfigBinOptions extends React.PureComponent<IAxisConfigBinOpti
</Text>
</Stack.Item>
)}
{(selectedMeta.featureRange?.rangeType === RangeTypes.Integer ||
selectedMeta.featureRange?.rangeType === RangeTypes.Numeric) &&
allowUserInteract(this.props.selectedColumn.property) && (
<Toggle
key="logarithmic-toggle"
label={localization.Interpret.AxisConfigDialog.logarithmicScaling}
inlineLabel
checked={selectedMeta.AxisType === AxisTypes.Logarithmic}
onChange={this.enableLogarithmicScaling}
/>
)}
{selectedMeta.featureRange?.rangeType === RangeTypes.Integer &&
this.props.allowTreatAsCategorical &&
allowUserInteract(this.props.selectedColumn.property) && (
<Toggle
key="categorical-toggle"
label={localization.Interpret.AxisConfigDialog.TreatAsCategorical}
inlineLabel
checked={selectedMeta.treatAsCategorical}
onChange={this.setAsCategorical}
/>
)}
{this.displayLogarithmicToggle(selectedMeta) && (
<Toggle
key="logarithmic-toggle"
label={localization.Interpret.AxisConfigDialog.logarithmicScaling}
inlineLabel
checked={selectedMeta.AxisType === AxisTypes.Logarithmic}
onChange={this.enableLogarithmicScaling}
/>
)}
{this.displayCategoricalToggle(selectedMeta) && (
<Toggle
key="categorical-toggle"
label={localization.Interpret.AxisConfigDialog.TreatAsCategorical}
inlineLabel
checked={selectedMeta.treatAsCategorical}
onChange={this.setAsCategorical}
/>
)}
{selectedMeta?.treatAsCategorical ? (
<>
<Text variant="small">
Expand Down Expand Up @@ -126,6 +124,24 @@ export class AxisConfigBinOptions extends React.PureComponent<IAxisConfigBinOpti
);
};

private displayLogarithmicToggle = (selectedMeta: IJointMeta): boolean => {
const allowLogarithmicScaling = this.props.allowLogarithmicScaling ?? true;
return (
(selectedMeta.featureRange?.rangeType === RangeTypes.Integer ||
selectedMeta.featureRange?.rangeType === RangeTypes.Numeric) &&
allowLogarithmicScaling &&
allowUserInteract(this.props.selectedColumn.property)
);
};

private displayCategoricalToggle = (selectedMeta: IJointMeta): boolean => {
return (
selectedMeta.featureRange?.rangeType === RangeTypes.Integer &&
this.props.allowTreatAsCategorical &&
allowUserInteract(this.props.selectedColumn.property)
);
};

private readonly setAsCategorical = (
_ev?: React.FormEvent<HTMLElement>,
checked?: boolean
Expand Down
3 changes: 2 additions & 1 deletion libs/core-ui/src/lib/components/AxisConfigDialog.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export interface IAxisConfigDialogProps {
canBin: boolean;
mustBin: boolean;
canDither: boolean;
allowTreatAsCategorical?: boolean;
allowTreatAsCategorical: boolean;
allowLogarithmicScaling?: boolean;
hideDroppedFeatures?: boolean;
onAccept: (newConfig: ISelectorConfig) => void;
onCancel: () => void;
Expand Down
7 changes: 7 additions & 0 deletions libs/counterfactuals/src/lib/CounterfactualChart.styles.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@ export interface ICounterfactualChartStyles {
lowerChartContainer: IStyle;
rotatedVerticalBox: IStyle;
verticalAxis: IStyle;
buttonStyle: IStyle;
}

export const counterfactualChartStyles: () => IProcessedStyleSet<ICounterfactualChartStyles> =
() => {
return mergeStyleSets<ICounterfactualChartStyles>({
buttonStyle: {
marginBottom: "10px",
marginTop: "10px",
paddingBottom: "10px",
paddingTop: "10px"
},
chartWithAxes: {
...fullLgDown,
paddingTop: "30px",
Expand Down
31 changes: 27 additions & 4 deletions libs/counterfactuals/src/lib/CounterfactualChartLegend.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import {
IComboBox,
ComboBox,
PrimaryButton,
Stack
Stack,
DefaultButton
} from "@fluentui/react";
import {
defaultModelAssessmentContext,
Expand Down Expand Up @@ -38,6 +39,7 @@ export interface ICounterfactualChartLegendProps {
selectedPointsIndexes: number[];
indexSeries: number[];
isCounterfactualsDataLoading?: boolean;
isBubbleChartRendered?: boolean;
removeCustomPoint: (index: number) => void;
setTemporaryPointToCopyOfDatasetPoint: (
index: number,
Expand All @@ -48,6 +50,7 @@ export interface ICounterfactualChartLegendProps {
toggleCustomActivation: (index: number) => void;
togglePanel: () => void;
toggleSelectionOfPoint: (index?: number) => void;
setIsRevertButtonClicked: (status: boolean) => void;
}

export class CounterfactualChartLegend extends React.PureComponent<ICounterfactualChartLegendProps> {
Expand All @@ -59,7 +62,7 @@ export class CounterfactualChartLegend extends React.PureComponent<ICounterfactu
const classNames = counterfactualChartStyles();
return (
<Stack className={classNames.legendAndText}>
{this.displayDatapointDropbox() && (
{this.displayComboBox() && (
<ComboBox
id={"CounterfactualSelectedDatapoint"}
className={classNames.legendLabel}
Expand All @@ -82,7 +85,7 @@ export class CounterfactualChartLegend extends React.PureComponent<ICounterfactu
)}
</div>
<PrimaryButton
className={classNames.legendLabel}
className={classNames.buttonStyle}
onClick={this.props.togglePanel}
disabled={this.disableCounterfactualPanel()}
text={
Expand All @@ -91,6 +94,15 @@ export class CounterfactualChartLegend extends React.PureComponent<ICounterfactu
: localization.Counterfactuals.createCounterfactual
}
/>
{this.displayRevertButton() && (
<DefaultButton
className={classNames.buttonStyle}
onClick={this.onRevertButtonClick}
text={localization.Counterfactuals.revertToBubbleChart}
title={localization.Counterfactuals.revertToBubbleChart}
disabled={this.props.isCounterfactualsDataLoading}
/>
)}
{this.props.customPoints.length > 0 && (
<InteractiveLegend
items={this.props.customPoints.map((row, rowIndex) => {
Expand Down Expand Up @@ -119,7 +131,7 @@ export class CounterfactualChartLegend extends React.PureComponent<ICounterfactu
return localization.Counterfactuals.currentClass;
}

private displayDatapointDropbox(): boolean {
private displayComboBox(): boolean {
const isLargeDataEnabled = ifEnableLargeData(this.context.dataset);
if (!isLargeDataEnabled) {
return true;
Expand All @@ -128,6 +140,13 @@ export class CounterfactualChartLegend extends React.PureComponent<ICounterfactu
return isLargeDataEnabled && this.props.indexSeries.length > 0;
}

private displayRevertButton(): boolean {
return (
ifEnableLargeData(this.context.dataset) &&
this.props.indexSeries.length > 0
);
}

private selectPointFromDropdown = (
_event: React.FormEvent<IComboBox>,
item?: IComboBoxOption
Expand Down Expand Up @@ -156,6 +175,10 @@ export class CounterfactualChartLegend extends React.PureComponent<ICounterfactu
);
};

private onRevertButtonClick = (): void => {
this.props.setIsRevertButtonClicked(true);
};

private getDataOptions(): IComboBoxOption[] {
let indexes = this.context.selectedErrorCohort.cohort.unwrap(
JointDataset.IndexLabel
Expand Down
15 changes: 14 additions & 1 deletion libs/counterfactuals/src/lib/CounterfactualChartWithLegend.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export interface ICounterfactualChartWithLegendProps {
indexSeries: number[];
temporaryPoint: { [key: string]: any } | undefined;
isCounterfactualsDataLoading?: boolean;
isRevertButtonClicked: boolean;
onChartPropsUpdated: (chartProps: IGenericChartProps) => void;
onCustomPointLengthUpdated: (customPointLength: number) => void;
onSelectedPointsIndexesUpdated: (selectedPointsIndexes: number[]) => void;
Expand All @@ -48,6 +49,8 @@ export interface ICounterfactualChartWithLegendProps {
setCounterfactualData: (absoluteIndex?: number) => Promise<void>;
telemetryHook?: (message: ITelemetryEvent) => void;
onIndexSeriesUpdated?: (indexSeries: number[]) => void;
setIsRevertButtonClicked: (status: boolean) => void;
resetIndexes: () => void;
}

export interface ICounterfactualChartWithLegendState {
Expand Down Expand Up @@ -128,6 +131,7 @@ export class CounterfactualChartWithLegend extends React.PureComponent<
isCounterfactualsDataLoading={
this.props.isCounterfactualsDataLoading
}
setIsRevertButtonClicked={this.props.setIsRevertButtonClicked}
/>
</Stack>
</Stack.Item>
Expand Down Expand Up @@ -179,12 +183,16 @@ export class CounterfactualChartWithLegend extends React.PureComponent<
};

private onChartPropsUpdated = (newProps: IGenericChartProps): void => {
this.resetCustomPoints();
this.props.onChartPropsUpdated(newProps);
};

private resetCustomPoints = (): void => {
this.setState({
customPointIsActive: [],
customPoints: [],
pointIsActive: []
});
this.props.onChartPropsUpdated(newProps);
};

private saveAsPoint = (): void => {
Expand Down Expand Up @@ -222,6 +230,7 @@ export class CounterfactualChartWithLegend extends React.PureComponent<
private getLargeCounterfactualChartComponent = (): React.ReactNode => {
return (
<LargeCounterfactualChart
cohort={this.context.selectedErrorCohort.cohort}
chartProps={this.props.chartProps}
customPoints={this.state.customPoints}
isPanelOpen={this.state.isPanelOpen}
Expand All @@ -241,6 +250,10 @@ export class CounterfactualChartWithLegend extends React.PureComponent<
setCounterfactualData={this.props.setCounterfactualData}
onIndexSeriesUpdated={this.props.onIndexSeriesUpdated}
isCounterfactualsDataLoading={this.props.isCounterfactualsDataLoading}
isRevertButtonClicked={this.props.isRevertButtonClicked}
setIsRevertButtonClicked={this.props.setIsRevertButtonClicked}
resetIndexes={this.props.resetIndexes}
resetCustomPoints={this.resetCustomPoints}
/>
);
};
Expand Down
Loading

0 comments on commit 8f5723f

Please sign in to comment.