RI Adversarial File Scanning Walkthrough

You are the AI Risk Officer at a Consumer Social Company. The NLP team has been tasked with implementing a text classification model to predict the top-level “sentiment” of posts on the app. These predictions will later be consumed by multiple models throughout the company, such as recommendation, lead prediction, and the core advertisement models. You want to verify your models are sufficiently robust to adversaries seeking to exploit model vulnerabilities and boost content that your user base does not actually like.

In this Notebook Walkthrough, we will review our core product of AI Stress Testing of NLP models in an adversarial setting. RIME AI Stress Testing allows you to test any text classification model on any dataset. In this way, you will be able to quantify your model’s vulnerability to attacks and noisy data.

Your team’s NLP models are fine-tuned from state-of-the-art transformer models found on Hugging Face’s Model Hub 🤗. In particular, you have chosen to fine-tune a DistilBERT on data similar to the Stanford Sentiment Treebank dataset for a lightweight yet performant model.

For more information on how to connect with a Hugging Face Model or Hugging Face Dataset, check out the linked documentation.

To begin, please specify your RIME cluster’s URL and personal access token.

[ ]:
!pip install rime-sdk &> /dev/null
!pip install seaborn

from rime_sdk import Client

Establish the RIME Client

To get started, provide the API credentials and the base domain/address of the RIME service. You can generate and copy an API token from the API Access Tokens Page under Workspace settings. For the domian/address of the RIME service, contact your admin.

Image of getting an API token
[ ]:
API_TOKEN = '' # PASTE API_KEY
CLUSTER_URL = '' # PASTE DEDICATED DOMAIN OF RIME SERVICE (e.g., https://rime.example.rbst.io)
AGENT_ID = '' # PASTE AGENT_ID IF USING AN AGENT THAT IS NOT THE DEFAULT
[ ]:
client = Client(CLUSTER_URL, API_TOKEN)

Create a Project

Below, create a project to store this and other future adversarial robustness stress test run results.

[ ]:
description = (
    "Evaluate the robustness of text classification models"
    " against adversarial attacks. Demonstration uses the"
    " SST-2 dataset (https://huggingface.co/datasets/sst2)"
    " and a fine-tuned version of the DistilBERT model"
    " (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)."
)
project = client.create_project(
    name="NLP Adversarial Robustness Demo",
    description=description,
    model_task="MODEL_TASK_MULTICLASS_CLASSIFICATION"
)

Register Model and Datasets

Next, we use datasets (Stanford Sentiment Treebank) and model (DistilBERT) from huggingface, as described abov, and register those with RIME.

[ ]:
from datetime import datetime

dt = str(datetime.now())
model_id = project.register_model(
    f'model_{dt}',
    model_config={
        "hugging_face": {
            "model_uri": "distilbert-base-uncased-finetuned-sst-2-english"
        }
    },
    agent_id=AGENT_ID
)
def _register_dataset(split_name):
    data_info = {
        "connection_info": {
            "hugging_face": {
                "dataset_uri": "sst2",
                "split_name": split_name,
            },
        },
        "data_params": {
            "label_col": "label",
            "text_features": ["sentence"],
            "sample": True,
            "nrows": 100
        },
    }
    return project.register_dataset(f'{split_name}_datset_{dt}', data_info, agent_id=AGENT_ID)


ref_dataset_id = _register_dataset("train")
eval_dataset_id = _register_dataset("validation")

Start File Scan on the Registered Model

[ ]:
file_scan_job = client.start_file_scan(
    model_id=model_id, project_id=project.project_id, agent_id=AGENT_ID
)
file_scan_job.get_status(verbose=True, wait_until_finish=True)

Get File Scan Result for Registered Model

[ ]:
file_scan_result = client.get_file_scan_result(file_scan_id=file_scan_job.job_id)

Start Stress Test

