Spaces:
Sleeping
Sleeping
from .dataset import DataSample, TrainSample, Dataset | |
from accelerate.logging import get_logger | |
logger = get_logger(__name__, log_level="INFO") | |
class Wiki1M(Dataset): | |
def __init__( | |
self, | |
dataset_name: str = "Wiki1M", | |
split: str = "validation", | |
file_path: str = "cache/wiki1m_for_simcse.txt", | |
): | |
self.dataset_name = dataset_name | |
self.split = split | |
self.data = [] | |
self.load_data(file_path) | |
def __len__(self): | |
return len(self.data) | |
def load_data(self, file_path: str = None): | |
logger.info(f"Loading Wiki1M data from {file_path}...") | |
id_ = 0 | |
with open(file_path, "r") as f: | |
for line in f: | |
line = line.strip() | |
self.data.append( | |
DataSample( | |
id_=id_, | |
query=line, | |
positive=line, | |
) | |
) | |
id_ += 1 | |
logger.info(f"Loaded {len(self.data)} samples.") | |
def __getitem__(self, index): | |
sample = self.data[index] | |
if self.split == "train": | |
return TrainSample(texts=[sample.query, sample.positive], label=1.0) | |
elif self.split == "validation": | |
assert False, "Wiki1M does not have a validation split." | |