From b75a59aa0b5916a47e0ce21670fdb43ab2a91d2f Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Thu, 12 Jan 2023 17:34:24 -0800 Subject: [PATCH 01/11] Move import to top level from within the function (#1898) Signed-off-by: Gaurav Gupta Signed-off-by: Gaurav Gupta --- responsibleai/tests/rai_insights/test_rai_insights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/responsibleai/tests/rai_insights/test_rai_insights.py b/responsibleai/tests/rai_insights/test_rai_insights.py index fd5ad0e307..aa986dc128 100644 --- a/responsibleai/tests/rai_insights/test_rai_insights.py +++ b/responsibleai/tests/rai_insights/test_rai_insights.py @@ -28,6 +28,7 @@ from responsibleai._internal.constants import ManagerNames from responsibleai._tools.shared.state_directory_management import \ DirectoryManager +from responsibleai.feature_metadata import FeatureMetadata LABELS = 'labels' @@ -62,7 +63,6 @@ def test_rai_insights_iris(self, manager_type): ManagerParams.FEATURE_IMPORTANCE: True } - from responsibleai.feature_metadata import FeatureMetadata feature_metadata = FeatureMetadata( identity_feature_name='sepal length', dropped_features=['petal length']) From 051133a9ea52639efed3690018f03667dd35cca2 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 12 Jan 2023 22:16:51 -0500 Subject: [PATCH 02/11] Add forecasting dashboard - UI - transformation creation, table, and comparison (#1874) * Add postga build trigger (#1755) (#1756) Signed-off-by: Gaurav Gupta Signed-off-by: Gaurav Gupta Signed-off-by: Gaurav Gupta * Add model wrapper for wrapping predictions and test data (#1762) * Add model wrapper for wrapping predictions and test data Signed-off-by: Gaurav Gupta * Fix failing tests Signed-off-by: Gaurav Gupta Signed-off-by: Gaurav Gupta * Change description of cohort selection panel in Aggregate Feature Importance (#1770) Signed-off-by: Gaurav Gupta Signed-off-by: Gaurav Gupta * Support cohort filtering of string target in rai_insights (#1771) * Port tests Signed-off-by: Gaurav Gupta * Fix rai_insights Signed-off-by: Gaurav Gupta Signed-off-by: Gaurav Gupta * Simplify tests in test_cohort_filter.py (#1772) * Simply tests in test_cohort_filter.py Signed-off-by: Gaurav Gupta * Change test name Signed-off-by: Gaurav Gupta Signed-off-by: Gaurav Gupta * Forecasting Dashboard This commit adds the Forecasting Dashboard capabilities in both UI and SDK. * remove unrelated file * basic set of changes to get it working again * forecasting_grains -> time_series_id_column_names * fix breaking changes in latest main, add quantile prediction, fix date conversion, remove obsolete managers * fix UI components (including table and what-if creation, chart still left tbd) * get dashboard to work with multiple time series and switching back and forth * remove data explorer and model overview, create and edit cohort functionalities from forecasting * lintfix * localizzation * remove console.log * remove question * remove unrelated changes, fix package.json * string fixes and package.json adjustment * localization * localization * more localization and removing of unrelated files * remove unrelated files * remove unrelated file changes * remove is_forecasting_true_y * lintfix * add dropdown label * cache transformation predictions * lintfix * remove isUndefinedOrEmpty * remove isUndefinedOrEmpty * rename isAllDataCohort file as requested * lintfix * remove unused componentDidMount Signed-off-by: Gaurav Gupta Co-authored-by: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Co-authored-by: Gaurav Gupta Co-authored-by: John Wang --- libs/core-ui/src/index.ts | 2 +- .../src/lib/Cohort/CohortInfo/CohortInfo.tsx | 2 +- .../CohortInfoSection/CohortInfoSection.tsx | 2 +- ...AllDataCohort.ts => allDataCohortUtils.ts} | 0 .../Controls/ForecastComparison.tsx | 115 ++++++++- .../Controls/TransformationCreation.tsx | 216 ++++++++++++++++ .../Controls/TransformationCreationDialog.tsx | 242 ++++++++++++++++++ .../Controls/TransformationsTable.tsx | 147 +++++++++++ .../Controls/getForecastPrediction.ts | 13 +- .../ForecastingDashboard.tsx | 128 +++++++-- .../Interfaces/Transformation.ts | 81 ++++++ .../Cohort/ShiftCohort.tsx | 14 +- 12 files changed, 919 insertions(+), 43 deletions(-) rename libs/core-ui/src/lib/Cohort/{isAllDataCohort.ts => allDataCohortUtils.ts} (100%) create mode 100644 libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationCreation.tsx create mode 100644 libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationCreationDialog.tsx create mode 100644 libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationsTable.tsx create mode 100644 libs/forecasting/src/lib/ForecastingDashboard/Interfaces/Transformation.ts diff --git a/libs/core-ui/src/index.ts b/libs/core-ui/src/index.ts index bb94f51f46..4506e4a3d0 100644 --- a/libs/core-ui/src/index.ts +++ b/libs/core-ui/src/index.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. export * from "./lib/cohortKey"; -export * from "./lib/Cohort/isAllDataCohort"; +export * from "./lib/Cohort/allDataCohortUtils"; export * from "./lib/Cohort/Cohort"; export * from "./lib/Cohort/CohortList/CohortList"; export * from "./lib/Cohort/Constants"; diff --git a/libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx b/libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx index cec5822fa9..018ff58bdd 100644 --- a/libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx +++ b/libs/core-ui/src/lib/Cohort/CohortInfo/CohortInfo.tsx @@ -6,9 +6,9 @@ import { localization } from "@responsible-ai/localization"; import React from "react"; import { getCohortFilterCount } from "../../util/getCohortFilterCount"; +import { isAllDataErrorCohort } from "../allDataCohortUtils"; import { ErrorCohortStats } from "../CohortStats"; import { ErrorCohort } from "../ErrorCohort"; -import { isAllDataErrorCohort } from "../isAllDataCohort"; import { PredictionPath } from "../PredictionPath/PredictionPath"; import { cohortInfoStyles } from "./CohortInfo.styles"; diff --git a/libs/core-ui/src/lib/Cohort/CohortInfoSection/CohortInfoSection.tsx b/libs/core-ui/src/lib/Cohort/CohortInfoSection/CohortInfoSection.tsx index 182bff70ae..06c72cb0cc 100644 --- a/libs/core-ui/src/lib/Cohort/CohortInfoSection/CohortInfoSection.tsx +++ b/libs/core-ui/src/lib/Cohort/CohortInfoSection/CohortInfoSection.tsx @@ -10,7 +10,7 @@ import { IModelAssessmentContext, ModelAssessmentContext } from "../../Context/ModelAssessmentContext"; -import { isAllDataErrorCohort } from "../isAllDataCohort"; +import { isAllDataErrorCohort } from "../allDataCohortUtils"; export interface ICohortInfoSectionProps { toggleShiftCohortVisibility: () => void; diff --git a/libs/core-ui/src/lib/Cohort/isAllDataCohort.ts b/libs/core-ui/src/lib/Cohort/allDataCohortUtils.ts similarity index 100% rename from libs/core-ui/src/lib/Cohort/isAllDataCohort.ts rename to libs/core-ui/src/lib/Cohort/allDataCohortUtils.ts diff --git a/libs/forecasting/src/lib/ForecastingDashboard/Controls/ForecastComparison.tsx b/libs/forecasting/src/lib/ForecastingDashboard/Controls/ForecastComparison.tsx index d3d8faf1f0..41efd9aab2 100644 --- a/libs/forecasting/src/lib/ForecastingDashboard/Controls/ForecastComparison.tsx +++ b/libs/forecasting/src/lib/ForecastingDashboard/Controls/ForecastComparison.tsx @@ -14,15 +14,20 @@ import { SeriesOptionsType } from "highcharts"; import React from "react"; import { forecastingDashboardStyles } from "../ForecastingDashboard.styles"; +import { Transformation } from "../Interfaces/Transformation"; import { getForecastPrediction } from "./getForecastPrediction"; -export class IForecastComparisonProps {} +export interface IForecastComparisonProps { + transformations: Map; +} export interface IForecastComparisonState { timeSeriesId?: number; baselinePrediction?: Array<[number, number]>; trueY?: Array<[number, number]>; + transformationPredictions: Map>; + selectedTransformations: Set; } const stackTokens = { @@ -39,7 +44,11 @@ export class ForecastComparison extends React.Component< public constructor(props: IForecastComparisonProps) { super(props); - this.state = {}; + this.state = { + baselinePrediction: undefined, + selectedTransformations: new Set(), + transformationPredictions: new Map>() + }; } public async componentDidMount(): Promise { @@ -58,11 +67,36 @@ export class ForecastComparison extends React.Component< if (currentlySelectedTimeSeriesId !== this.state.timeSeriesId) { const trueY = this.getTrueY(); const baselinePrediction = await this.getBaselineForecastPrediction(); + const selectedTransformationsAndPredictions = + await this.getSelectedForecastPredictions( + [...this.props.transformations.keys()], + true + ); this.setState({ baselinePrediction, + selectedTransformations: + selectedTransformationsAndPredictions.selectedTransformations, timeSeriesId: currentlySelectedTimeSeriesId, + transformationPredictions: + selectedTransformationsAndPredictions.transformationPredictions, trueY }); + return; + } + + // Check if any new transformations were added. + // If so, add their corresponding predictions to this.state.transformationPredictions. + // If we add deletion for transformations we will need to check for transformations + // that have been removed and delete their corresponding predictions, too. + const currentTransformations = [...this.props.transformations.keys()]; + const prevTransformations = new Set( + this.state.transformationPredictions.keys() + ); + const newlyAddedTransformations = currentTransformations.filter( + (t) => !prevTransformations.has(t) + ); + if (newlyAddedTransformations.length > 0) { + this.addSelectedForecastPredictions(newlyAddedTransformations); } } @@ -89,6 +123,17 @@ export class ForecastComparison extends React.Component< type: "spline" } as SeriesOptionsType); } + this.state.selectedTransformations.forEach((transformationName) => { + const transformationPredictions = + this.state.transformationPredictions.get(transformationName); + if (transformationPredictions) { + seriesData.push({ + data: transformationPredictions, + name: transformationName, + type: "spline" + } as SeriesOptionsType); + } + }); return ( @@ -155,6 +200,72 @@ export class ForecastComparison extends React.Component< return undefined; }; + private getSelectedForecastPredictions = async ( + newTransformationNames: string[], + ignoreExisting?: boolean + ): Promise<{ + transformationPredictions: Map>; + selectedTransformations: Set; + }> => { + const newTransformationPredictions = await Promise.all( + newTransformationNames.map(async (newTransformationName) => { + if (this.context.requestForecast === undefined) { + return; + } + const newTransformation = this.props.transformations.get( + newTransformationName + ); + if (newTransformation === undefined) { + return; + } + + const pred = await getForecastPrediction( + newTransformation.cohort.cohort, + this.context.jointDataset, + this.context.requestForecast, + newTransformation + ); + if (pred && this.context.dataset.index) { + return orderByTime(pred, this.getIndices(this.context.dataset.index)); + } + return undefined; + }) + ); + + const newMap = ignoreExisting + ? new Map>() + : new Map(this.state.transformationPredictions); + const newSet = ignoreExisting + ? new Set() + : new Set(this.state.selectedTransformations); + newTransformationNames.forEach((newTransformationName, index) => { + const newPredictions = newTransformationPredictions[index]; + if (newPredictions !== undefined) { + newMap.set(newTransformationName, newPredictions); + newSet.add(newTransformationName); + } + }); + + return { + selectedTransformations: newSet, + transformationPredictions: newMap + }; + }; + + private addSelectedForecastPredictions = async ( + newTransformationNames: string[] + ): Promise => { + const selectedTransformationsAndPredictions = + await this.getSelectedForecastPredictions(newTransformationNames); + + this.setState({ + selectedTransformations: + selectedTransformationsAndPredictions.selectedTransformations, + transformationPredictions: + selectedTransformationsAndPredictions.transformationPredictions + }); + }; + private readonly getTrueY = (): Array<[number, number]> | undefined => { if (this.context.dataset.index) { return orderByTime( diff --git a/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationCreation.tsx b/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationCreation.tsx new file mode 100644 index 0000000000..ebd56d393b --- /dev/null +++ b/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationCreation.tsx @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { + Stack, + Text, + Label, + ComboBox, + IComboBox, + IComboBoxOption, + SpinButton +} from "@fluentui/react"; +import { + defaultModelAssessmentContext, + ModelAssessmentContext +} from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; +import React from "react"; + +import { forecastingDashboardStyles } from "../ForecastingDashboard.styles"; +import { + Operation, + transformationOperations, + Feature, + isMultiplicationOrDivision +} from "../Interfaces/Transformation"; + +export interface ITransformationCreationProps { + transformationName?: string; + transformationOperation?: Operation; + transformationFeature?: Feature; + transformationValue: number; + transformationValueErrorMessage?: string; + onChangeTransformationFeature: (item: IComboBoxOption) => void; + onChangeTransformationValue: (newValue: number) => void; + onChangeTransformationOperation: (operation: Operation) => void; +} + +interface ITransformationCreationState { + featureOptions: IComboBoxOption[]; +} + +export class TransformationCreation extends React.Component< + ITransformationCreationProps, + ITransformationCreationState +> { + public static contextType = ModelAssessmentContext; + public context: React.ContextType = + defaultModelAssessmentContext; + + private transformationValueStep = 0.01; + + public constructor(props: ITransformationCreationProps) { + super(props); + this.state = { featureOptions: [] }; + } + + public componentDidMount(): void { + this.setState({ + featureOptions: this.context.dataset.feature_names + .map((featureName, idx) => { + return { featureName, idx }; + }) + .filter(({ featureName, idx }) => { + const columnMetaName = `Data${idx.toString()}`; + const columnMetadata = + this.context.jointDataset.metaDict[columnMetaName]; + const isDatetimeFeature = + this.context.dataset.feature_metadata?.datetime_features?.includes( + featureName + ) ?? false; + const isTimeSeriesIdColumn = + this.context.dataset.feature_metadata?.time_series_id_column_names?.includes( + featureName + ) ?? false; + return ( + !columnMetadata.isCategorical && + !columnMetadata.treatAsCategorical && + !isDatetimeFeature && + !isTimeSeriesIdColumn + ); + }) + .map(({ featureName, idx }) => { + return { + key: `Data${idx.toString()}`, + text: featureName + } as IComboBoxOption; + }) + }); + } + + public render(): React.ReactNode { + const classNames = forecastingDashboardStyles(); + + return ( + + + + + + + + { + return { key: t.key, text: t.displayName }; + })} + selectedKey={this.props.transformationOperation?.key} + className={classNames.smallDropdown} + onChange={this.onChangeTransformationOperation} + /> + + {this.props.transformationOperation && ( + <> + {isMultiplicationOrDivision(this.props.transformationOperation) && ( + + + { + localization.Forecasting.TransformationCreation + .divisionAndMultiplicationBy + } + + + )} + + + + {this.props.transformationValueErrorMessage && ( +
+ + {this.props.transformationValueErrorMessage} + +
+ )} +
+ + )} +
+ ); + } + + private onChangeTransformationValue = ( + _event: React.SyntheticEvent, + newValue?: string + ): void => { + if (newValue) { + this.props.onChangeTransformationValue(Number(newValue)); + } + }; + + private onChangeTransformationOperation = ( + _event: React.FormEvent, + item?: IComboBoxOption + ): void => { + if (item) { + const transformationOperation = transformationOperations.find( + (op) => op.key === item.key + ); + if (transformationOperation) { + this.props.onChangeTransformationOperation(transformationOperation); + } + } + }; + + private onChangeTransformationFeature = ( + _event: React.FormEvent, + item?: IComboBoxOption + ): void => { + if (item) { + this.props.onChangeTransformationFeature(item); + } + }; +} diff --git a/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationCreationDialog.tsx b/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationCreationDialog.tsx new file mode 100644 index 0000000000..3237714849 --- /dev/null +++ b/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationCreationDialog.tsx @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { + Stack, + PrimaryButton, + TextField, + Text, + IComboBoxOption, + Dialog, + DialogFooter, + DialogType +} from "@fluentui/react"; +import { + defaultModelAssessmentContext, + ModelAssessmentContext +} from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; +import React from "react"; + +import { forecastingDashboardStyles } from "../ForecastingDashboard.styles"; +import { + Transformation, + Operation, + Feature +} from "../Interfaces/Transformation"; + +import { TransformationCreation } from "./TransformationCreation"; + +export interface ITransformationCreationDialogProps { + addTransformation: (name: string, transformation: Transformation) => void; + transformations: Map; + isVisible: boolean; +} + +export interface ITransformationCreationDialogState { + transformationName?: string; + transformationOperation?: Operation; + transformationFeature?: Feature; + transformationValue: number; +} + +export class TransformationCreationDialog extends React.Component< + ITransformationCreationDialogProps, + ITransformationCreationDialogState +> { + public static contextType = ModelAssessmentContext; + public context: React.ContextType = + defaultModelAssessmentContext; + + public constructor(props: ITransformationCreationDialogProps) { + super(props); + this.state = { transformationValue: 2 }; + } + + public componentDidUpdate( + prevProps: Readonly + ): void { + if (this.props.isVisible !== prevProps.isVisible) { + this.setState({ + transformationFeature: undefined, + transformationName: undefined, + transformationOperation: undefined, + transformationValue: 2 + }); + } + } + + public render(): React.ReactNode { + const classNames = forecastingDashboardStyles(); + + let transformationNameErrorMessage = undefined; + if (!this.state.transformationName) { + transformationNameErrorMessage = + localization.Forecasting.TransformationCreation + .scenarioNamingInstructions; + } else if ( + this.state.transformationName && + this.props.transformations.has(this.state.transformationName) + ) { + transformationNameErrorMessage = + localization.Forecasting.TransformationCreation + .scenarioNamingCollisionMessage; + } + + let transformationValueErrorMessage = undefined; + if ( + this.state.transformationOperation && + (this.state.transformationValue < + this.state.transformationOperation.minValue || + this.state.transformationValue > + this.state.transformationOperation.maxValue || + this.state.transformationOperation.excludedValues.includes( + this.state.transformationValue + )) + ) { + transformationValueErrorMessage = localization.formatString( + localization.Forecasting.TransformationCreation.valueErrorMessage, + this.state.transformationOperation.displayName, + this.state.transformationOperation.minValue, + this.state.transformationOperation.maxValue, + this.state.transformationOperation.excludedValues.toString() + ); + } + + let transformationCombinationErrorMessage = undefined; + if ( + this.state.transformationOperation && + this.state.transformationFeature + ) { + // ensure the current selection isn't a duplicate + const transformation = this.createTransformation(); + if (transformation) { + this.props.transformations.forEach((existingTransformation) => { + const equalsResult = transformation.equals(existingTransformation); + if (equalsResult) { + transformationCombinationErrorMessage = + localization.Forecasting.TransformationCreation + .invalidCombinationErrorMessage; + } + }); + } + } + return ( + + ); + } + + private onChangeTransformationName = ( + _event: React.FormEvent, + newValue?: string + ): void => { + this.setState({ transformationName: newValue || "" }); + }; + + private onChangeTransformationValue = (newValue: number): void => { + this.setState({ transformationValue: newValue }); + }; + + private onChangeTransformationOperation = (operation: Operation): void => { + this.setState({ + transformationOperation: operation + }); + }; + + private onChangeTransformationFeature = (item: IComboBoxOption): void => { + this.setState({ + transformationFeature: { key: item.key as string, text: item.text } + }); + }; + + private addTransformation = (): void => { + const transformation = this.createTransformation(); + if (this.state.transformationName && transformation) { + this.props.addTransformation( + this.state.transformationName, + transformation + ); + } + }; + + private createTransformation(): Transformation | undefined { + if ( + this.state.transformationFeature && + this.state.transformationOperation + ) { + return new Transformation( + this.context.baseErrorCohort, + this.state.transformationOperation, + this.state.transformationFeature, + this.state.transformationValue + ); + } + return undefined; + } +} diff --git a/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationsTable.tsx b/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationsTable.tsx new file mode 100644 index 0000000000..745cd533c8 --- /dev/null +++ b/libs/forecasting/src/lib/ForecastingDashboard/Controls/TransformationsTable.tsx @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { DetailsList, SelectionMode, Stack, Text } from "@fluentui/react"; +import { + defaultModelAssessmentContext, + ModelAssessmentContext +} from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; +import React from "react"; + +import { forecastingDashboardStyles } from "../ForecastingDashboard.styles"; +import { + isMultiplicationOrDivision, + Transformation +} from "../Interfaces/Transformation"; + +interface ITransformationsTableProps { + transformations: Map; +} + +interface ITransformationsTableState { + rows: ITransformationRow[]; +} + +const stackTokens = { + childrenGap: "l1" +}; + +interface ITransformationRow { + key: string; + transformationName: string; + method: string; +} + +export class TransformationsTable extends React.Component< + ITransformationsTableProps, + ITransformationsTableState +> { + public static contextType = ModelAssessmentContext; + public context: React.ContextType = + defaultModelAssessmentContext; + + public constructor(props: ITransformationsTableProps) { + super(props); + this.state = { rows: this.calculateUpdatedTable() }; + } + + public componentDidUpdate(): void { + // Currently, transformations are not editable or deletable. + // If that changes in the future we will have to update these checks. + const nTransformations = this.props.transformations.size; + const prevTransformationNames = new Set( + this.state.rows.map((t) => t.transformationName) + ); + const didUpdate = + prevTransformationNames.size !== nTransformations || + nTransformations !== + [...this.props.transformations.keys()].filter((t) => + prevTransformationNames.has(t) + ).length; + + if (didUpdate) { + this.setState({ rows: this.calculateUpdatedTable() }); + } + } + + public render(): React.ReactNode { + if ( + this.props.transformations.size === 0 || + this.state === undefined || + this.state.rows === undefined || + this.state.rows.length === 0 + ) { + return; + } + const classNames = forecastingDashboardStyles(); + + const forecastNames: string[] = []; + const forecastTransformations: Transformation[] = []; + + for (const [ + forecastName, + forecastTransformation + ] of this.props.transformations.entries()) { + forecastNames.push(forecastName); + forecastTransformations.push(forecastTransformation); + } + + return ( + + + + {localization.formatString( + localization.Forecasting.TransformationTable.header, + this.props.transformations.size + )} + + + + + + + ); + } + + public calculateUpdatedTable(): ITransformationRow[] { + const rows: ITransformationRow[] = [ + ...this.props.transformations.entries() + ].map(([transformationName, transformation], index) => { + const method = `${transformation.feature.text} ${ + transformation.operation.displayName + } ${ + isMultiplicationOrDivision(transformation.operation) + ? localization.Forecasting.TransformationTable + .divisionAndMultiplicationBy + : "" + }${transformation.value.toString()}`; + return { + key: index.toString(), + method, + transformationName + }; + }); + return rows; + } +} diff --git a/libs/forecasting/src/lib/ForecastingDashboard/Controls/getForecastPrediction.ts b/libs/forecasting/src/lib/ForecastingDashboard/Controls/getForecastPrediction.ts index ce8bdbe39e..3c8903418d 100644 --- a/libs/forecasting/src/lib/ForecastingDashboard/Controls/getForecastPrediction.ts +++ b/libs/forecasting/src/lib/ForecastingDashboard/Controls/getForecastPrediction.ts @@ -3,12 +3,15 @@ import { Cohort, JointDataset } from "@responsible-ai/core-ui"; +import { Transformation } from "../Interfaces/Transformation"; + export async function getForecastPrediction( cohort: Cohort, jointDataset: JointDataset, requestForecast: | ((request: any[], abortSignal: AbortSignal) => Promise) - | undefined + | undefined, + transformation?: Transformation ): Promise { if (requestForecast === undefined) { return; @@ -17,7 +20,13 @@ export async function getForecastPrediction( [ Cohort.getLabeledFilters(cohort.filters, jointDataset), Cohort.getLabeledCompositeFilters(cohort.compositeFilters, jointDataset), - [] + transformation + ? [ + transformation.operation.key, + transformation.feature.key, + transformation.value + ] + : [] ], new AbortController().signal ); diff --git a/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.tsx b/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.tsx index dda3a2bf01..31a147ddf8 100644 --- a/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.tsx +++ b/libs/forecasting/src/lib/ForecastingDashboard/ForecastingDashboard.tsx @@ -1,7 +1,13 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -import { Dropdown, IDropdownOption, Stack, Text } from "@fluentui/react"; +import { + Dropdown, + IDropdownOption, + PrimaryButton, + Stack, + Text +} from "@fluentui/react"; import { defaultModelAssessmentContext, isAllDataErrorCohort, @@ -11,11 +17,17 @@ import { localization } from "@responsible-ai/localization"; import React from "react"; import { ForecastComparison } from "./Controls/ForecastComparison"; +import { TransformationCreationDialog } from "./Controls/TransformationCreationDialog"; +import { TransformationsTable } from "./Controls/TransformationsTable"; import { forecastingDashboardStyles } from "./ForecastingDashboard.styles"; +import { Transformation } from "./Interfaces/Transformation"; export class IForecastingDashboardProps {} -export class IForecastingDashboardState {} +export interface IForecastingDashboardState { + transformations?: Map>; + isTransformationCreatorVisible: boolean; +} export class ForecastingDashboard extends React.Component< IForecastingDashboardProps, @@ -25,6 +37,14 @@ export class ForecastingDashboard extends React.Component< public context: React.ContextType = defaultModelAssessmentContext; + public constructor(props: IForecastingDashboardProps) { + super(props); + + this.state = { + isTransformationCreatorVisible: false + }; + } + public render(): React.ReactNode { const classNames = forecastingDashboardStyles(); @@ -48,35 +68,68 @@ export class ForecastingDashboard extends React.Component< }; }); + const cohortTransformations = + this.state.transformations?.get( + this.context.baseErrorCohort.cohort.getCohortID() + ) ?? new Map(); + return ( - - - {localization.Forecasting.whatIfDescription} - - - - - {!noCohortSelected && ( + <> + + + {localization.Forecasting.whatIfDescription} + - + - )} - + {!noCohortSelected && ( + <> + + { + this.setState({ isTransformationCreatorVisible: true }); + }} + text={localization.Forecasting.TransformationCreation.title} + /> + + + + {cohortTransformations.size > 0 && ( + + + + )} + + + + + + )} + + + ); } @@ -94,4 +147,23 @@ export class ForecastingDashboard extends React.Component< } } }; + + private addTransformation = ( + name: string, + transformation: Transformation + ): void => { + const currentCohortID = this.context.baseErrorCohort.cohort.getCohortID(); + const newMap = + this.state.transformations ?? + new Map>(); + const cohortMap = + this.state.transformations?.get(currentCohortID) ?? + new Map(); + cohortMap.set(name, transformation); + newMap.set(currentCohortID, cohortMap); + this.setState({ + isTransformationCreatorVisible: false, + transformations: newMap + }); + }; } diff --git a/libs/forecasting/src/lib/ForecastingDashboard/Interfaces/Transformation.ts b/libs/forecasting/src/lib/ForecastingDashboard/Interfaces/Transformation.ts new file mode 100644 index 0000000000..cd880a1d55 --- /dev/null +++ b/libs/forecasting/src/lib/ForecastingDashboard/Interfaces/Transformation.ts @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { ErrorCohort } from "@responsible-ai/core-ui"; +import { localization } from "@responsible-ai/localization"; + +export type Operation = { + key: string; + displayName: string; + minValue: number; + maxValue: number; + excludedValues: number[]; +}; + +export function isMultiplicationOrDivision(operation: Operation): boolean { + return ["multiply", "divide"].includes(operation.key); +} + +export const transformationOperations: Operation[] = [ + { + displayName: localization.Forecasting.Transformations.multiply, + excludedValues: [0, 1], + key: "multiply", + maxValue: 1000, + minValue: -1000 + }, + { + displayName: localization.Forecasting.Transformations.divide, + excludedValues: [0, 1], + key: "divide", + maxValue: 1000, + minValue: -1000 + }, + { + displayName: localization.Forecasting.Transformations.add, + excludedValues: [0], + key: "add", + maxValue: 1000, + minValue: -1000 + }, + { + displayName: localization.Forecasting.Transformations.subtract, + excludedValues: [0], + key: "subtract", + maxValue: 1000, + minValue: -1000 + } +]; + +export type Feature = { + key: string; + text: string; +}; + +export class Transformation { + public cohort: ErrorCohort; + public operation: Operation; + public feature: Feature; + public value: number; + + public constructor( + cohort: ErrorCohort, + operation: Operation, + feature: Feature, + value: number + ) { + this.cohort = cohort; + this.operation = operation; + this.feature = feature; + this.value = value; + } + + public equals(obj: Transformation): boolean { + return ( + this.cohort.cohort.getCohortID() === obj.cohort.cohort.getCohortID() && + this.operation.key === obj.operation.key && + this.feature.key === obj.feature.key && + this.value === obj.value + ); + } +} diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ShiftCohort.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ShiftCohort.tsx index 3ae107336d..fea8804170 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ShiftCohort.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Cohort/ShiftCohort.tsx @@ -43,14 +43,12 @@ export class ShiftCohort extends React.Component< defaultModelAssessmentContext; public componentDidMount(): void { - const savedCohorts = this.context.errorCohorts - .filter((errorCohort) => !errorCohort.isTemporary) - .filter( - (errorCohort) => - !errorCohort.isTemporary && - (this.props.showAllDataCohort || - !isAllDataErrorCohort(errorCohort, true)) - ); + const savedCohorts = this.context.errorCohorts.filter( + (errorCohort) => + !errorCohort.isTemporary && + (this.props.showAllDataCohort || + !isAllDataErrorCohort(errorCohort, true)) + ); const options: IDropdownOption[] = savedCohorts.map( (savedCohort: ErrorCohort, index: number) => { return { key: index, text: savedCohort.cohort.name }; From a635d6ddbf0d1263b594adf2a64d45677d4df91a Mon Sep 17 00:00:00 2001 From: Kashyap Patel <64443771+ms-kashyap@users.noreply.github.com> Date: Fri, 13 Jan 2023 13:29:05 -0500 Subject: [PATCH 03/11] Make `DataBalanceManager.compute()` throw warning instead of exception (#1902) * Make DataBalanceManager.compute() throw warning instead of exception since it is enabled by default * Lint fix * Isort fix --- .../managers/data_balance_manager.py | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/responsibleai/responsibleai/managers/data_balance_manager.py b/responsibleai/responsibleai/managers/data_balance_manager.py index f6023de707..cb51046e61 100644 --- a/responsibleai/responsibleai/managers/data_balance_manager.py +++ b/responsibleai/responsibleai/managers/data_balance_manager.py @@ -4,6 +4,7 @@ """Defines the Data Balance Manager class.""" import json +import warnings from pathlib import Path from typing import Dict, List @@ -157,30 +158,35 @@ def compute(self): if not self._is_added: return - self._validate() + try: + self._validate() - self._df = prepare_df(df=self._df) + self._df = prepare_df(df=self._df) - feature_balance_measures = {} - for pos_label in self._classes: - feature_balance_measures[pos_label] = FeatureBalanceMeasures( - cols_of_interest=self._cols_of_interest, - label_col=self._target_column, - pos_label=pos_label, - ).measures(dataset=self._df) + feature_balance_measures = {} + for pos_label in self._classes: + feature_balance_measures[pos_label] = FeatureBalanceMeasures( + cols_of_interest=self._cols_of_interest, + label_col=self._target_column, + pos_label=pos_label, + ).measures(dataset=self._df) - distribution_balance_measures = DistributionBalanceMeasures( - cols_of_interest=self._cols_of_interest - ).measures(dataset=self._df) - aggregate_balance_measures = AggregateBalanceMeasures( - cols_of_interest=self._cols_of_interest - ).measures(dataset=self._df) + distribution_balance_measures = DistributionBalanceMeasures( + cols_of_interest=self._cols_of_interest + ).measures(dataset=self._df) + aggregate_balance_measures = AggregateBalanceMeasures( + cols_of_interest=self._cols_of_interest + ).measures(dataset=self._df) - self._set_data_balance_measures( - feature_balance_measures=feature_balance_measures, - distribution_balance_measures=distribution_balance_measures, - aggregate_balance_measures=aggregate_balance_measures, - ) + self._set_data_balance_measures( + feature_balance_measures=feature_balance_measures, + distribution_balance_measures=distribution_balance_measures, + aggregate_balance_measures=aggregate_balance_measures, + ) + except Exception as e: + warnings.warn( + f"Failed to compute data balance measures due to {e!r}." + ) def _set_data_balance_measures( self, From cac55903f53b932a1f594c88bb9a89979ab25371 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Fri, 13 Jan 2023 15:32:25 -0500 Subject: [PATCH 04/11] raiwidgets: Fix url in setup.py (#1906) --- raiwidgets/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/raiwidgets/setup.py b/raiwidgets/setup.py index 122e4d1b91..b5cad4243c 100644 --- a/raiwidgets/setup.py +++ b/raiwidgets/setup.py @@ -32,7 +32,7 @@ "Machine Learning models.", long_description=long_description, long_description_content_type="text/markdown", - url="https://github.com/microsoft/responsible-ai-widgets", + url="https://github.com/microsoft/responsible-ai-toolbox", packages=setuptools.find_packages(), python_requires='>=3.6', install_requires=install_requires, From 6c25ee28d206dfd8c28d90ff74c66d223f03bf02 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Fri, 13 Jan 2023 15:39:56 -0800 Subject: [PATCH 05/11] Raise `UserConfigValidationException` when `treatment_feature` is empty list (#1904) Signed-off-by: Gaurav Gupta Signed-off-by: Gaurav Gupta --- .../responsibleai/managers/causal_manager.py | 8 +++++++ .../test_rai_insights_validations.py | 24 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/responsibleai/responsibleai/managers/causal_manager.py b/responsibleai/responsibleai/managers/causal_manager.py index 4833f3984e..7b2f4e49eb 100644 --- a/responsibleai/responsibleai/managers/causal_manager.py +++ b/responsibleai/responsibleai/managers/causal_manager.py @@ -157,6 +157,14 @@ def add( :param random_state: Controls the randomness of the estimator. :type random_state: int or RandomState or None """ + if not isinstance(treatment_features, list): + raise UserConfigValidationException( + "Expecting a list for treatment_features but got {0}".format( + type(treatment_features))) + if len(treatment_features) == 0: + raise UserConfigValidationException( + "Please specify at least one feature in " + "treatment_features list") for feature in treatment_features: if self._feature_metadata and \ self._feature_metadata.dropped_features and \ diff --git a/responsibleai/tests/rai_insights/test_rai_insights_validations.py b/responsibleai/tests/rai_insights/test_rai_insights_validations.py index 3b61612aad..8799365fed 100644 --- a/responsibleai/tests/rai_insights/test_rai_insights_validations.py +++ b/responsibleai/tests/rai_insights/test_rai_insights_validations.py @@ -555,6 +555,30 @@ def test_treatment_features_list_not_having_train_features(self): with pytest.raises(UserConfigValidationException, match=message): rai_insights.causal.add(treatment_features=['not_a_feature']) + def test_treatment_features_list_not_having_any_features(self): + X_train, y_train, X_test, y_test, _ = \ + create_binary_classification_dataset() + + model = create_lightgbm_classifier(X_train, y_train) + X_train[TARGET] = y_train + X_test[TARGET] = y_test + + rai_insights = RAIInsights( + model=model, + train=X_train, + test=X_test, + target_column=TARGET, + task_type='classification') + + message = ("Please specify at least one feature in " + "treatment_features list") + with pytest.raises(UserConfigValidationException, match=message): + rai_insights.causal.add(treatment_features=[]) + + message = ("Expecting a list for treatment_features but got") + with pytest.raises(UserConfigValidationException, match=message): + rai_insights.causal.add(treatment_features={}) + def test_treatment_features_having_dropped_features(self): X_train, y_train, X_test, y_test, _ = \ create_binary_classification_dataset() From 527bc6b18e749fc5ac71f96d93c9de62caf9e5f5 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Sun, 15 Jan 2023 11:46:33 -0500 Subject: [PATCH 06/11] fix small data erroring out on mutual info score for error analysis guidance (#1907) --- .../erroranalysis/analyzer/error_analyzer.py | 11 ++++++++-- erroranalysis/tests/test_importances.py | 22 ++++++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/erroranalysis/erroranalysis/analyzer/error_analyzer.py b/erroranalysis/erroranalysis/analyzer/error_analyzer.py index 0dd6228d5e..7589274f83 100644 --- a/erroranalysis/erroranalysis/analyzer/error_analyzer.py +++ b/erroranalysis/erroranalysis/analyzer/error_analyzer.py @@ -472,10 +472,17 @@ def _compute_mutual_info(self, input_data, diff): error. :rtype: list[float] """ + # if only one row, replicate it to avoid exception + if input_data.shape[0] == 1: + input_data = np.concatenate((input_data, input_data)) + diff = np.concatenate((diff, diff)) + n_neighbors = min(3, input_data.shape[0] - 1) if self._model_task == ModelTask.CLASSIFICATION: - return mutual_info_classif(input_data, diff).tolist() + return mutual_info_classif( + input_data, diff, n_neighbors=n_neighbors).tolist() else: - return mutual_info_regression(input_data, diff).tolist() + return mutual_info_regression( + input_data, diff, n_neighbors=n_neighbors).tolist() def compute_root_stats(self): """Compute the root all data statistics. diff --git a/erroranalysis/tests/test_importances.py b/erroranalysis/tests/test_importances.py index ead24d33a5..d7fa15d219 100644 --- a/erroranalysis/tests/test_importances.py +++ b/erroranalysis/tests/test_importances.py @@ -4,6 +4,7 @@ import time import numpy as np +import pytest from common_utils import replicate_dataset from erroranalysis._internal.constants import ModelTask @@ -15,9 +16,12 @@ from rai_test_utils.models.model_utils import (create_models_classification, create_models_regression) from rai_test_utils.models.sklearn import ( + create_sklearn_random_forest_classifier, create_sklearn_random_forest_regressor, create_titanic_pipeline) TOL = 1e-10 +NUM_SAMPLE_ROWS = 100 +DEFAULT_SAMPLE_COLS = 20 class TestImportances(object): @@ -77,7 +81,7 @@ def test_large_data_importances(self): # mutual information can be very costly for large number of rows # hence, assert we downsample to compute importances for large data X_train, y_train, X_test, y_test, _ = \ - create_binary_classification_dataset(100) + create_binary_classification_dataset(NUM_SAMPLE_ROWS) feature_names = list(X_train.columns) model = create_sklearn_random_forest_regressor(X_train, y_train) X_test, y_test = replicate_dataset(X_test, y_test) @@ -95,6 +99,22 @@ def test_large_data_importances(self): # note execution time is in seconds assert execution_time < 20 + @pytest.mark.parametrize('num_rows', [1, 2, 3, 4]) + def test_small_data_importances(self, num_rows): + # validate we can run on very few rows + X_train, y_train, X_test, y_test, _ = \ + create_binary_classification_dataset(NUM_SAMPLE_ROWS) + feature_names = list(X_train.columns) + model = create_sklearn_random_forest_classifier(X_train, y_train) + X_test = X_test[:num_rows] + y_test = y_test[:num_rows] + categorical_features = [] + model_analyzer = ModelAnalyzer(model, X_test, y_test, + feature_names, + categorical_features) + scores = model_analyzer.compute_importances() + assert len(scores) == DEFAULT_SAMPLE_COLS + def test_importances_missings(self): X_train, X_test, y_train, y_test, feature_names, _ = create_iris_data() From cb3b74fe85964ee0d44cd8d3dd0cc804f73f9ced Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 17 Jan 2023 11:53:16 -0500 Subject: [PATCH 07/11] release erroranalysis v0.4.0 (#1909) --- erroranalysis/erroranalysis/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/erroranalysis/erroranalysis/version.py b/erroranalysis/erroranalysis/version.py index 6d47b0a2c0..e7725d6ac3 100644 --- a/erroranalysis/erroranalysis/version.py +++ b/erroranalysis/erroranalysis/version.py @@ -3,6 +3,6 @@ name = 'erroranalysis' _major = '0' -_minor = '3' -_patch = '13' +_minor = '4' +_patch = '0' version = '{}.{}.{}'.format(_major, _minor, _patch) From 04419236c5961b764300cc55bee0fd2a73581697 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 17 Jan 2023 13:57:40 -0500 Subject: [PATCH 08/11] update raiwidgets and responsibleai to erroranalysis 0.4.0 (#1910) --- raiwidgets/requirements.txt | 2 +- responsibleai/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/raiwidgets/requirements.txt b/raiwidgets/requirements.txt index 11410b0050..5247f090f2 100644 --- a/raiwidgets/requirements.txt +++ b/raiwidgets/requirements.txt @@ -5,6 +5,6 @@ rai-core-flask==0.5.0 itsdangerous==2.0.1 scikit-learn>=0.22.1 lightgbm>=2.0.11 -erroranalysis>=0.3.13 +erroranalysis>=0.4.0 fairlearn>=0.7.0 raiutils>=0.3.0 diff --git a/responsibleai/requirements.txt b/responsibleai/requirements.txt index 5413aec321..8975c19735 100644 --- a/responsibleai/requirements.txt +++ b/responsibleai/requirements.txt @@ -1,7 +1,7 @@ dice-ml>=0.9,<0.10 econml>=0.14.0 jsonschema -erroranalysis>=0.3.13 +erroranalysis>=0.4.0 interpret-community>=0.28.0 lightgbm>=2.0.11 numpy>=1.17.2,<1.24.0 From 5f6dd92ef2a6569a6b7eec130003243a5a795b2d Mon Sep 17 00:00:00 2001 From: Vinutha Karanth Date: Tue, 17 Jan 2023 13:14:24 -0800 Subject: [PATCH 09/11] Enable logarithmic scaling, reflect cohort change and add add Revert to bubble chart button (#1905) * bub ch Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * root dep Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * localcounter code Signed-off-by: vinutha karanth * update interface Signed-off-by: vinutha karanth * local count Signed-off-by: vinutha karanth * update local imp chart marker Signed-off-by: vinutha karanth * allowtreatascat Signed-off-by: vinutha karanth * add few changes Signed-off-by: vinutha karanth * enus change Signed-off-by: vinutha karanth * custom points updates Signed-off-by: vinutha karanth * custom point abs index changes Signed-off-by: vinutha karanth * make local imp call async Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * enus change Signed-off-by: vinutha karanth * create new comp for large counter data Signed-off-by: vinutha karanth * add loading for bubble Signed-off-by: vinutha karanth * loading update for local imp Signed-off-by: vinutha karanth * rem unwanted code in largecc Signed-off-by: vinutha karanth * remove console log Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * lintfix and err fix Signed-off-by: vinutha karanth * lintfix and err fix Signed-off-by: vinutha karanth * update snap Signed-off-by: vinutha karanth * add err message for local fetch Signed-off-by: vinutha karanth * bub data err msg Signed-off-by: vinutha karanth * rem dups Signed-off-by: vinutha karanth * en log Signed-off-by: vinutha karanth * move to diff folder Signed-off-by: vinutha karanth * ena log and cohort Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * revert button changes Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * style and intfix Signed-off-by: vinutha karanth * lintfix Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth * reset index Signed-off-by: vinutha karanth * add instance Signed-off-by: vinutha karanth * remove any Signed-off-by: vinutha karanth * address comment Signed-off-by: vinutha karanth * update Signed-off-by: vinutha karanth Signed-off-by: vinutha karanth --- .../core-ui/src/lib/components/AxisConfig.tsx | 2 +- .../lib/components/AxisConfigBinOptions.tsx | 62 ++++++---- .../src/lib/components/AxisConfigDialog.tsx | 3 +- .../src/lib/CounterfactualChart.styles.ts | 7 ++ .../src/lib/CounterfactualChartLegend.tsx | 31 ++++- .../src/lib/CounterfactualChartWithLegend.tsx | 15 ++- .../src/lib/CounterfactualComponent.tsx | 52 +++++++-- .../src/lib/CounterfactualComponentUtils.ts | 25 ++++ .../LargeCounterfactualChart.tsx | 108 ++++++++++++++---- .../calculateBubbleData.ts | 53 +++++---- libs/localization/src/lib/en.json | 1 + 11 files changed, 279 insertions(+), 80 deletions(-) create mode 100644 libs/counterfactuals/src/lib/CounterfactualComponentUtils.ts diff --git a/libs/core-ui/src/lib/components/AxisConfig.tsx b/libs/core-ui/src/lib/components/AxisConfig.tsx index 0ad2ede729..57a4684625 100644 --- a/libs/core-ui/src/lib/components/AxisConfig.tsx +++ b/libs/core-ui/src/lib/components/AxisConfig.tsx @@ -17,7 +17,7 @@ export interface IAxisConfigProps { canBin: boolean; mustBin: boolean; canDither: boolean; - allowTreatAsCategorical?: boolean; + allowTreatAsCategorical: boolean; hideDroppedFeatures?: boolean; onAccept: (newConfig: ISelectorConfig) => void; } diff --git a/libs/core-ui/src/lib/components/AxisConfigBinOptions.tsx b/libs/core-ui/src/lib/components/AxisConfigBinOptions.tsx index b7536a550d..28a947fa6e 100644 --- a/libs/core-ui/src/lib/components/AxisConfigBinOptions.tsx +++ b/libs/core-ui/src/lib/components/AxisConfigBinOptions.tsx @@ -8,6 +8,7 @@ import React from "react"; import { AxisTypes, ISelectorConfig } from "../util/IGenericChartProps"; import { JointDataset } from "../util/JointDataset"; +import { IJointMeta } from "../util/JointDatasetUtils"; import { axisConfigBinOptionsStyles } from "./AxisConfigBinOptions.styles"; import { AxisConfigDialogSpinButton } from "./AxisConfigDialogSpinButton"; @@ -26,7 +27,8 @@ export interface IAxisConfigBinOptionsProps { minHistCols: number; mustBin: boolean; selectedBinCount?: number; - allowTreatAsCategorical?: boolean; + allowTreatAsCategorical: boolean; + allowLogarithmicScaling?: boolean; selectedColumn: ISelectorConfig; onBinCountUpdated: (binCount?: number) => void; onEnableLogarithmicScaling: (checked?: boolean | undefined) => void; @@ -54,28 +56,24 @@ export class AxisConfigBinOptions extends React.PureComponent )} - {(selectedMeta.featureRange?.rangeType === RangeTypes.Integer || - selectedMeta.featureRange?.rangeType === RangeTypes.Numeric) && - allowUserInteract(this.props.selectedColumn.property) && ( - - )} - {selectedMeta.featureRange?.rangeType === RangeTypes.Integer && - this.props.allowTreatAsCategorical && - allowUserInteract(this.props.selectedColumn.property) && ( - - )} + {this.displayLogarithmicToggle(selectedMeta) && ( + + )} + {this.displayCategoricalToggle(selectedMeta) && ( + + )} {selectedMeta?.treatAsCategorical ? ( <> @@ -126,6 +124,24 @@ export class AxisConfigBinOptions extends React.PureComponent { + const allowLogarithmicScaling = this.props.allowLogarithmicScaling ?? true; + return ( + (selectedMeta.featureRange?.rangeType === RangeTypes.Integer || + selectedMeta.featureRange?.rangeType === RangeTypes.Numeric) && + allowLogarithmicScaling && + allowUserInteract(this.props.selectedColumn.property) + ); + }; + + private displayCategoricalToggle = (selectedMeta: IJointMeta): boolean => { + return ( + selectedMeta.featureRange?.rangeType === RangeTypes.Integer && + this.props.allowTreatAsCategorical && + allowUserInteract(this.props.selectedColumn.property) + ); + }; + private readonly setAsCategorical = ( _ev?: React.FormEvent, checked?: boolean diff --git a/libs/core-ui/src/lib/components/AxisConfigDialog.tsx b/libs/core-ui/src/lib/components/AxisConfigDialog.tsx index d608c439f1..1ee57367bd 100644 --- a/libs/core-ui/src/lib/components/AxisConfigDialog.tsx +++ b/libs/core-ui/src/lib/components/AxisConfigDialog.tsx @@ -39,7 +39,8 @@ export interface IAxisConfigDialogProps { canBin: boolean; mustBin: boolean; canDither: boolean; - allowTreatAsCategorical?: boolean; + allowTreatAsCategorical: boolean; + allowLogarithmicScaling?: boolean; hideDroppedFeatures?: boolean; onAccept: (newConfig: ISelectorConfig) => void; onCancel: () => void; diff --git a/libs/counterfactuals/src/lib/CounterfactualChart.styles.ts b/libs/counterfactuals/src/lib/CounterfactualChart.styles.ts index 88bfbcd955..368eda7aed 100644 --- a/libs/counterfactuals/src/lib/CounterfactualChart.styles.ts +++ b/libs/counterfactuals/src/lib/CounterfactualChart.styles.ts @@ -16,11 +16,18 @@ export interface ICounterfactualChartStyles { lowerChartContainer: IStyle; rotatedVerticalBox: IStyle; verticalAxis: IStyle; + buttonStyle: IStyle; } export const counterfactualChartStyles: () => IProcessedStyleSet = () => { return mergeStyleSets({ + buttonStyle: { + marginBottom: "10px", + marginTop: "10px", + paddingBottom: "10px", + paddingTop: "10px" + }, chartWithAxes: { ...fullLgDown, paddingTop: "30px", diff --git a/libs/counterfactuals/src/lib/CounterfactualChartLegend.tsx b/libs/counterfactuals/src/lib/CounterfactualChartLegend.tsx index 646d7aa7dc..69878f1f53 100644 --- a/libs/counterfactuals/src/lib/CounterfactualChartLegend.tsx +++ b/libs/counterfactuals/src/lib/CounterfactualChartLegend.tsx @@ -6,7 +6,8 @@ import { IComboBox, ComboBox, PrimaryButton, - Stack + Stack, + DefaultButton } from "@fluentui/react"; import { defaultModelAssessmentContext, @@ -38,6 +39,7 @@ export interface ICounterfactualChartLegendProps { selectedPointsIndexes: number[]; indexSeries: number[]; isCounterfactualsDataLoading?: boolean; + isBubbleChartRendered?: boolean; removeCustomPoint: (index: number) => void; setTemporaryPointToCopyOfDatasetPoint: ( index: number, @@ -48,6 +50,7 @@ export interface ICounterfactualChartLegendProps { toggleCustomActivation: (index: number) => void; togglePanel: () => void; toggleSelectionOfPoint: (index?: number) => void; + setIsRevertButtonClicked: (status: boolean) => void; } export class CounterfactualChartLegend extends React.PureComponent { @@ -59,7 +62,7 @@ export class CounterfactualChartLegend extends React.PureComponent - {this.displayDatapointDropbox() && ( + {this.displayComboBox() && ( + {this.displayRevertButton() && ( + + )} {this.props.customPoints.length > 0 && ( { @@ -119,7 +131,7 @@ export class CounterfactualChartLegend extends React.PureComponent 0; } + private displayRevertButton(): boolean { + return ( + ifEnableLargeData(this.context.dataset) && + this.props.indexSeries.length > 0 + ); + } + private selectPointFromDropdown = ( _event: React.FormEvent, item?: IComboBoxOption @@ -156,6 +175,10 @@ export class CounterfactualChartLegend extends React.PureComponent { + this.props.setIsRevertButtonClicked(true); + }; + private getDataOptions(): IComboBoxOption[] { let indexes = this.context.selectedErrorCohort.cohort.unwrap( JointDataset.IndexLabel diff --git a/libs/counterfactuals/src/lib/CounterfactualChartWithLegend.tsx b/libs/counterfactuals/src/lib/CounterfactualChartWithLegend.tsx index d8f2f4bb24..dfab12ff82 100644 --- a/libs/counterfactuals/src/lib/CounterfactualChartWithLegend.tsx +++ b/libs/counterfactuals/src/lib/CounterfactualChartWithLegend.tsx @@ -28,6 +28,7 @@ export interface ICounterfactualChartWithLegendProps { indexSeries: number[]; temporaryPoint: { [key: string]: any } | undefined; isCounterfactualsDataLoading?: boolean; + isRevertButtonClicked: boolean; onChartPropsUpdated: (chartProps: IGenericChartProps) => void; onCustomPointLengthUpdated: (customPointLength: number) => void; onSelectedPointsIndexesUpdated: (selectedPointsIndexes: number[]) => void; @@ -48,6 +49,8 @@ export interface ICounterfactualChartWithLegendProps { setCounterfactualData: (absoluteIndex?: number) => Promise; telemetryHook?: (message: ITelemetryEvent) => void; onIndexSeriesUpdated?: (indexSeries: number[]) => void; + setIsRevertButtonClicked: (status: boolean) => void; + resetIndexes: () => void; } export interface ICounterfactualChartWithLegendState { @@ -128,6 +131,7 @@ export class CounterfactualChartWithLegend extends React.PureComponent< isCounterfactualsDataLoading={ this.props.isCounterfactualsDataLoading } + setIsRevertButtonClicked={this.props.setIsRevertButtonClicked} />
@@ -179,12 +183,16 @@ export class CounterfactualChartWithLegend extends React.PureComponent< }; private onChartPropsUpdated = (newProps: IGenericChartProps): void => { + this.resetCustomPoints(); + this.props.onChartPropsUpdated(newProps); + }; + + private resetCustomPoints = (): void => { this.setState({ customPointIsActive: [], customPoints: [], pointIsActive: [] }); - this.props.onChartPropsUpdated(newProps); }; private saveAsPoint = (): void => { @@ -222,6 +230,7 @@ export class CounterfactualChartWithLegend extends React.PureComponent< private getLargeCounterfactualChartComponent = (): React.ReactNode => { return ( ); }; diff --git a/libs/counterfactuals/src/lib/CounterfactualComponent.tsx b/libs/counterfactuals/src/lib/CounterfactualComponent.tsx index eabf124e9e..dead922f92 100644 --- a/libs/counterfactuals/src/lib/CounterfactualComponent.tsx +++ b/libs/counterfactuals/src/lib/CounterfactualComponent.tsx @@ -29,6 +29,7 @@ import { getSortArrayAndIndex } from "../util/getSortArrayAndIndex"; import { getLocalCounterfactualsFromSDK } from "../util/largeCounterfactualsView/getOnScatterPlotPointClick"; import { CounterfactualChartWithLegend } from "./CounterfactualChartWithLegend"; +import { hasAxisTypeChanged } from "./CounterfactualComponentUtils"; import { CounterfactualErrorDialog } from "./CounterfactualErrorDialog"; import { CounterfactualLocalImportanceChart } from "./CounterfactualLocalImportanceChart"; export interface ICounterfactualComponentProps { @@ -53,6 +54,7 @@ export interface ICounterfactualComponentState { indexSeries: number[]; isCounterfactualsDataLoading?: boolean; localCounterfactualErrorMessage?: string; + isRevertButtonClicked: boolean; } export class CounterfactualComponent extends React.PureComponent< @@ -65,6 +67,7 @@ export class CounterfactualComponent extends React.PureComponent< private selectedFeatureImportance: IGlobalSeries[] = []; private validationErrors: { [key: string]: string | undefined } = {}; private temporaryPoint: { [key: string]: any } | undefined; + private changedKeys: string[] = []; public constructor(props: ICounterfactualComponentProps) { super(props); @@ -73,6 +76,7 @@ export class CounterfactualComponent extends React.PureComponent< customPointLength: 0, indexSeries: [], isCounterfactualsDataLoading: false, + isRevertButtonClicked: false, localCounterfactualErrorMessage: undefined, request: undefined, selectedPointsIndexes: [], @@ -156,6 +160,9 @@ export class CounterfactualComponent extends React.PureComponent< setCounterfactualData={this.setCounterfactualData} onIndexSeriesUpdated={this.onIndexSeriesUpdated} isCounterfactualsDataLoading={this.state.isCounterfactualsDataLoading} + isRevertButtonClicked={this.state.isRevertButtonClicked} + setIsRevertButtonClicked={this.setIsRevertButtonClicked} + resetIndexes={this.resetIndexes} /> { + this.changedKeys = []; + this.compareChartProps(newProps, this.state.chartProps); const shouldResetIndexes = ifEnableLargeData(this.context.dataset) && - !_.isEqual(this.state.chartProps, newProps); + !_.isEqual(this.state.chartProps, newProps) && + !hasAxisTypeChanged(this.changedKeys); this.setState({ chartProps: newProps }); if (shouldResetIndexes) { - this.setState({ - counterfactualsData: this.props.data, - customPointLength: 0, - indexSeries: [], - selectedPointsIndexes: [] - }); + this.resetIndexes(); + } + }; + + private resetIndexes = (): void => { + this.setState({ + counterfactualsData: this.props.data, + customPointLength: 0, + indexSeries: [], + selectedPointsIndexes: [] + }); + }; + + private compareChartProps = ( + newProps: IGenericChartProps, + oldProps?: IGenericChartProps + ): void => { + if (oldProps) { + for (const key in newProps) { + if (typeof newProps[key] === "object") { + this.compareChartProps(newProps[key], oldProps[key]); + } + if (newProps[key] !== oldProps[key]) { + this.changedKeys.push(key); + } + } } }; @@ -269,6 +299,14 @@ export class CounterfactualComponent extends React.PureComponent< this.setState({ indexSeries }); + this.setIsRevertButtonClicked(false); + }; + + private setIsRevertButtonClicked = (status: boolean): void => { + this.setState({ isRevertButtonClicked: status }); + if (status) { + this.resetIndexes(); + } }; private onSelectedPointsIndexesUpdated = (newSelection: number[]): void => { diff --git a/libs/counterfactuals/src/lib/CounterfactualComponentUtils.ts b/libs/counterfactuals/src/lib/CounterfactualComponentUtils.ts new file mode 100644 index 0000000000..745d9d5457 --- /dev/null +++ b/libs/counterfactuals/src/lib/CounterfactualComponentUtils.ts @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +enum FieldChangeUpdate { + Dither = "dither", + Property = "property", + Type = "type" +} + +export function hasAxisTypeChanged(changedKeys: string[]): boolean { + // return true only if type of the axis has changed in panel + const changedKeysTemp = removeParentKeys(changedKeys); + if ( + changedKeysTemp.length === 1 && + changedKeysTemp.includes(FieldChangeUpdate.Type) + ) { + return true; + } + return false; +} + +function removeParentKeys(changedKeys: string[]): string[] { + const valuesToRemove = new Set(["options", "xAxis", "yAxis"]); // Since chartProps is a nested object, these are parent keys which are usually changed if inner keys are changed. + return changedKeys.filter((item) => !valuesToRemove.has(item)); +} diff --git a/libs/counterfactuals/src/lib/largeCounterfactualsView/LargeCounterfactualChart.tsx b/libs/counterfactuals/src/lib/largeCounterfactualsView/LargeCounterfactualChart.tsx index 3c9b144e3d..1a98a35ec6 100644 --- a/libs/counterfactuals/src/lib/largeCounterfactualsView/LargeCounterfactualChart.tsx +++ b/libs/counterfactuals/src/lib/largeCounterfactualsView/LargeCounterfactualChart.tsx @@ -14,28 +14,38 @@ import { JointDataset, TelemetryLevels, ifEnableLargeData, + Cohort, IHighchartsConfig, IHighchartBubbleSDKClusterData } from "@responsible-ai/core-ui"; import _ from "lodash"; import React from "react"; -import { calculateBubblePlotDataFromErrorCohort } from "../../util/largeCounterfactualsView/calculateBubbleData"; +import { + calculateBubblePlotDataFromErrorCohort, + instanceOfHighChart +} from "../../util/largeCounterfactualsView/calculateBubbleData"; import { getCounterfactualsScatterOption, IScatterPoint } from "../../util/largeCounterfactualsView/getCounterfactualsScatterOption"; import { ICounterfactualChartProps } from "../CounterfactualChart"; import { counterfactualChartStyles } from "../CounterfactualChart.styles"; +import { hasAxisTypeChanged } from "../CounterfactualComponentUtils"; import { CounterfactualPanel } from "../CounterfactualPanel"; import { LargeCounterfactualChartArea } from "./LargeCounterfactualChartArea"; export interface ILargeCounterfactualChartProps extends ICounterfactualChartProps { + cohort: Cohort; isCounterfactualsDataLoading?: boolean; + isRevertButtonClicked: boolean; setCounterfactualData: (absoluteIndex?: number) => Promise; onIndexSeriesUpdated?: (indexSeries: number[]) => void; + setIsRevertButtonClicked: (status: boolean) => void; + resetIndexes: () => void; + resetCustomPoints: () => void; } export interface ICounterfactualChartState { @@ -47,6 +57,7 @@ export interface ICounterfactualChartState { indexSeries: number[]; isBubbleChartDataLoading: boolean; bubbleChartErrorMessage?: string; + isBubbleChartRendered: boolean; } export class LargeCounterfactualChart extends React.PureComponent< @@ -56,6 +67,7 @@ export class LargeCounterfactualChart extends React.PureComponent< public static contextType = ModelAssessmentContext; public context: React.ContextType = defaultModelAssessmentContext; + private changedKeys: string[] = []; public constructor(props: ILargeCounterfactualChartProps) { super(props); @@ -64,6 +76,7 @@ export class LargeCounterfactualChart extends React.PureComponent< bubbleChartErrorMessage: undefined, indexSeries: [], isBubbleChartDataLoading: false, + isBubbleChartRendered: true, plotData: undefined, xDialogOpen: false, xSeries: [], @@ -77,19 +90,19 @@ export class LargeCounterfactualChart extends React.PureComponent< } public componentDidUpdate(prevProps: ILargeCounterfactualChartProps): void { + if (this.shouldUpdateBubbleChartPlot(prevProps)) { + this.updateBubblePlot(); + return; + } + if (this.hasAxisTypeChanged(prevProps.chartProps)) { + this.updateScatterPlot(); + return; + } if (!_.isEqual(prevProps.chartProps, this.props.chartProps)) { this.updateBubblePlot(); - } else if ( - !_.isEqual( - prevProps.selectedPointsIndexes, - this.props.selectedPointsIndexes - ) || - !_.isEqual(prevProps.customPoints, this.props.customPoints) || - !_.isEqual( - prevProps.isCounterfactualsDataLoading, - this.props.isCounterfactualsDataLoading - ) - ) { + return; + } + if (this.shouldUpdateScatterPlot(prevProps)) { this.updateScatterPlot(); } } @@ -131,6 +144,7 @@ export class LargeCounterfactualChart extends React.PureComponent< canBin={bin} mustBin={bin} allowTreatAsCategorical={!ifEnableLargeData(this.context.dataset)} + allowLogarithmicScaling={!this.state.isBubbleChartRendered} canDither={this.props.chartProps.chartType === ChartTypes.Scatter} hideDroppedFeatures onAccept={this.onYSet} @@ -145,6 +159,7 @@ export class LargeCounterfactualChart extends React.PureComponent< mustBin={bin} canDither={this.props.chartProps.chartType === ChartTypes.Scatter} allowTreatAsCategorical={!ifEnableLargeData(this.context.dataset)} + allowLogarithmicScaling={!this.state.isBubbleChartRendered} hideDroppedFeatures onAccept={this.onXSet} onCancel={this.setXDialogOpen} @@ -190,10 +205,27 @@ export class LargeCounterfactualChart extends React.PureComponent< this.props.onChartPropsUpdated(newProps); }; + private compareChartProps = ( + newProps: IGenericChartProps, + oldProps: IGenericChartProps + ): void => { + for (const key in newProps) { + if (typeof newProps[key] === "object") { + this.compareChartProps(newProps[key], oldProps[key]); + } + if (newProps[key] !== oldProps[key]) { + this.changedKeys.push(key); + } + } + }; + private readonly setSeries = (newProps: IGenericChartProps): void => { + this.changedKeys = []; + this.compareChartProps(newProps, this.props.chartProps); const shouldResetIndexes = ifEnableLargeData(this.context.dataset) && - !_.isEqual(this.props.chartProps, newProps); + !_.isEqual(this.props.chartProps, newProps) && + !hasAxisTypeChanged(this.changedKeys); if (shouldResetIndexes) { this.setState({ indexSeries: [], @@ -203,6 +235,40 @@ export class LargeCounterfactualChart extends React.PureComponent< } }; + private readonly shouldUpdateScatterPlot = ( + prevProps: ILargeCounterfactualChartProps + ): boolean => { + return ( + !_.isEqual( + prevProps.selectedPointsIndexes, + this.props.selectedPointsIndexes + ) || + !_.isEqual(prevProps.customPoints, this.props.customPoints) || + !_.isEqual( + prevProps.isCounterfactualsDataLoading, + this.props.isCounterfactualsDataLoading + ) + ); + }; + + private readonly shouldUpdateBubbleChartPlot = ( + prevProps: ILargeCounterfactualChartProps + ): boolean => { + return ( + this.props.cohort.name !== prevProps.cohort.name || + (this.props.isRevertButtonClicked && + prevProps.isRevertButtonClicked !== this.props.isRevertButtonClicked) + ); + }; + + private readonly hasAxisTypeChanged = ( + prevChartProps: IGenericChartProps + ): boolean => { + this.changedKeys = []; + this.compareChartProps(this.props.chartProps, prevChartProps); + return hasAxisTypeChanged(this.changedKeys); + }; + private readonly setXDialogOpen = (): void => { this.setState({ xDialogOpen: !this.state.xDialogOpen }); }; @@ -211,16 +277,17 @@ export class LargeCounterfactualChart extends React.PureComponent< this.setState({ yDialogOpen: !this.state.yDialogOpen }); }; - private async updateBubblePlot(): Promise { + private async updateBubblePlot(): Promise { this.setState({ isBubbleChartDataLoading: true }); + this.props.onIndexSeriesUpdated && this.props.onIndexSeriesUpdated([]); + this.props.resetIndexes(); + this.props.resetCustomPoints(); const plotData = await this.getBubblePlotData(); - if (plotData && plotData["error"]) { - this.setState({ - bubbleChartErrorMessage: plotData["error"].split(":").pop() - }); + if (plotData && !instanceOfHighChart(plotData)) { this.setState({ + bubbleChartErrorMessage: plotData.toString().split(":").pop(), isBubbleChartDataLoading: false, plotData: undefined }); @@ -229,6 +296,7 @@ export class LargeCounterfactualChart extends React.PureComponent< this.setState({ bubbleChartErrorMessage: undefined, isBubbleChartDataLoading: false, + isBubbleChartRendered: true, plotData }); } @@ -254,9 +322,8 @@ export class LargeCounterfactualChart extends React.PureComponent< IHighchartsConfig | IHighchartBubbleSDKClusterData | undefined > { return await calculateBubblePlotDataFromErrorCohort( - this.context.selectedErrorCohort.cohort, + this.props.cohort, this.props.chartProps, - this.props.selectedPointsIndexes, this.props.customPoints, this.context.jointDataset, this.context.dataset, @@ -276,6 +343,7 @@ export class LargeCounterfactualChart extends React.PureComponent< ): void => { this.setState({ indexSeries, + isBubbleChartRendered: false, plotData: scatterPlotData, xSeries, ySeries diff --git a/libs/counterfactuals/src/util/largeCounterfactualsView/calculateBubbleData.ts b/libs/counterfactuals/src/util/largeCounterfactualsView/calculateBubbleData.ts index c9b67bfa65..7e42da1c00 100644 --- a/libs/counterfactuals/src/util/largeCounterfactualsView/calculateBubbleData.ts +++ b/libs/counterfactuals/src/util/largeCounterfactualsView/calculateBubbleData.ts @@ -17,7 +17,6 @@ import { IScatterPoint } from "./getCounterfactualsScatterOption"; export async function calculateBubblePlotDataFromErrorCohort( errorCohort: Cohort, chartProps: IGenericChartProps, - selectedPointsIndexes: number[], customPoints: Array<{ [key: string]: any; }>, @@ -41,29 +40,33 @@ export async function calculateBubblePlotDataFromErrorCohort( onIndexSeriesUpdated?: (indexSeries: number[]) => void ): Promise { if (ifEnableLargeData(dataset) && requestBubblePlotData) { - const bubbleChartData = await calculateBubblePlotDataFromSDK( - errorCohort, - jointDataset, - requestBubblePlotData, - jointDataset.metaDict[chartProps?.xAxis.property].label, - jointDataset.metaDict[chartProps?.yAxis.property].label - ); - if (bubbleChartData.error) { - return bubbleChartData; + try { + const selectedPointsIndexes: number[] = []; + const bubbleChartData = await calculateBubblePlotDataFromSDK( + errorCohort, + jointDataset, + requestBubblePlotData, + jointDataset.metaDict[chartProps?.xAxis.property].label, + jointDataset.metaDict[chartProps?.yAxis.property].label + ); + return getBubbleChartOptions( + bubbleChartData.clusters, + jointDataset.metaDict[chartProps?.xAxis.property].label, + jointDataset.metaDict[chartProps?.yAxis.property].label, + chartProps, + jointDataset, + selectedPointsIndexes, + customPoints, + isCounterfactualsDataLoading, + onBubbleClick, + selectPointFromChartLargeData, + onIndexSeriesUpdated + ); + } catch (error) { + if (error) { + return error; + } } - return getBubbleChartOptions( - bubbleChartData.clusters, - jointDataset.metaDict[chartProps?.xAxis.property].label, - jointDataset.metaDict[chartProps?.yAxis.property].label, - chartProps, - jointDataset, - selectedPointsIndexes, - customPoints, - isCounterfactualsDataLoading, - onBubbleClick, - selectPointFromChartLargeData, - onIndexSeriesUpdated - ); } return undefined; } @@ -99,3 +102,7 @@ export async function calculateBubblePlotDataFromSDK( return result; } + +export function instanceOfHighChart(object: any): object is IHighchartsConfig { + return "chart" in object; +} diff --git a/libs/localization/src/lib/en.json b/libs/localization/src/lib/en.json index 3d0db3b55b..359cb7bf53 100644 --- a/libs/localization/src/lib/en.json +++ b/libs/localization/src/lib/en.json @@ -151,6 +151,7 @@ "counterfactualName": "What-if counterfactual name", "createWhatIfCounterfactual": "Create what-if counterfactual", "createCounterfactual": "Counterfactual", + "revertToBubbleChart": "Revert to bubble chart", "createOwn": "Create your own counterfactual:", "currentClass": "Current class", "currentRange": "Current range", From 9d24ea287155faca9cfc2442f9c60d952d7ec5a4 Mon Sep 17 00:00:00 2001 From: tongy-msft <91754176+tongyu-microsoft@users.noreply.github.com> Date: Tue, 17 Jan 2023 15:37:39 -0800 Subject: [PATCH 10/11] Fix model wrapper, remove task type from error analysis (#1912) --- .../managers/error_analysis_manager.py | 71 ++++++++++++++----- .../rai_insights/rai_insights.py | 3 +- 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/responsibleai/responsibleai/managers/error_analysis_manager.py b/responsibleai/responsibleai/managers/error_analysis_manager.py index 2d82f5f72f..f6a7ecf59a 100644 --- a/responsibleai/responsibleai/managers/error_analysis_manager.py +++ b/responsibleai/responsibleai/managers/error_analysis_manager.py @@ -16,7 +16,7 @@ from erroranalysis._internal.error_report import \ json_converter as report_json_converter from responsibleai._config.base_config import BaseConfig -from responsibleai._interfaces import ErrorAnalysisData, TaskType +from responsibleai._interfaces import ErrorAnalysisData from responsibleai._internal.constants import ErrorAnalysisManagerKeys as Keys from responsibleai._internal.constants import ListProperties, ManagerNames from responsibleai._tools.shared.state_directory_management import \ @@ -82,13 +82,26 @@ def as_error_config(json_dict): return json_dict -class MetadataRemovalModelWrapper(): - """Defines MetadataRemovalModelWrapper, wrapping the model - to ignore dropped feature metadata if any.""" +def get_wrapped_model(model, dropped_features): + predict_proba_flag = hasattr(model, 'predict_proba') + if predict_proba_flag: + wrapper_model = MetadataRemovalClassificationModelWrapper( + model, + dropped_features) + else: + wrapper_model = MetadataRemovalRegressionModelWrapper( + model, + dropped_features) + return wrapper_model + + +class MetadataRemovalClassificationModelWrapper(): + """Defines MetadataRemovalClassificationModelWrapper, wrapping the + classification model to ignore dropped feature metadata if any.""" def __init__(self, model: any, dropped_features: Optional[List[str]] = None): - """If needed, wraps the model to ignore the dropped features. + """If needed, wraps the classification model to ignore the dropped features. :param model: The model or function to evaluate on the examples. :type model: function or model with a predict or predict_proba function @@ -111,6 +124,30 @@ def _apply_func(self, func, dataset): return func(dataset.drop(columns=self.dropped_features, axis=1)) +class MetadataRemovalRegressionModelWrapper(): + """Defines MetadataRemovalRegressionModelWrapper, wrapping the + regression model to ignore dropped feature metadata if any.""" + + def __init__(self, model: any, + dropped_features: Optional[List[str]] = None): + """If needed, wraps the model to ignore the dropped features. + + :param model: The model or function to evaluate on the examples. + :type model: function or model with a predict function + :param dropped_features: List of features that were dropped by the + the user during training of their model. + :type dropped_features: Optional[List[str]] + """ + self.model = model + self.dropped_features = dropped_features + + def predict(self, dataset: pd.DataFrame): + if self.dropped_features is None or len(self.dropped_features) == 0: + return self.model.predict(dataset) + return self.model.predict(dataset.drop( + columns=self.dropped_features, axis=1)) + + class ErrorAnalysisConfig(BaseConfig): """Defines the ErrorAnalysisConfig, specifying the parameters to run.""" @@ -188,8 +225,7 @@ class ErrorAnalysisManager(BaseManager): def __init__(self, model: Any, dataset: pd.DataFrame, target_column: str, classes: Optional[List] = None, categorical_features: Optional[List[str]] = None, - dropped_features: Optional[List[str]] = None, - task_type: Optional[TaskType] = None): + dropped_features: Optional[List[str]] = None): """Creates an ErrorAnalysisManager object. :param model: The model to analyze errors on. @@ -210,8 +246,6 @@ def __init__(self, model: Any, dataset: pd.DataFrame, target_column: str, training. This includes metadata that is useful for evaluating the model. :type dropped_features: Optional[List[str]] - :param task_type: The task type of the model. - :type task_type: TaskType """ self._true_y = dataset[target_column] self._dataset = dataset.drop(columns=[target_column]) @@ -220,13 +254,15 @@ def __init__(self, model: Any, dataset: pd.DataFrame, target_column: str, self._categorical_features = categorical_features self._ea_config_list = [] self._ea_report_list = [] - self._analyzer = ModelAnalyzer(MetadataRemovalModelWrapper( - model, dropped_features), + wrapper_model = get_wrapped_model( + model, + dropped_features) + self._analyzer = ModelAnalyzer( + wrapper_model, self._dataset, self._true_y, self._feature_names, self._categorical_features, - model_task=task_type, classes=self._classes) def add(self, max_depth: int = 3, num_leaves: int = 31, @@ -440,7 +476,6 @@ def _load(path, rai_insights): inst.__dict__['_ea_report_list'] = ea_report_list inst.__dict__['_ea_config_list'] = ea_config_list - task_type = rai_insights.task_type categorical_features = rai_insights.categorical_features inst.__dict__['_categorical_features'] = categorical_features target_column = rai_insights.target_column @@ -454,11 +489,13 @@ def _load(path, rai_insights): if rai_insights._feature_metadata is not None: dropped_features = rai_insights._feature_metadata.dropped_features inst.__dict__['_dropped_features'] = dropped_features - inst.__dict__['_analyzer'] = ModelAnalyzer(MetadataRemovalModelWrapper( - rai_insights.model, dropped_features), + wrapper_model = get_wrapped_model( + rai_insights.model, + dropped_features) + inst.__dict__['_analyzer'] = ModelAnalyzer( + wrapper_model, dataset, true_y, feature_names, - categorical_features, - model_task=task_type) + categorical_features) return inst diff --git a/responsibleai/responsibleai/rai_insights/rai_insights.py b/responsibleai/responsibleai/rai_insights/rai_insights.py index bb8a7679fc..8a804a0a5b 100644 --- a/responsibleai/responsibleai/rai_insights/rai_insights.py +++ b/responsibleai/responsibleai/rai_insights/rai_insights.py @@ -249,8 +249,7 @@ def _initialize_managers(self): self.model, self.test, self.target_column, self._classes, self.categorical_features, - dropped_features, - task_type=self.task_type) + dropped_features) self._explainer_manager = ExplainerManager( self.model, self.get_train_data(), self.get_test_data(), From 60d8870d1db82cc5c65bb6f2f4b0ad5559e58805 Mon Sep 17 00:00:00 2001 From: tongy-msft <91754176+tongyu-microsoft@users.noreply.github.com> Date: Tue, 17 Jan 2023 16:35:35 -0800 Subject: [PATCH 11/11] Update some sample notebooks to include dropped_features (#1911) * update some notebooks to include dropped_features * update some notebooks to include dropped_features * update some notebooks to include dropped_features * remove unused changes * remove unused changes * address comments * address comment --- .../getting-started.ipynb | 11 ++++++-- ...aidashboard-diabetes-decision-making.ipynb | 27 ++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/notebooks/responsibleaidashboard/getting-started.ipynb b/notebooks/responsibleaidashboard/getting-started.ipynb index c48cedb4fe..6da4edebac 100644 --- a/notebooks/responsibleaidashboard/getting-started.ipynb +++ b/notebooks/responsibleaidashboard/getting-started.ipynb @@ -110,7 +110,9 @@ "source": [ "It is necessary to initialize a RAIInsights object upon which the different components can be loaded. `task_type` holds the string `'regression'` or `'classification'` depending on the developer's purpose.\n", "\n", - "Users can also specify categorical features via the `categorical_features` parameter." + "Users can also specify categorical features via the `categorical_features` parameter.\n", + "\n", + "Using the `FeatureMetadata` container, you can declare an `identity_feature`, and specify features to withhold from the model via the `dropped_features` parameter. The `FeatureMetadata` serves as an input argument for `RAIInsights`." ] }, { @@ -119,8 +121,13 @@ "metadata": {}, "source": [ "```Python\n", + "from responsibleai.feature_metadata import FeatureMetadata\n", + "# Add 's1' as an identity feature, set 'age' as a dropped feature\n", + "feature_metadata = FeatureMetadata(identity_feature_name='s1', dropped_features=['age'])\n", + "\n", "task_type = 'regression'\n", - "rai_insights = RAIInsights(model, train_data, test_data, target_feature, task_type)\n", + "\n", + "rai_insights = RAIInsights(model, train_data, test_data, target_feature, task_type, categorical_features=[], feature_metadata=feature_metadata)\n", "```" ] }, diff --git a/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb b/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb index 52ca030b82..61d6f9e482 100644 --- a/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb +++ b/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb @@ -114,6 +114,25 @@ "data.feature_names" ] }, + { + "cell_type": "markdown", + "id": "46119282", + "metadata": {}, + "source": [ + "You may define `features_to_drop` and drop any features from `X_train`. The model will be trained without `features_to_drop`. If `features_to_drop` is not set, `X_train_after_drop` will be the same as `X_train`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3db1ef3b", + "metadata": {}, + "outputs": [], + "source": [ + "features_to_drop = []\n", + "X_train_after_drop = X_train.drop(features_to_drop, axis=1)" + ] + }, { "cell_type": "markdown", "id": "59853607", @@ -130,7 +149,7 @@ "outputs": [], "source": [ "model = RandomForestRegressor(random_state=0)\n", - "model.fit(X_train, y_train)" + "model.fit(X_train_after_drop, y_train)" ] }, { @@ -161,7 +180,7 @@ "\n", "RAIInsights accepts the model, the train dataset, the test dataset, the target feature string, the task type string, and a list of strings of categorical feature names as its arguments.\n", "\n", - "You may also create the `FeatureMetadata` container and identify any feature of your choice as the `identity_feature`. The `FeatureMetadata` may also be passed into the `RAIInsights`." + "You may also create the `FeatureMetadata` container, identify any feature of your choice as the `identity_feature`, and specify dropped features via the `dropped_features` parameter. The `FeatureMetadata` may also be passed into the `RAIInsights`." ] }, { @@ -172,8 +191,8 @@ "outputs": [], "source": [ "from responsibleai.feature_metadata import FeatureMetadata\n", - "# Add 's1' as an identity feature\n", - "feature_metadata = FeatureMetadata(identity_feature_name='s1')\n", + "# Add 's1' as an identity feature, set features_to_drop as dropped features\n", + "feature_metadata = FeatureMetadata(identity_feature_name='s1', dropped_features=features_to_drop)\n", "rai_insights = RAIInsights(model, train_data, test_data, target_feature, 'regression',\n", " categorical_features=[], feature_metadata=feature_metadata)" ]