[ ]:
stress_test_config = {
    "run_name": "DistilBERT Adversarial Robustness",
    "data_info": {
        "ref_dataset_id": ref_dataset_id,
        "eval_dataset_id": eval_dataset_id,
    },
    "model_id": model_id,
    "categories": [
        "TEST_CATEGORY_TYPE_ADVERSARIAL",
        "TEST_CATEGORY_TYPE_SUBSET_PERFORMANCE",
        "TEST_CATEGORY_TYPE_TRANSFORMATIONS",
        "TEST_CATEGORY_TYPE_BIAS_AND_FAIRNESS",
        "TEST_CATEGORY_TYPE_DATA_CLEANLINESS"
    ],
    "test_suite_config": {
        "global_exclude_columns": ["idx"]
    },
}
stress_job = client.start_stress_test(
    stress_test_config, project.project_id, agent_id=AGENT_ID
)
stress_job.get_status(verbose=True, wait_until_finish=True)

Review Adversarial Stress Test Run

Now that the test run is complete, we can check out the results in the RIME web interface.

[ ]:
test_run = stress_job.get_test_run()
test_run

Query Results

Alternatively, we can query the test case results to identify model vulnerabilities.

[ ]:
result_df = test_run.get_result_df()
result_df.head()

Test Severity: Let’s plot some of the results. First, let’s check the severity distribution of attack tests.

[ ]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="white")

severity_cols = [col for col in result_df.columns if 'severity_counts' in col.lower()]
severity_counts = result_df[severity_cols].iloc[0]
plt.pie(severity_counts, labels=severity_cols)
plt.show()

Reviewing Test Case Results

Next, let’s look at the results by attack type.

[ ]:
test_cases_df = test_run.get_test_cases_df(show_test_case_metrics=True)
test_cases_df.head()
[ ]:
fig = plt.figure(figsize=(10,10))
test_type_pass_rates = {name: (batch_df['severity'] == 'SEVERITY_PASS').sum() / len(batch_df) for name, batch_df in test_cases_df.groupby("test_batch_type")}
sns.barplot(y=list(test_type_pass_rates.keys()), x=list(test_type_pass_rates.values()), orient='h')
plt.xlabel('Pass Rate')
plt.ylabel('Test Type')
plt.show()

You can also query certain test batch-level metrics, including the model’s accuracy on the original and perturbed inputs, and the average number of queries for the attack algorithm.

[ ]:
import pandas as pd

col_name_map = {
    "test_name": "Test Name",
    "PERFORMANCE_METRIC_VALUE:original_accuracy": "Original Accuracy",
    "PERFORMANCE_METRIC_VALUE:perturbed_accuracy": "Perturbed Accuracy",
    "ATTACK_DETAILS:avg_queries": "Average Number of Queries",
}

def _all_col_names_in_summary_df(col_names, summary_df):
    for col_name in col_names:
        if col_name not in summary_df.index:
            return False
    return True

test_batches = test_run.get_test_batches()
all_summaries = []
for batch in test_batches:
    summary_df = batch.summary(show_batch_metrics=True)
    test_name = summary_df['test_name']
    if _all_col_names_in_summary_df(col_name_map.keys(), summary_df):
        all_summaries.append(summary_df[list(col_name_map.keys())])
metrics_df = pd.concat(all_summaries, axis=1).T
metrics_df = metrics_df.rename(columns=col_name_map).set_index("Test Name")
metrics_df.sort_values(by="Perturbed Accuracy")

It’s evident that while this model is fairly robust to simple transformation-style augmentations, it fails to withstand some character-level evolutionary attacks, indicating that additional data augmentation and/or a data sanitation pipeline should be applied before this model goes into production! One way to add additional augmented data to your training problem is through querying the results:

[ ]:
import pandas as pd

def filter_rows(text_series: pd.Series, label_series: pd.Series) -> pd.DataFrame:
    filter_indices = ~text_series.isna()
    return pd.DataFrame({'Augmented': text_series[filter_indices], "Labels": label_series[filter_indices]})

failed_df = test_cases_df[test_cases_df['severity'] == 'SEVERITY_ALERT']

# attacks examples
perturbed_text_col =  [col for col in test_cases_df.columns if col.endswith('perturbed_sentence')][0]
class_col = [col for col in test_cases_df.columns if col.endswith('original_class')][0]
perturbed_df = filter_rows(failed_df[perturbed_text_col], failed_df[class_col])

perturbed_df.head()
[ ]: