black formatted files before changes
This commit is contained in:
@@ -19,15 +19,22 @@ class ODDSDataset(Dataset):
|
||||
"""
|
||||
|
||||
urls = {
|
||||
'arrhythmia': 'https://www.dropbox.com/s/lmlwuspn1sey48r/arrhythmia.mat?dl=1',
|
||||
'cardio': 'https://www.dropbox.com/s/galg3ihvxklf0qi/cardio.mat?dl=1',
|
||||
'satellite': 'https://www.dropbox.com/s/dpzxp8jyr9h93k5/satellite.mat?dl=1',
|
||||
'satimage-2': 'https://www.dropbox.com/s/hckgvu9m6fs441p/satimage-2.mat?dl=1',
|
||||
'shuttle': 'https://www.dropbox.com/s/mk8ozgisimfn3dw/shuttle.mat?dl=1',
|
||||
'thyroid': 'https://www.dropbox.com/s/bih0e15a0fukftb/thyroid.mat?dl=1'
|
||||
"arrhythmia": "https://www.dropbox.com/s/lmlwuspn1sey48r/arrhythmia.mat?dl=1",
|
||||
"cardio": "https://www.dropbox.com/s/galg3ihvxklf0qi/cardio.mat?dl=1",
|
||||
"satellite": "https://www.dropbox.com/s/dpzxp8jyr9h93k5/satellite.mat?dl=1",
|
||||
"satimage-2": "https://www.dropbox.com/s/hckgvu9m6fs441p/satimage-2.mat?dl=1",
|
||||
"shuttle": "https://www.dropbox.com/s/mk8ozgisimfn3dw/shuttle.mat?dl=1",
|
||||
"thyroid": "https://www.dropbox.com/s/bih0e15a0fukftb/thyroid.mat?dl=1",
|
||||
}
|
||||
|
||||
def __init__(self, root: str, dataset_name: str, train=True, random_state=None, download=False):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
dataset_name: str,
|
||||
train=True,
|
||||
random_state=None,
|
||||
download=False,
|
||||
):
|
||||
super(Dataset, self).__init__()
|
||||
|
||||
self.classes = [0, 1]
|
||||
@@ -37,25 +44,25 @@ class ODDSDataset(Dataset):
|
||||
self.root = Path(root)
|
||||
self.dataset_name = dataset_name
|
||||
self.train = train # training set or test set
|
||||
self.file_name = self.dataset_name + '.mat'
|
||||
self.file_name = self.dataset_name + ".mat"
|
||||
self.data_file = self.root / self.file_name
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
mat = loadmat(self.data_file)
|
||||
X = mat['X']
|
||||
y = mat['y'].ravel()
|
||||
X = mat["X"]
|
||||
y = mat["y"].ravel()
|
||||
idx_norm = y == 0
|
||||
idx_out = y == 1
|
||||
|
||||
# 60% data for training and 40% for testing; keep outlier ratio
|
||||
X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(X[idx_norm], y[idx_norm],
|
||||
test_size=0.4,
|
||||
random_state=random_state)
|
||||
X_train_out, X_test_out, y_train_out, y_test_out = train_test_split(X[idx_out], y[idx_out],
|
||||
test_size=0.4,
|
||||
random_state=random_state)
|
||||
X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(
|
||||
X[idx_norm], y[idx_norm], test_size=0.4, random_state=random_state
|
||||
)
|
||||
X_train_out, X_test_out, y_train_out, y_test_out = train_test_split(
|
||||
X[idx_out], y[idx_out], test_size=0.4, random_state=random_state
|
||||
)
|
||||
X_train = np.concatenate((X_train_norm, X_train_out))
|
||||
X_test = np.concatenate((X_test_norm, X_test_out))
|
||||
y_train = np.concatenate((y_train_norm, y_train_out))
|
||||
@@ -88,7 +95,11 @@ class ODDSDataset(Dataset):
|
||||
Returns:
|
||||
tuple: (sample, target, semi_target, index)
|
||||
"""
|
||||
sample, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index])
|
||||
sample, target, semi_target = (
|
||||
self.data[index],
|
||||
int(self.targets[index]),
|
||||
int(self.semi_targets[index]),
|
||||
)
|
||||
|
||||
return sample, target, semi_target, index
|
||||
|
||||
@@ -107,4 +118,4 @@ class ODDSDataset(Dataset):
|
||||
# download file
|
||||
download_url(self.urls[self.dataset_name], self.root, self.file_name)
|
||||
|
||||
print('Done!')
|
||||
print("Done!")
|
||||
|
||||
Reference in New Issue
Block a user