Kommentar
Åtkomst till den här sidan kräver auktorisering. Du kan prova att logga in eller ändra kataloger.
Åtkomst till den här sidan kräver auktorisering. Du kan prova att ändra kataloger.
Använd Kernel SHAP (SHapley Additive exPlanations) för att förklara en tabellklassificeringsmodell. Kernel-SHAP är en modellagnostisk metod som uppskattar varje funktions bidrag till en modells förutsägelse. Du tränar en logistisk regressionsmodell på datamängden Adult Census Income och använder sedan SynapseML-transformatorn TabularSHAP för att beräkna förklaringar på funktionsnivå.
Förutsättningar
Skaffa en Microsoft Fabric-prenumeration. Eller registrera dig för en kostnadsfri utvärderingsversion av Microsoft Fabric.
Logga in på Microsoft Fabric.
Växla till Fabric med hjälp av upplevelseväxlaren längst ned till vänster på startsidan.
- Skapa en ny notebook-fil på din arbetsyta och koppla den till ett sjöhus. Mer information finns i Skapa en notebook-fil.
SynapseML, PySpark, pandas och plotly är förinstallerade i Fabric notebook-miljöer. Ingen extra paketinstallation krävs.
Importera paket och definiera hjälp-UDF:er
I din Fabric notebook-fil klistrar du in följande kod i en cell och kör den. Det här steget importerar de bibliotek som krävs och definierar två användardefinierade funktioner (UDF:er) för att extrahera vektorelement senare.
import pyspark
from synapse.ml.explainers import TabularSHAP
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.sql.types import FloatType, ArrayType
from pyspark.sql.functions import col, lit, rand, broadcast, udf
import pandas as pd
vec_access = udf(lambda v, i: float(v[i]), FloatType())
vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))
Kontrollera: Kör följande kod i en ny cell. Du bör se utdata TabularSHAP imported successfully.
print("TabularSHAP imported successfully")
print(f"PySpark version: {pyspark.__version__}")
Läsa in data och träna en klassificeringsmodell
Läs in datamängden Adult Census Income från Azure Blob Storage, indexera måletiketten och träna en pipeline för logistisk regression.
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/AdultCensusIncome.parquet"
)
labelIndexer = StringIndexer(
inputCol="income", outputCol="label", stringOrderType="alphabetAsc"
).fit(df)
print("Label index assignment: " + str(set(zip(labelIndexer.labels, [0, 1]))))
training = labelIndexer.transform(df).cache()
categorical_features = [
"workclass",
"education",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"native-country",
]
categorical_features_idx = [feat + "_idx" for feat in categorical_features]
categorical_features_enc = [feat + "_enc" for feat in categorical_features]
numeric_features = [
"age",
"education-num",
"capital-gain",
"capital-loss",
"hours-per-week",
]
strIndexer = StringIndexer(
inputCols=categorical_features, outputCols=categorical_features_idx
)
onehotEnc = OneHotEncoder(
inputCols=categorical_features_idx, outputCols=categorical_features_enc
)
vectAssem = VectorAssembler(
inputCols=categorical_features_enc + numeric_features, outputCol="features"
)
lr = LogisticRegression(featuresCol="features", labelCol="label", weightCol="fnlwgt")
pipeline = Pipeline(stages=[strIndexer, onehotEnc, vectAssem, lr])
model = pipeline.fit(training)
Kontrollera: Kör följande cell. Du bör se antal rader för träningsdata och bekräftelse på stegen i pipelinen.
print(f"Training rows: {training.count()}")
print(f"Pipeline stages: {[type(s).__name__ for s in model.stages]}")
assert training.count() > 30000, "Dataset should contain over 30,000 rows"
print("Model trained successfully")
# Expected output:
#Training rows: 32561
#Pipeline stages: ['StringIndexerModel', 'OneHotEncoderModel', #'VectorAssembler', 'LogisticRegressionModel']
#Model trained successfully
Välj observationer för att förklara
Välj slumpmässigt fem observationer från de poängsatta träningsdata. Dessa observationer är de instanser som du genererar SHAP-förklaringar för.
explain_instances = (
model.transform(training).orderBy(rand()).limit(5).repartition(200).cache()
)
display(explain_instances)
Verifiera: Bekräfta exempelstorleken.
count = explain_instances.count()
print(f"Explain instances: {count}")
assert count == 5, f"Expected 5 rows, got {count}"
print("Sample selected successfully")
Konfigurera och köra TabularSHAP
Skapa en TabularSHAP förklaring och tillämpa den på de valda observationerna. Nyckelparametrarna är:
| Parameter | Description |
|---|---|
inputCols |
Funktionskolumner som modellen använder för förutsägelse. |
outputCol |
Namnet på kolumnen som innehåller SHAP-utdatavärden. |
numSamples |
Antal perturbationsexempel för KERNEL SHAP-uppskattning. Högre värden är mer exakta men långsammare. |
model |
Den tränade pipelinemodellen för att förklara. |
targetCol |
Kolumnen för modellutdata som ska förklaras. I det här exemplet är kolumnen probability. |
targetClasses |
Klassindex som ska förklaras.
[1] förklarar endast sannolikhet för klass 1. Använd [0, 1] för att förklara båda klasserna. |
backgroundData |
Ett exempel på träningsdata som används som referensdistribution för integrering av funktioner. |
shap = TabularSHAP(
inputCols=categorical_features + numeric_features,
outputCol="shapValues",
numSamples=5000,
model=model,
targetCol="probability",
targetClasses=[1],
backgroundData=broadcast(training.orderBy(rand()).limit(100).cache()),
)
shap_df = shap.transform(explain_instances)
Note
Det här steget kan ta flera minuter beroende på numSamples och klusterstorlek. Med numSamples=5000 och fem observationer kan du förvänta dig 3–10 minuter på ett standard Fabric Spark-kluster.
Kontrollera: Kontrollera att SHAP-utdatakolumnen finns.
assert "shapValues" in shap_df.columns, "shapValues column missing"
print(f"SHAP output columns: {shap_df.columns}")
print("TabularSHAP transform completed")
Extrahera SHAP-värden
Extrahera klass 1-sannolikheten och SHAP-värdena från resultatet DataFrame. För varje observation börjar SHAP-värdevektorn med basvärdet (medelvärdet av bakgrundsdatauppsättningen), följt av ett värde per funktion.
shaps = (
shap_df.withColumn("probability", vec_access(col("probability"), lit(1)))
.withColumn("shapValues", vec2array(col("shapValues").getItem(0)))
.select(
["shapValues", "probability", "label"] + categorical_features + numeric_features
)
)
shaps_local = shaps.toPandas()
shaps_local.sort_values("probability", ascending=False, inplace=True, ignore_index=True)
pd.set_option("display.max_colwidth", None)
display(shaps_local)
Kontrollera: Bekräfta strukturen för pandas DataFrame.
expected_cols = len(categorical_features) + len(numeric_features) + 3
print(f"DataFrame shape: {shaps_local.shape}")
print(f"Expected columns: {expected_cols}, Actual: {shaps_local.shape[1]}")
assert shaps_local.shape == (5, expected_cols), f"Unexpected shape: {shaps_local.shape}"
print("SHAP values extracted successfully")
Visualisera SHAP-värden
Skapa ett stapeldiagram för varje observation som visar hur varje funktion bidrar till den förväntade sannolikheten.
from plotly.subplots import make_subplots
import plotly.graph_objects as go
features = categorical_features + numeric_features
features_with_base = ["Base"] + features
rows = shaps_local.shape[0]
fig = make_subplots(
rows=rows,
cols=1,
subplot_titles="Probability: "
+ shaps_local["probability"].apply("{:.2%}".format)
+ "; Label: "
+ shaps_local["label"].astype(str),
)
for index, row in shaps_local.iterrows():
feature_values = [0] + [row[feature] for feature in features]
shap_values = row["shapValues"]
list_of_tuples = list(zip(features_with_base, feature_values, shap_values))
shap_pdf = pd.DataFrame(list_of_tuples, columns=["name", "value", "shap"])
fig.add_trace(
go.Bar(
x=shap_pdf["name"],
y=shap_pdf["shap"],
hovertext="value: " + shap_pdf["value"].astype(str),
),
row=index + 1,
col=1,
)
fig.update_yaxes(range=[-1, 1], fixedrange=True, zerolinecolor="black")
fig.update_xaxes(type="category", tickangle=45, fixedrange=True)
fig.update_layout(height=400 * rows, title_text="SHAP explanations")
fig.show()
Kontrollera: Bekräfta att ritobjektet har skapats.
print(f"Figure traces: {len(fig.data)}")
print(f"Figure height: {fig.layout.height}px")
assert len(fig.data) == 5, f"Expected 5 traces, got {len(fig.data)}"
print("Visualization created successfully")
Tolka resultatet
Varje underplot representerar en observation. Staplarna visar:
- Bas: Den genomsnittliga modellens utdata över bakgrundsdatamängden (baslinjesannolikheten).
- Positiva SHAP-värden: Funktioner som driver förutsägelsen mot klass 1 (intäkter större än 50 000).
- Negativa SHAP-värden: Egenskaper som driver prediktionen mot klass 0 (inkomst mindre än eller lika med 50 000).
Summan av basvärdet och alla shap-värden för funktionen är lika med modellens förväntade sannolikhet för observationen.
Felsökning
| Problematik | Orsak | Lösning |
|---|---|---|
OutOfMemoryError under TabularSHAP |
numSamples är för stort för tillgängligt minne. |
Minska numSamples, till exempel till 1 000 eller öka Spark-körminnet. |
| SHAP-transformering är långsam | Hög numSamples med många funktioner ökar beräkningstiden. |
Minska numSamples till 1 000–2 000 för snabbare undersökande resultat. Öka för slutlig analys. |
FileNotFoundException för parquet |
Nätverksåtkomsten till mmlspark.blob.core.windows.net är blockerad. |
Kontrollera att din Fabric arbetsyta har utgående Internetåtkomst. Alternativt kan du ladda upp datauppsättningen till ditt lakehouse. |
shapValues kolumnen innehåller null-värden |
Vissa observationer kan misslyckas om funktionsvärden ligger utanför träningsdistributionen. | Sök efter null- eller oväntade värden i indatafunktioner. Filtrera null-värden från resultat. |
display() visar inga utdata |
Koden körs utanför en Fabric notebook-miljö. | Använd shaps_local.head() eller print(shaps_local) i standardmiljöer för Python. |
Rensa
Om du har laddat upp datauppsättningen till ett lakehouse för den här självstudien, tar du bort den för att frigöra lagringsutrymme:
# Remove cached DataFrames from memory
training.unpersist()
explain_instances.unpersist()
print("Cached DataFrames released")