Beyond Black Box: Enhancing Model Explainability with LLMs and SHAP


TL;DR

This tutorial showcases how to use GenAI tools to convert complex SHAP values from a machine learning model into simple, narrative explanations, making model predictions easily understandable and communicable to non-technical stakeholders.

Introduction

In the realm of machine learning, 'black box' models often offer impressive predictive performance but lack transparency in their decision-making processes. SHAP (SHapley Additive exPlanations) has emerged as a powerful tool to demystify these models by quantifying the impact of each feature on the prediction.

While SHAP values offer a powerful way to understand machine learning predictions, their technical nature often leaves stakeholders in the dark. This tutorial introduces a swift, GenAI-driven solution that transforms these intricate explanations into relatable narratives, making complex model insights easily shareable and understandable for all involved parties.

This blog post explores how Large Language Models (LLMs) can transform your predictions into intuitive, narrative explanations, using the Titanic dataset as a case study. We'll leverage the Julia programming language, specifically the MLJ.jl framework for machine learning and SHAP.jl for generating SHAP values.

It is inspired by the paper Tell Me a Story! Narrative-Driven XAI with Large Language Models, where the authors share interesting findings from a user study:

...over 90% of the surveyed general audience finds the narrative generated by SHAPstories convincing.

excerpt from Tell Me a Story! Narrative-Driven XAI with Large Language Models

Setting Up the Experiment

The Titanic dataset, a classic in machine learning, offers a vivid context for survival prediction. Since it's not the focus of this article, we'll borrow the approach from the Titanic MLJ tutorial.

# To set up your environment, you'll need to install the following packages:
# ]add DataFramesMeta MLJ DecisionTree MLJDecisionTreeInterface ShapML PromptingTools 

using MLJ, DataFramesMeta, ShapML
using MLJ.CategoricalArrays: unwrap
using PromptingTools
const PT = PromptingTools

# Get the data and transform it into MLJ format
table = OpenML.load(42638)
df = DataFrame(table)
dropmissing!(df, :embarked)
coerce!(df, :sibsp => Count, :survived => OrderedFactor, :pclass => OrderedFactor, :sex => OrderedFactor, :embarked => OrderedFactor)
y, X = unpack(df, ==(:survived), !=(:cabin));
train, test = partition(eachindex(y), 0.7, shuffle=true, rng=1234);

# Train a model on 70% of data
Tree = @load DecisionTreeClassifier pkg = DecisionTree
tree = Tree(max_depth=5)
mach = machine(tree, X, y)
fit!(mach, rows=train)
y_pred = predict(mach, X);

Next, we integrate SHAP.jl to compute SHAP values for our model's predictions:

# Generate SHAP values for the classifier predictions
predict_proba(model, data) = DataFrame(; y_pred=predict(model, data) |> x -> pdf.(x, "1"))
data_shap = ShapML.shap(explain=X,
    model=mach,
    predict_function=predict_proba,
    sample_size=60,
    seed=1234
);

Let's show the SHAP results of a single data point:

shap_ = @chain begin
    @rsubset data_shap :index == 3
    @orderby -:shap_effect
end
6×6 DataFrame
 Row │ index  feature_name  feature_value  shap_effect  shap_effect_sd  intercept 
     │ Int64  String        Any            Float64      Float64         Float64   
─────┼────────────────────────────────────────────────────────────────────────────
   1 │     3  sex           female          0.349978         0.288405    0.350693
   2 │     3  sibsp         0               0.0166667        0.129099    0.350693
   3 │     3  fare          7.925           0.0165051        0.177344    0.350693
   4 │     3  embarked      S              -0.00239717       0.0130181   0.350693
   5 │     3  age           26.0           -0.0193997        0.237259    0.350693
   6 │     3  pclass        3              -0.125878         0.247836    0.350693

Confusing, right? Let's try to make sense of these values.

(See the SHAP.jl documentation for more details on SHAP effect.)

Bridging the Gap with LLMs

Here we introduce the use of LLMs to create stories that explain predictions in a more human-friendly manner.

We can create a story template (already done in PromptingTools.jl under the name StorytellerExplainSHAP) and use aigenerate to interpolate these values into a coherent narrative.

Let's prepare the general information about the task and dataset:

# Describe the data, perhaps columns names could suffice?
feature_description = let
    io = IOBuffer()
    show(io, describe(X, :mean, :min, :max, :nunique); summary=false, eltypes=false)
    "\n" * String(take!(io))
