Improve support for loading local datasets (#33)
* Handle loading local datasets * Reorder branches to avoid chain of negatives
This commit is contained in:
+36
-2
@@ -2,8 +2,10 @@
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
import gc
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
@@ -13,7 +15,10 @@ from accelerate.utils import (
|
||||
is_sdaa_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from datasets import load_dataset
|
||||
from datasets import ReadInstruction, load_dataset, load_from_disk
|
||||
from datasets.config import DATASET_STATE_JSON_FILENAME
|
||||
from datasets.download.download_manager import DownloadMode
|
||||
from datasets.utils.info_utils import VerificationMode
|
||||
from optuna import Trial
|
||||
from rich.console import Console
|
||||
|
||||
@@ -36,7 +41,36 @@ def format_duration(seconds: float) -> str:
|
||||
|
||||
|
||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
dataset = load_dataset(specification.dataset, split=specification.split)
|
||||
path = specification.dataset
|
||||
split_str = specification.split
|
||||
if os.path.isdir(path):
|
||||
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||
# Path should be the subdirectory for a particular split.
|
||||
dataset = load_from_disk(path)
|
||||
# Parse the split instructions.
|
||||
ri = ReadInstruction.from_spec(split_str)
|
||||
# Associate the split with its number of examples (lines).
|
||||
split_name = str(dataset.split)
|
||||
name2len = {split_name: len(dataset)}
|
||||
# Convert the instructions to absolute indices and select the first one.
|
||||
abs_i = ri.to_absolute(name2len)[0]
|
||||
# Get the dataset by applying the indices.
|
||||
dataset = dataset[abs_i.from_ : abs_i.to]
|
||||
else:
|
||||
# Path is a local directory.
|
||||
dataset = load_dataset(
|
||||
path,
|
||||
split=split_str,
|
||||
# Don't require the number of examples (lines) per split to be pre-defined.
|
||||
verification_mode=VerificationMode.NO_CHECKS,
|
||||
# But also don't use cached data, as the dataset may have changed on disk.
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
||||
)
|
||||
else:
|
||||
# Probably a repository path; let load_dataset figure it out.
|
||||
dataset = load_dataset(path, split=split_str)
|
||||
|
||||
return list(dataset[specification.column])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user