from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, total_data):
# total_data是一个10000*10的numpy.ndarray类型数据,其中前9列是特征,最后一列是标签
self.features = total_data[:, :9]
self.label = total_data[:, -1]
def __getitem__(self, idx):
y = self.features[idx]
x = self.label[idx]
return {'x':x, 'y':y}
def __len__(self):
return len(self.features)
from torch.utils.data import Sampler
import numpy as np
class MySampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
if len(self.data_source) >= 10000:
return iter(range(len(self.data_source)))
else:
return np.random.randint(len(self.data_source))
def __len__(self):
return len(self.data_source)
import torch
def my_collate_fn(data)
x_list = [data[i]['x'] for i in range(len(data))]
y_list = [data[i]['y'] for i in range(len(data))]
x_tensor = torch.FloatTensor(x_list)
y_tensor = torch.FloatTensor(y_list)
return {'x': x_tensor, 'y': y_tensor}