download
This module downloads all necessary datasets.
Usage examples:
import download
download.download_cifar100()
download.download_cifar10()
download.download_kaggle()
download.download_all_datasets() # Download all 3 at once
1""" 2.. include:: ../docs/download.md 3""" 4 5import tarfile 6from os import makedirs, path, rename 7from os.path import join, dirname, isdir 8 9import requests 10import tensorflow as tf 11 12dataDir = join(dirname(__file__), "../data") 13kaggleDir = join(dataDir, "kaggle") 14cifar10DataDir: str = join(dataDir, "cifar10", "dataset") 15cifar100DataDir: str = join(dataDir, "cifar100", "dataset") 16 17 18def _download_raw_kaggle_data(): 19 kaggleCompressed = path.join(kaggleDir, "raw_data.tgz") 20 if not path.isfile(kaggleCompressed): 21 url = 'https://github.com/OliverRoss/replicating_mia_datasets/raw/master/dataset_purchase.tgz' 22 response = requests.get(url) 23 with open(kaggleCompressed, mode='wb') as file: 24 file.write(response.content) 25 26 27def _extract_kaggle_data(): 28 kaggleRaw = path.join(kaggleDir, "raw_data") 29 kaggleCompressed = path.join(kaggleDir, "raw_data.tgz") 30 31 if not path.isfile(kaggleRaw): 32 tarfile.open(kaggleCompressed).extractall(kaggleDir) 33 # "dataset_purchase" is the file name, we use the one in kaggleRaw 34 rename(path.join(kaggleDir, "dataset_purchase"), kaggleRaw) 35 36 37def download_kaggle(): 38 if not path.isdir(kaggleDir): 39 makedirs(kaggleDir) 40 _download_raw_kaggle_data() 41 _extract_kaggle_data() 42 43 44def download_cifar10(): 45 if not isdir(cifar10DataDir): 46 (_, _), (_, _) = tf.keras.datasets.cifar10.load_data() 47 48 49def download_cifar100(): 50 if not isdir(cifar100DataDir): 51 (_, _), (_, _) = tf.keras.datasets.cifar100.load_data() 52 53 54def download_all_datasets(): 55 print("Downloading all datasets.") 56 download_cifar100() 57 download_cifar10() 58 download_kaggle() 59 60 61def download_dataset(datasetName: str): 62 if datasetName == "cifar10": 63 download_cifar10() 64 elif datasetName == "cifar100": 65 download_cifar100() 66 elif "kaggle" in datasetName: 67 download_kaggle() 68 else: 69 raise ValueError(f"{datasetName} is not a known dataset.") 70 71 72if __name__ == "__main__": 73 import argparse 74 import configuration as con 75 parser = argparse.ArgumentParser(description='Make sure the needed dataset is downloaded.') 76 parser.add_argument('--config', help='Relative path to config file.',) 77 config = con.from_cli_options(vars(parser.parse_args())) 78 dataName = config["targetDataset"]["name"] 79 print(f"Downloading data for {dataName}, if necessary.") 80 download_dataset(dataName)
def
download_kaggle():
def
download_cifar10():
def
download_cifar100():
def
download_all_datasets():
def
download_dataset(datasetName: str):