end

# to provide to aigenerate as kwargs
task_context = (; task_definition="which of the Titanic passenger have died or survived based on their data",
    feature_description, label_definition="that the passenger survived",
    # keep instructions None for now, see `Practical Considerations` section below
    instructions="None.");

Let's prepare a utility function for individual instances (idx is the position of the instance in the dataset):

"Prepares the context for the selected instance to be provided to the LLM"
function prepare_instance_context(data_shap, y_pred, y, idx::Int)
    proba_ = pdf(y_pred[idx], "1")
    shap_ = data_shap.intercept[1] + (@rsubset(data_shap, :index == idx).shap_effect |> sum)

    probability_pct = proba_ >= 0.5 ? round(Int, proba_ * 100) : round(Int, (1 - proba_) * 100)
    prediction = proba_ >= 0.5 ? "the passenger survived" : "the passenger died"
    outcome = unwrap(y[idx] == "1") ? "the passenger survived" : "the passenger died"
    classified_correctly = prediction == outcome ? "correctly classified" : "misclassified"

    @info "Selected item: $idx, Proba: $proba_ vs SHAP values for instance $shap_ -> Outcome: $outcome"

    # Generate the SHAP table
    io = IOBuffer()
    shap_ = @chain begin
        @rsubset data_shap :index == idx
        @rsubset !(:shap_effect ≈ 0)
        @orderby -:shap_effect
        @rtransform :shap_effect = round(:shap_effect, digits=2)
        @select :feature_name :feature_value :shap_effect
    end
    show(io, shap_; summary=false, eltypes=false)
    shap_table = String(take!(io))

    return (; probability_pct, prediction, outcome, classified_correctly, shap_table)
end;

We're ready to start generating stories!

Example 1: Misclassified

instance_context = prepare_instance_context(data_shap, y_pred, y, 821)
#[ Info: Selected item: 821, Proba: 0.07623318385650224 vs SHAP values for instance 0.12002796217871464 -> Outcome: the passenger survived

