attack_pipeline

This module plugs the ends of the MIA pipeline together. The attack models are loaded and evaluated on data that is concatenated from the target model's training data and unseen data.

  1"""
  2.. include:: ../docs/attack_pipeline.md
  3"""
  4
  5from os import environ
  6
  7# Tensorflow C++ backend logging verbosity
  8environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # NOQA
  9
 10from typing import Dict, Tuple
 11
 12import numpy as np
 13from tensorflow.data import Dataset  # pyright: ignore
 14from tensorflow.python.framework import random_seed
 15
 16import utils
 17import datasets as ds
 18import target_models as tm
 19import attack_model as am
 20
 21global_seed: int = 1234
 22
 23
 24def set_seed(new_seed: int):
 25    """
 26    Set the global seed that will be used for all functions that include
 27    randomness.
 28    """
 29    global global_seed
 30    global_seed = new_seed
 31    random_seed.set_seed(global_seed)
 32
 33def load_target_data(config:Dict) -> Tuple[Dataset, Dataset]:
 34    """
 35    Returns tuple trainData, restData.
 36
 37    RestData is data unused for training and testing previously.
 38    """
 39    targetModelName = tm.get_model_name(config)
 40    targetTrainDataName = targetModelName + "_train_data"
 41    targetRestDataName = targetModelName + "_rest_data"
 42    targetTrainData = ds.load_target(targetTrainDataName)
 43    targetRestData = ds.load_target(targetRestDataName)
 44    return targetTrainData, targetRestData
 45
 46def run_pipeline(targetModel, targetTrainData, targetRestData):
 47    # TODO: batchSize is hardcoded
 48    numClasses = config["targetModel"]["classes"]
 49    batchSizeTarget = 100
 50    batchSizeAttack = config["attackModel"]["hyperparameters"]["batchSize"]
 51    targetTrainDataSize = config["targetDataset"]["trainSize"]
 52
 53    hash = utils.hash(str(config))
 54
 55    try:
 56        memberAttackPredictions = ds.load_numpy_array(f"{hash}_memberAttackPredictions.npy")
 57        nonmemberAttackPredictions = ds.load_numpy_array(f"{hash}_nonmemberAttackPredictions.npy")
 58
 59    except:
 60        attackModels = am.get_attack_models(config, [])
 61
 62        membersDataset = targetTrainData
 63        nonmembersDataset = targetRestData.take(targetTrainDataSize)
 64
 65        memberTargetPredictions = targetModel.predict(membersDataset.batch(batchSizeTarget))
 66        nonmemberTargetPredictions = targetModel.predict(nonmembersDataset.batch(batchSizeTarget))
 67
 68        memberAttackPredictions = [[] for _ in range(numClasses)]
 69        nonmemberAttackPredictions = [[] for _ in range(numClasses)]
 70
 71        print("Predicting members.")
 72        for i, targetPrediction in enumerate(memberTargetPredictions):
 73            label = np.argmax(targetPrediction)
 74            # select respective attack model, trained for that class
 75            attackModel = attackModels[label]
 76            modelInput = Dataset.from_tensors(targetPrediction).batch(batchSizeAttack)
 77            attackPrediction = attackModel.predict(modelInput,verbose = 0)
 78            memberAttackPredictions[label].append(np.argmax(attackPrediction))
 79            if i % 100 == 0 and config["verbose"]:
 80                print(f"Predicted {i}/{targetTrainDataSize} member records on attack model.")
 81
 82        print("Predicting nonmembers.")
 83        for i, targetPrediction in enumerate(nonmemberTargetPredictions):
 84            label = np.argmax(targetPrediction)
 85            # select respective attack model, trained for that class
 86            attackModel = attackModels[label]
 87            modelInput = Dataset.from_tensors(targetPrediction).batch(batchSizeAttack)
 88            attackPrediction = attackModel.predict(modelInput, verbose = 0)
 89            nonmemberAttackPredictions[label].append(np.argmax(attackPrediction))
 90            if i % 100 == 0 and config["verbose"]:
 91                print(f"Predicted {i}/{targetTrainDataSize} nonmember records on attack model.")
 92
 93        ds.save_numpy_array(f"{hash}_memberAttackPredictions.npy",memberAttackPredictions)
 94        ds.save_numpy_array(f"{hash}_nonmemberAttackPredictions.npy",nonmemberAttackPredictions)
 95
 96    precisionPerClass = [None for _ in range(numClasses)]
 97    recallPerClass = [None for _ in range(numClasses)]
 98
 99    for _class in range(numClasses):
