feat: support plain text files as prompt datasets (#337)
A dataset path that points to a plain file is now read as one prompt per line, with empty lines ignored. For text files, "column" is ignored and "split" is optional; when given, it selects a subset of lines using slice notation (e.g. "[:400]"). Detection uses os.path.isfile so files without an extension also work. The split-parsing logic is factored into a shared get_split_slice helper, which derives the split name from the specification, and split/column are now optional in DatasetSpecification, with the dataset branches raising a clear error when either is missing. An invalid split raises instead of being silently ignored. A bare slice does not parse with the pinned datasets version, since ReadInstruction.from_spec expects a named split, so the text branch prepends a synthetic split name. Revives the approach from #103. Closes #98. Co-authored-by: Ric <ricyoung@gmail.com>
This commit is contained in:
@@ -173,6 +173,11 @@ refusal_markers = [
|
||||
# System prompt to use when prompting the model.
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
# Each "dataset" below can be a Hugging Face dataset ID, a path to a dataset on disk,
|
||||
# or a path to a plain text file with one prompt per line (empty lines are ignored).
|
||||
# For text files, "column" is ignored and "split" is optional; when given, it selects
|
||||
# a subset of the lines using slice notation (e.g. "[:400]").
|
||||
|
||||
# Dataset of prompts that tend to not result in refusals (used for calculating refusal directions).
|
||||
[good_prompts]
|
||||
dataset = "mlabonne/harmless_alpaca"
|
||||
|
||||
@@ -42,9 +42,15 @@ class DatasetSpecification(BaseModel):
|
||||
description="Hugging Face commit hash of the dataset.",
|
||||
)
|
||||
|
||||
split: str = Field(description="Portion of the dataset to use.")
|
||||
split: str | None = Field(
|
||||
default=None,
|
||||
description="Portion of the dataset to use. Required for datasets, optional for plain text files.",
|
||||
)
|
||||
|
||||
column: str = Field(description="Column in the dataset that contains the prompts.")
|
||||
column: str | None = Field(
|
||||
default=None,
|
||||
description="Column in the dataset that contains the prompts. Required for datasets, ignored for plain text files.",
|
||||
)
|
||||
|
||||
prefix: str = Field(
|
||||
default="",
|
||||
|
||||
+37
-11
@@ -188,6 +188,20 @@ class Prompt:
|
||||
user: str
|
||||
|
||||
|
||||
def get_split_slice(split_str: str, length: int) -> tuple[int, int]:
|
||||
"""Resolves a split specification into absolute (start, end) indices."""
|
||||
|
||||
# The split name is the part before the slice, e.g. "train" in "train[:400]".
|
||||
split_name = split_str.split("[")[0]
|
||||
# Associate the split with its number of examples (lines).
|
||||
name_to_length = {split_name: length}
|
||||
# Convert the instructions to absolute indices and select the first one.
|
||||
absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute(
|
||||
name_to_length
|
||||
)[0]
|
||||
return absolute_instruction.from_, absolute_instruction.to
|
||||
|
||||
|
||||
def load_prompts(
|
||||
settings: Settings,
|
||||
specification: DatasetSpecification,
|
||||
@@ -195,29 +209,41 @@ def load_prompts(
|
||||
path = specification.dataset
|
||||
split_str = specification.split
|
||||
|
||||
if os.path.isfile(path):
|
||||
# Plain text file with one prompt per line. Empty lines are ignored.
|
||||
with open(path, encoding="utf-8") as file:
|
||||
prompts = [line.strip() for line in file if line.strip()]
|
||||
|
||||
# The split is optional for text files. When given, it selects a subset
|
||||
# of the lines using slice notation (e.g. "[:400]"). A synthetic split
|
||||
# name is prepended because ReadInstruction expects a named split.
|
||||
if split_str is not None:
|
||||
start, end = get_split_slice(f"_{split_str}", len(prompts))
|
||||
prompts = prompts[start:end]
|
||||
else:
|
||||
# All dataset sources require an explicit split and column.
|
||||
if split_str is None:
|
||||
raise ValueError(f'The "split" field is required for datasets: {path}')
|
||||
|
||||
if specification.column is None:
|
||||
raise ValueError(f'The "column" field is required for datasets: {path}')
|
||||
|
||||
if is_hf_path(path):
|
||||
dataset = load_dataset(
|
||||
path,
|
||||
revision=specification.commit,
|
||||
split=split_str,
|
||||
)
|
||||
else:
|
||||
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||
elif Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||
# Path should be the subdirectory for a particular split.
|
||||
dataset = load_from_disk(path)
|
||||
assert not isinstance(dataset, DatasetDict), (
|
||||
"Loading dataset dicts is not supported"
|
||||
)
|
||||
# Parse the split instructions.
|
||||
instruction = ReadInstruction.from_spec(split_str)
|
||||
# Associate the split with its number of examples (lines).
|
||||
split_name = str(dataset.split)
|
||||
name2len = {split_name: len(dataset)}
|
||||
# Convert the instructions to absolute indices and select the first one.
|
||||
abs_instruction = instruction.to_absolute(name2len)[0]
|
||||
# Get the dataset by applying the indices.
|
||||
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
|
||||
# Parse the split instructions and apply them.
|
||||
start, end = get_split_slice(split_str, len(dataset))
|
||||
dataset = dataset[start:end]
|
||||
else:
|
||||
# Path should be a local directory.
|
||||
dataset = load_dataset(
|
||||
|
||||
Reference in New Issue
Block a user