msg = aigenerate(:StorytellerExplainSHAP; task_context..., instance_context..., model="gpt4t")
[ Info: Tokens: 774 @ Cost: \$0.0122 in 18.0 seconds
AIMessage("In the grand tapestry of the Titanic's voyage, the waves of fate seemed to conspire against a young man, a third-class passenger with little more than the humble fare that booked his passage. He was en route to a new life, with no siblings or spouse along for the journey, a solitary figure amongst the throng. His ticket, priced at a modest sum, suggested a man of simple means, likely overlooked amidst the wealth and splendor of the Titanic's more affluent passengers. The ship's records indicate that he embarked at Southampton, a bustling port where many souls boarded, unaware of the tragedy that lay ahead. Despite these factors that seemingly stacked the odds against his survival, the young man found a way to defy the cold arithmetic of survival on that fateful night.

The prediction cast a dark shadow, suggesting the young man had perished in the disaster, overwhelmed by the unfortunate confluence of his situation. However, his resolve or perhaps serendipity led him to survive, a living testament to the unpredictability of life and the limitations of even the most finely tuned models of artificial conjecture.")

Example 2: Correctly classified as survived

instance_context = prepare_instance_context(data_shap, y_pred, y, 864)
# [ Info: Selected item: 821, Proba: 0.07623318385650224 vs SHAP values for instance 0.12002796217871464 -> Outcome: the passenger survived
msg = aigenerate(:StorytellerExplainSHAP; task_context..., instance_context..., model="gpt4t")
[ Info: Tokens: 793 @ Cost: \$0.0128 in 23.1 seconds
AIMessage("On a fateful night in 1912, aboard the ill-fated Titanic, a middle-aged woman made her journey across the Atlantic. Not hailing from the opulence of first class but also not confined to the cramped conditions of the third class, she traveled in second class comfort, which, on this particular voyage, turned out to be a relatively safer berth. Alone without siblings or spouse, she could decisively move and respond to the ensuing chaos that frigid night. Her gender played a pivotal role as women were prioritized during the lifeboat evacuations. The relatively modest fare she paid for her passage, and her age, while painting a picture of an average middle-class woman, surprisingly did not significantly alter her survival chances.

In the end, the combination of her traveling in second class, her companionship status and the fact that she was a woman in her early 40s, tilted the scales in favor of her survival. Despite the tragedy, she was one of the fortunate to secure a place in a lifeboat and live to tell the tale. The model concluded that this passenger survived, mirroring her real-life outcome, driven mainly by her travel class and gender amidst the Titanic's tragic demise.")

Example 3: Correctly classified as died

instance_context = prepare_instance_context(data_shap, y_pred, y, 126)
# [ Info: Selected item: 126, Proba: 0.07623318385650224 vs SHAP values for instance 0.12002796217871464 -> Outcome: the passenger died

msg = aigenerate(:StorytellerExplainSHAP; task_context..., instance_context..., model="gpt4t")
[ Info: Tokens: 800 @ Cost: \$0.013 in 16.6 seconds
AIMessage("On a fateful night aboard the Titanic, amidst the cold embrace of the Atlantic, the destiny of a young man was sealed. This traveler, a male aged 30, had boarded the vessel with a third-class ticket, a choice that offered limited access to lifeboats and safety measures as the tragedy unfolded around him. His journey had commenced at Queenstown, now known as Cobh, a port indicated by a single-letter ticket stamp 'Q', marking his point of embarkation, a detail as undistinguished as the modest fare he had paid for his passage, a mere 7.75 dollars.

In these desperate times, when families clung together and women and children were ushered to safety first, his solitary status, with not a sibling or spouse to claim companionship, rendered him nearly invisible in the chaos. The narrative ends with the echo of his footsteps fading into silence, the algorithm's analysis as cold and unyielding as the night, concluding his fate as one among the lost souls of the Titanic. The model, with solemn certainty, predicts his demise, a story woven from the threads of his social standing, his gender, and the austerity of his journey—a young man, alone and in third class, claimed by the sea.")

While the stories are not perfect, they are surprisingly easy to understand. Another GenAI success with just a few lines of extra code!

Practical Considerations

Choosing the Right Model I have found only GPT-4 to be useful for this approach. Weaker models would require some finetuning or few-shot prompting to get the desired results.

Even though it takes almost 20-30 seconds per story, the overall time is quite reasonable when you send multiple requests in parallel (eg, wih asyncmap).

Costs? While 1 cent per story might seem like a lot, imagine how long it takes to write functions to produce coherent narratives! You're saving a lot of time and effort by using this approach.

Moreover, you rarely need to explain ALL predictions. It's usually enough to explain a few examples to get a sense of the model's behavior or if there is an audit / a request from stakeholders.

Boardroom Readiness? The stories above are not really suitable for a boardroom presentation. However, you can easily provide additional information about the audience, context, style, tone, etc. via instructions placeholder.

msg = aigenerate(:StorytellerExplainSHAP; task_context..., instance_context..., instructions="Be brief. Adjust the story for boardroom presentation.", model="gpt4t")
[ Info: Tokens: 766 @ Cost: \$0.0111 in 19.2 seconds
AIMessage("In our analysis, we encountered a male passenger aged 27 who had embarked from Southampton without any siblings or spouse aboard. Traveling third class and having paid a fare far below average, he was enveloped by the perilous reputation of the most economically restrained accommodations of the Titanic. This narrative, coupled with being a young male—the archetype often expected to give precedence to women and children during life-saving procedures—cast a long shadow over his likelihood of survival. Our prediction model rendered a grim forecast, plunging the odds against his survival amidst the tragedy well-known to history.")

What is being sent to the LLM? If you want to see what we're sending to the LLM (for debugging), simply render the template without sending it to the LLM:

conv_rendered = PT.render(PT.NoSchema(), :StorytellerExplainSHAP; task_context..., instance_context...)
println(conv_rendered[2].content)

Conclusion

Combining SHAP explanations with LLM-generated stories offers a novel way to enhance the interpretability of machine learning models. This approach is particularly valuable for data scientists who need to share the explanations with non-technical stakeholders (Responsible AI!). e encourage our readers to experiment with this methodology in their own models to gain deeper insights and share their experiences and templates with the community.

Note: This approach is not specific to SHAP. You can use it with any other explainability tool, eg, CounterfactualExplanations.jl. All you need to do is tweak the prompt template slightly - PromptingTools PRs welcome :)


Further Reading and Resources

CC BY-SA 4.0 Jan Siml. Last modified: April 28, 2024. Website built with Franklin.jl and the Julia programming language. See the Privacy Policy