100        memberAttackPrediction = memberAttackPredictions[_class]
101        if memberAttackPrediction:
102            recallPerClass[_class] = 1 - np.average(memberAttackPrediction)
103
104        nonmemberAttackPrediction = nonmemberAttackPredictions[_class]
105        if nonmemberAttackPrediction:
106            membersInferredAsMembers = len(memberAttackPrediction) - np.count_nonzero(memberAttackPrediction)
107            nonmembersInferredAsMembers = len(nonmemberAttackPrediction) - np.count_nonzero(nonmemberAttackPrediction)
108            if (membersInferredAsMembers + nonmembersInferredAsMembers):
109                precisionPerClass[_class] = membersInferredAsMembers / (membersInferredAsMembers + nonmembersInferredAsMembers)
110
111    membersInferredAsMembers = targetTrainDataSize - sum([sum(x) for x in memberAttackPredictions])
112    nonmembersInferredAsMembers = targetTrainDataSize - sum([sum(x) for x in nonmemberAttackPredictions])
113    totalRecall = membersInferredAsMembers/targetTrainDataSize
114    totalPrecision = membersInferredAsMembers / (membersInferredAsMembers + nonmembersInferredAsMembers)
115    return totalPrecision, totalRecall, precisionPerClass, recallPerClass
116
117def process_results(precision, recall, precisionPerClass, recallPerClass):
118
119    hash = utils.hash(str(config))
120    with open(f"{hash}_recallPerClass.csv",'w') as file:
121        file.write(f"Recall (Overall:{recall})\n")
122        for recall in recallPerClass:
123            file.write(f"{recall}\n")
124    with open(f"{hash}_precisionPerClass.csv",'w') as file:
125        file.write(f"Precision (Overall: {precision})\n")
126        for precision in precisionPerClass:
127            file.write(f"{precision}\n")
128
129if __name__ == "__main__":
130    import argparse
131    import configuration as con
132
133    parser = argparse.ArgumentParser(description='Run the attack pipeline on the target model.')
134    parser.add_argument('--config', help='Relative path to config file.',)
135    config = con.from_cli_options(vars(parser.parse_args()))
136    set_seed(config["seed"])
137
138    
139    targetDataset = ds.load_dataset(config["targetDataset"]["name"])
140    targetModel = tm.get_target_model(config, targetDataset)
141
142    targetTrainData, targetRestData = load_target_data(config)
143
144    precision, recall, precisionPerClass, recallPerClass = run_pipeline(targetModel, targetTrainData, targetRestData)
145
146    process_results(precision, recall, precisionPerClass, recallPerClass)
def set_seed(new_seed: int):
25def set_seed(new_seed: int):
26    """
27    Set the global seed that will be used for all functions that include
28    randomness.
29    """
30    global global_seed
31    global_seed = new_seed
32    random_seed.set_seed(global_seed)

Set the global seed that will be used for all functions that include randomness.

def load_target_data( config: Dict) -> Tuple[tensorflow.python.data.ops.dataset_ops.DatasetV2, tensorflow.python.data.ops.dataset_ops.DatasetV2]:
34def load_target_data(config:Dict) -> Tuple[Dataset, Dataset]:
35    """
36    Returns tuple trainData, restData.
37
38    RestData is data unused for training and testing previously.
39    """
40    targetModelName = tm.get_model_name(config)
41    targetTrainDataName = targetModelName + "_train_data"
42    targetRestDataName = targetModelName + "_rest_data"
43    targetTrainData = ds.load_target(targetTrainDataName)
44    targetRestData = ds.load_target(targetRestDataName)
45    return targetTrainData, targetRestData

Returns tuple trainData, restData.

RestData is data unused for training and testing previously.

