Skip to content

Commit

Permalink
Merge pull request #9 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Merge dev into main
  • Loading branch information
WenjieDu authored Mar 19, 2023
2 parents 25cec2b + ecaab5a commit 2f2f8d7
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish_to_PyPI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build numpy setuptools
pip install build
- name: Build package
run: python -m build
- name: Publish the new package to PyPI
Expand Down
32 changes: 19 additions & 13 deletions tsdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,24 @@
# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from tsdb.__version__ import version as __version__

from .__version__ import version as __version__
try:
from tsdb.database import (
list_database,
list_available_datasets,
)

from .data_processing import (
window_truncate,
download_and_extract,
load_dataset,
delete_cached_data,
list_cached_data,
CACHED_DATASET_DIR,
pickle_load,
pickle_dump,
)

from .database import list_database, list_available_datasets
from tsdb.data_processing import (
window_truncate,
download_and_extract,
load_dataset,
delete_cached_data,
purge_given_path,
list_cached_data,
CACHED_DATASET_DIR,
pickle_dump,
pickle_load,
)
except Exception as e:
print(e)
10 changes: 5 additions & 5 deletions tsdb/data_loading_funcs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from .beijing_multisite_air_quality import load_beijing_air_quality
from .electricity_load_diagrams import load_electricity
from .physionet_2012 import load_physionet2012
from .physionet_2019 import load_physionet2019
from .ucr_uea_datasets import load_ucr_uea_dataset
from tsdb.data_loading_funcs.beijing_multisite_air_quality import load_beijing_air_quality
from tsdb.data_loading_funcs.electricity_load_diagrams import load_electricity
from tsdb.data_loading_funcs.physionet_2012 import load_physionet2012
from tsdb.data_loading_funcs.physionet_2019 import load_physionet2019
from tsdb.data_loading_funcs.ucr_uea_datasets import load_ucr_uea_dataset
2 changes: 1 addition & 1 deletion tsdb/data_loading_funcs/ucr_uea_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,5 +313,5 @@ def _load_txt_uea(dataset_path):
"""
data = numpy.loadtxt(dataset_path)
X = to_time_series_dataset(data[:, 1:])
y = data[:, 0].astype(numpy.int)
y = data[:, 0].astype(int)
return X, y
159 changes: 97 additions & 62 deletions tsdb/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
import os
import pickle
import shutil
import sys
import tempfile
import urllib.request
import warnings
from sys import exit
from urllib.request import urlretrieve

import numpy as np
import numpy

from tsdb.data_loading_funcs import *
from tsdb.database import DATABASE, AVAILABLE_DATASETS

CACHED_DATASET_DIR = os.path.join(os.path.expanduser('~'), ".tsdb_cached_datasets")
CACHED_DATASET_DIR = os.path.join(os.path.expanduser("~"), ".tsdb_cached_datasets")


def window_truncate(feature_vectors, seq_len):
""" Generate time series samples, truncating windows from time-series data with a given sequence length.
"""Generate time series samples, truncating windows from time-series data with a given sequence length.
Parameters
----------
Expand All @@ -36,16 +36,16 @@ def window_truncate(feature_vectors, seq_len):
array,
Truncated time series with given sequence length.
"""
start_indices = np.asarray(range(feature_vectors.shape[0] // seq_len)) * seq_len
start_indices = numpy.asarray(range(feature_vectors.shape[0] // seq_len)) * seq_len
sample_collector = []
for idx in start_indices:
sample_collector.append(feature_vectors[idx: idx + seq_len])

return np.asarray(sample_collector).astype('float32')
return numpy.asarray(sample_collector).astype("float32")


def _download_and_extract(url, saving_path):
""" Download dataset from the given url and extract to the given saving path.
"""Download dataset from the given url and extract to the given saving path.
Parameters
----------
Expand All @@ -58,12 +58,12 @@ def _download_and_extract(url, saving_path):
-------
saving_path if successful else None
"""
no_need_decompression_format = ['csv', 'txt']
no_need_decompression_format = ["csv", "txt"]
supported_compression_format = ["zip", "tar", "gz", "bz", "xz"]

# truncate the file name from url
file_name = os.path.basename(url)
suffix = file_name.split('.')[-1]
suffix = file_name.split(".")[-1]

if suffix in no_need_decompression_format:
raw_data_saving_path = os.path.join(saving_path, file_name)
Expand All @@ -75,28 +75,31 @@ def _download_and_extract(url, saving_path):
warnings.warn(
"The compression format is not supported, aborting. "
"If necessary, please create a pull request to add according supports.",
category=RuntimeWarning
category=RuntimeWarning,
)
return None

# download and save the raw dataset
try:
urlretrieve(url, raw_data_saving_path)
urllib.request.urlretrieve(url, raw_data_saving_path)
# except Exception as e:
except Exception as e:
shutil.rmtree(saving_path, ignore_errors=True)
print(f"Exception: {e}\n"
f"Download failed. Aborting.")
exit()
print(f"Exception: {e}\n" f"Download failed. Aborting.")
sys.exit()
print(f"Successfully downloaded data to {raw_data_saving_path}.")

if suffix in supported_compression_format: # if the file is compressed, then unpack it
if (
suffix in supported_compression_format
): # if the file is compressed, then unpack it
try:
os.makedirs(saving_path, exist_ok=True)
shutil.unpack_archive(raw_data_saving_path, saving_path)
print(f"Successfully extracted data to {saving_path}")
except shutil.Error:
warnings.warn("The compressed file is corrupted, aborting.", category=RuntimeWarning)
warnings.warn(
"The compressed file is corrupted, aborting.", category=RuntimeWarning
)
return None
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
Expand All @@ -105,7 +108,7 @@ def _download_and_extract(url, saving_path):


def download_and_extract(dataset_name, dataset_saving_path):
""" Wrapper of _download_and_extract.
"""Wrapper of _download_and_extract.
Parameters
----------
Expand All @@ -119,7 +122,7 @@ def download_and_extract(dataset_name, dataset_saving_path):
-------
"""
print('Start downloading...')
print("Start downloading...")
os.makedirs(dataset_saving_path)
if isinstance(DATABASE[dataset_name], list):
for link in DATABASE[dataset_name]:
Expand All @@ -129,7 +132,7 @@ def download_and_extract(dataset_name, dataset_saving_path):


def list_cached_data():
""" List names of all cached datasets.
"""List names of all cached datasets.
Returns
-------
Expand All @@ -145,37 +148,59 @@ def list_cached_data():


def delete_cached_data(dataset_name=None):
""" Delete CACHED_DATASET_DIR if exists.
"""
"""Delete CACHED_DATASET_DIR if exists."""
# if CACHED_DATASET_DIR does not exist, abort
if not os.path.exists(CACHED_DATASET_DIR):
print('No cached data. Operation aborted.')
exit()
print("No cached data. Operation aborted.")
sys.exit()
# if CACHED_DATASET_DIR exists, then purge
if dataset_name is not None:
assert (
dataset_name in AVAILABLE_DATASETS
), f"{dataset_name} is not available in TSDB, so it has no cache. Please check your dataset name."
dir_to_delete = os.path.join(CACHED_DATASET_DIR, dataset_name)
if not os.path.exists(dir_to_delete):
print(f"Dataset {dataset_name} is not cached. Operation aborted.")
sys.exit()
print(f"Purging cached dataset {dataset_name} under {dir_to_delete}...")
else:
dir_to_delete = CACHED_DATASET_DIR
print(f"Purging all cached data under {CACHED_DATASET_DIR}...")
purge_given_path(dir_to_delete)


def purge_given_path(path):
"""Delete the given path.
It will be deleted if a file is given. Itself and all its contents will be purged will a fold is given.
Parameters
----------
path: str,
It could be a file or a fold.
"""
assert os.path.exists(
path
), f"The given path {path} does not exists. Operation aborted."

try:
if dataset_name is not None:
assert dataset_name in AVAILABLE_DATASETS, \
f'{dataset_name} is not available in TSDB, so it has no cache. Please check your dataset name.'
dir_to_delete = os.path.join(CACHED_DATASET_DIR, dataset_name)
if not os.path.exists(dir_to_delete):
print(f'Dataset {dataset_name} is not cached. Operation aborted.')
exit()
print(f'Purging cached dataset {dataset_name} under {dir_to_delete}...')
if os.path.isdir(path):
shutil.rmtree(path, ignore_errors=True)
else:
dir_to_delete = CACHED_DATASET_DIR
print(f'Purging all cached data under {CACHED_DATASET_DIR}...')
shutil.rmtree(dir_to_delete, ignore_errors=True)
os.remove(path)
# check if succeed
if not os.path.exists(dir_to_delete):
print('Purged successfully!')
if not os.path.exists(path):
print(f"Successfully deleted {path}.")
else:
raise FileExistsError(f'Deleting operation failed. {CACHED_DATASET_DIR} still exists.')
raise FileExistsError(
f"Deleting operation failed. {CACHED_DATASET_DIR} still exists."
)
except shutil.Error:
raise shutil.Error('Operation failed.')
raise shutil.Error("Operation failed.")


def pickle_dump(data, path):
""" Pickle the given object.
"""Pickle the given object.
Parameters
----------
Expand All @@ -191,17 +216,17 @@ def pickle_dump(data, path):
"""
try:
with open(path, 'wb') as f:
with open(path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
except pickle.PicklingError:
print('Pickling failed. No cache will be saved.')
print("Pickling failed. No cache will be saved.")
return None
print(f'Successfully saved to {path}')
print(f"Successfully saved to {path}")
return path


def pickle_load(path):
""" Load pickled object from file.
"""Load pickled object from file.
Parameters
----------
Expand All @@ -215,21 +240,20 @@ def pickle_load(path):
"""
try:
with open(path, 'rb') as f:
with open(path, "rb") as f:
data = pickle.load(f)
except pickle.UnpicklingError as e:
print('Cached data corrupted. Aborting...\n'
f'{e}')
print("Cached data corrupted. Aborting...\n" f"{e}")
return data


def load_dataset(dataset_name, use_cache=True):
""" Load dataset with given name.
"""Load dataset with given name.
Parameters
----------
dataset_name : str,
The name of the specific dataset in DATABASE.
The name of the specific dataset in database.DATABASE.
use_cache : bool,
Whether to use cache (including data downloading and processing)
Expand All @@ -239,43 +263,54 @@ def load_dataset(dataset_name, use_cache=True):
pandas.DataFrame,
Loaded dataset.
"""
assert dataset_name in AVAILABLE_DATASETS, f'Input dataset name "{dataset_name}" is not in the database {AVAILABLE_DATASETS}.'
assert dataset_name in AVAILABLE_DATASETS, \
f'The given dataset name "{dataset_name}" is not in the database. ' \
f'Please fetch the full list of the available datasets with tsdb.list_available_datasets()'

dataset_saving_path = os.path.join(CACHED_DATASET_DIR, dataset_name)
if not os.path.exists(dataset_saving_path): # if the dataset is not cached, then download it
if not os.path.exists(
dataset_saving_path
): # if the dataset is not cached, then download it
download_and_extract(dataset_name, dataset_saving_path)
else:
if use_cache:
print(f'Dataset {dataset_name} has already been downloaded. Processing directly...')
print(
f"Dataset {dataset_name} has already been downloaded. Processing directly..."
)
else:
# if not use cache, then delete the downloaded data dir (including processing cache)
shutil.rmtree(dataset_saving_path, ignore_errors=True)
download_and_extract(dataset_name, dataset_saving_path)

# if cached, then load directly
cache_path = os.path.join(dataset_saving_path, dataset_name + '_cache.pkl')
cache_path = os.path.join(dataset_saving_path, dataset_name + "_cache.pkl")
if os.path.exists(cache_path):
print(f'Dataset {dataset_name} has already been cached. Loading from cache directly...')
print(
f"Dataset {dataset_name} has already been cached. Loading from cache directly..."
)
result = pickle_load(cache_path)
else:
try:
if dataset_name == 'physionet_2012':
if dataset_name == "physionet_2012":
result = load_physionet2012(dataset_saving_path)
if dataset_name == 'physionet_2019':
if dataset_name == "physionet_2019":
result = load_physionet2019(dataset_saving_path)
elif dataset_name == 'electricity_load_diagrams':
elif dataset_name == "electricity_load_diagrams":
result = load_electricity(dataset_saving_path)
elif dataset_name == 'beijing_multisite_air_quality':
elif dataset_name == "beijing_multisite_air_quality":
result = load_beijing_air_quality(dataset_saving_path)
elif 'UCR_UEA_' in dataset_name:
actual_dataset_name = dataset_name.replace('UCR_UEA_', '') # delete 'UCR_UEA_' in the name
elif "UCR_UEA_" in dataset_name:
actual_dataset_name = dataset_name.replace(
"UCR_UEA_", ""
) # delete 'UCR_UEA_' in the name
result = load_ucr_uea_dataset(dataset_saving_path, actual_dataset_name)

except FileExistsError:
shutil.rmtree(dataset_saving_path, ignore_errors=True)
warnings.warn(
'Dataset corrupted, already deleted. Please rerun load_specific_dataset() to re-download the raw data.'
"Dataset corrupted, already deleted. Please rerun load_specific_dataset() to re-download the raw data."
)
pickle_dump(result, cache_path)

print('Loaded successfully!')
print("Loaded successfully!")
return result
Loading

0 comments on commit 2f2f8d7

Please sign in to comment.