SengTak's picture
added necesary files-1
55890ea verified
raw
history blame
1.34 kB
from .dataset import DataSample, TrainSample, Dataset
from accelerate.logging import get_logger
logger = get_logger(__name__, log_level="INFO")
class Wiki1M(Dataset):
def __init__(
self,
dataset_name: str = "Wiki1M",
split: str = "validation",
file_path: str = "cache/wiki1m_for_simcse.txt",
):
self.dataset_name = dataset_name
self.split = split
self.data = []
self.load_data(file_path)
def __len__(self):
return len(self.data)
def load_data(self, file_path: str = None):
logger.info(f"Loading Wiki1M data from {file_path}...")
id_ = 0
with open(file_path, "r") as f:
for line in f:
line = line.strip()
self.data.append(
DataSample(
id_=id_,
query=line,
positive=line,
)
)
id_ += 1
logger.info(f"Loaded {len(self.data)} samples.")
def __getitem__(self, index):
sample = self.data[index]
if self.split == "train":
return TrainSample(texts=[sample.query, sample.positive], label=1.0)
elif self.split == "validation":
assert False, "Wiki1M does not have a validation split."