def run_pipeline(targetModel, targetTrainData, targetRestData):
 47def run_pipeline(targetModel, targetTrainData, targetRestData):
 48    # TODO: batchSize is hardcoded
 49    numClasses = config["targetModel"]["classes"]
 50    batchSizeTarget = 100
 51    batchSizeAttack = config["attackModel"]["hyperparameters"]["batchSize"]
 52    targetTrainDataSize = config["targetDataset"]["trainSize"]
 53
 54    hash = utils.hash(str(config))
 55
 56    try:
 57        memberAttackPredictions = ds.load_numpy_array(f"{hash}_memberAttackPredictions.npy")
 58        nonmemberAttackPredictions = ds.load_numpy_array(f"{hash}_nonmemberAttackPredictions.npy")
 59
 60    except:
 61        attackModels = am.get_attack_models(config, [])
 62
 63        membersDataset = targetTrainData
 64        nonmembersDataset = targetRestData.take(targetTrainDataSize)
 65
 66        memberTargetPredictions = targetModel.predict(membersDataset.batch(batchSizeTarget))
 67        nonmemberTargetPredictions = targetModel.predict(nonmembersDataset.batch(batchSizeTarget))
 68
 69        memberAttackPredictions = [[] for _ in range(numClasses)]
 70        nonmemberAttackPredictions = [[] for _ in range(numClasses)]
 71
 72        print("Predicting members.")
 73        for i, targetPrediction in enumerate(memberTargetPredictions):
 74            label = np.argmax(targetPrediction)
 75            # select respective attack model, trained for that class
 76            attackModel = attackModels[label]
 77            modelInput = Dataset.from_tensors(targetPrediction).batch(batchSizeAttack)
 78            attackPrediction = attackModel.predict(modelInput,verbose = 0)
 79            memberAttackPredictions[label].append(np.argmax(attackPrediction))
 80            if i % 100 == 0 and config["verbose"]:
 81                print(f"Predicted {i}/{targetTrainDataSize} member records on attack model.")
 82
 83        print("Predicting nonmembers.")
 84        for i, targetPrediction in enumerate(nonmemberTargetPredictions):
 85            label = np.argmax(targetPrediction)
 86            # select respective attack model, trained for that class
 87            attackModel = attackModels[label]
 88            modelInput = Dataset.from_tensors(targetPrediction).batch(batchSizeAttack)
 89            attackPrediction = attackModel.predict(modelInput, verbose = 0)
 90            nonmemberAttackPredictions[label].append(np.argmax(attackPrediction))
 91            if i % 100 == 0 and config["verbose"]:
 92                print(f"Predicted {i}/{targetTrainDataSize} nonmember records on attack model.")
 93
 94        ds.save_numpy_array(f"{hash}_memberAttackPredictions.npy",memberAttackPredictions)
 95        ds.save_numpy_array(f"{hash}_nonmemberAttackPredictions.npy",nonmemberAttackPredictions)
 96
 97    precisionPerClass = [None for _ in range(numClasses)]
 98    recallPerClass = [None for _ in range(numClasses)]
 99
100    for _class in range(numClasses):
101        memberAttackPrediction = memberAttackPredictions[_class]
102        if memberAttackPrediction:
103            recallPerClass[_class] = 1 - np.average(memberAttackPrediction)
104
105        nonmemberAttackPrediction = nonmemberAttackPredictions[_class]
106        if nonmemberAttackPrediction:
107            membersInferredAsMembers = len(memberAttackPrediction) - np.count_nonzero(memberAttackPrediction)
108            nonmembersInferredAsMembers = len(nonmemberAttackPrediction) - np.count_nonzero(nonmemberAttackPrediction)
109            if (membersInferredAsMembers + nonmembersInferredAsMembers):
110                precisionPerClass[_class] = membersInferredAsMembers / (membersInferredAsMembers + nonmembersInferredAsMembers)
111
112    membersInferredAsMembers = targetTrainDataSize - sum([sum(x) for x in memberAttackPredictions])
113    nonmembersInferredAsMembers = targetTrainDataSize - sum([sum(x) for x in nonmemberAttackPredictions])
114    totalRecall = membersInferredAsMembers/targetTrainDataSize
115    totalPrecision = membersInferredAsMembers / (membersInferredAsMembers + nonmembersInferredAsMembers)
116    return totalPrecision, totalRecall, precisionPerClass, recallPerClass
def process_results(precision, recall, precisionPerClass, recallPerClass):
118def process_results(precision, recall, precisionPerClass, recallPerClass):
119
120    hash = utils.hash(str(config))
121    with open(f"{hash}_recallPerClass.csv",'w') as file:
122        file.write(f"Recall (Overall:{recall})\n")
123        for recall in recallPerClass:
124            file.write(f"{recall}\n")
125    with open(f"{hash}_precisionPerClass.csv",'w') as file:
126        file.write(f"Precision (Overall: {precision})\n")
127        for precision in precisionPerClass:
128            file.write(f"{precision}\n")