download

This module downloads all necessary datasets.

Usage examples:

 1"""
 2.. include:: ../docs/download.md
 3"""
 4
 5import tarfile
 6from os import makedirs, path, rename
 7from os.path import join, dirname, isdir
 8
 9import requests
10import tensorflow as tf
11
12dataDir = join(dirname(__file__), "../data")
13kaggleDir = join(dataDir, "kaggle")
14cifar10DataDir: str = join(dataDir, "cifar10", "dataset")
15cifar100DataDir: str = join(dataDir, "cifar100", "dataset")
16
17
18def _download_raw_kaggle_data():
19    kaggleCompressed = path.join(kaggleDir, "raw_data.tgz")
20    if not path.isfile(kaggleCompressed):
21        url = 'https://github.com/OliverRoss/replicating_mia_datasets/raw/master/dataset_purchase.tgz'
22        response = requests.get(url)
23        with open(kaggleCompressed, mode='wb') as file:
24            file.write(response.content)
25
26
27def _extract_kaggle_data():
28    kaggleRaw = path.join(kaggleDir, "raw_data")
29    kaggleCompressed = path.join(kaggleDir, "raw_data.tgz")
30
31    if not path.isfile(kaggleRaw):
32        tarfile.open(kaggleCompressed).extractall(kaggleDir)
33        # "dataset_purchase" is the file name, we use the one in kaggleRaw
34        rename(path.join(kaggleDir, "dataset_purchase"), kaggleRaw)
35
36
37def download_kaggle():
38    if not path.isdir(kaggleDir):
39        makedirs(kaggleDir)
40    _download_raw_kaggle_data()
41    _extract_kaggle_data()
42
43
44def download_cifar10():
45    if not isdir(cifar10DataDir):
46        (_, _), (_, _) = tf.keras.datasets.cifar10.load_data()
47
48
49def download_cifar100():
50    if not isdir(cifar100DataDir):
51        (_, _), (_, _) = tf.keras.datasets.cifar100.load_data()
52
53
54def download_all_datasets():
55    print("Downloading all datasets.")
56    download_cifar100()
57    download_cifar10()
58    download_kaggle()
59
60
61def download_dataset(datasetName: str):
62    if datasetName == "cifar10":
63        download_cifar10()
64    elif datasetName == "cifar100":
65        download_cifar100()
66    elif "kaggle" in datasetName:
67        download_kaggle()
68    else:
69        raise ValueError(f"{datasetName} is not a known dataset.")
70
71
72if __name__ == "__main__":
73    import argparse
74    import configuration as con
75    parser = argparse.ArgumentParser(description='Make sure the needed dataset is downloaded.')
76    parser.add_argument('--config', help='Relative path to config file.',)
77    config = con.from_cli_options(vars(parser.parse_args()))
78    dataName = config["targetDataset"]["name"]
79    print(f"Downloading data for {dataName}, if necessary.")
80    download_dataset(dataName)
def download_kaggle():
38def download_kaggle():
39    if not path.isdir(kaggleDir):
40        makedirs(kaggleDir)
41    _download_raw_kaggle_data()
42    _extract_kaggle_data()
def download_cifar10():
45def download_cifar10():
46    if not isdir(cifar10DataDir):
47        (_, _), (_, _) = tf.keras.datasets.cifar10.load_data()
def download_cifar100():
50def download_cifar100():
51    if not isdir(cifar100DataDir):
52        (_, _), (_, _) = tf.keras.datasets.cifar100.load_data()
def download_all_datasets():
55def download_all_datasets():
56    print("Downloading all datasets.")
57    download_cifar100()
58    download_cifar10()
59    download_kaggle()
def download_dataset(datasetName: str):
62def download_dataset(datasetName: str):
63    if datasetName == "cifar10":
64        download_cifar10()
65    elif datasetName == "cifar100":
66        download_cifar100()
67    elif "kaggle" in datasetName:
68        download_kaggle()
69    else:
70        raise ValueError(f"{datasetName} is not a known dataset.")