feat: support plain text files as prompt datasets (#337)

A dataset path that points to a plain file is now read as one prompt per
line, with empty lines ignored. For text files, "column" is ignored and
"split" is optional; when given, it selects a subset of lines using slice
notation (e.g. "[:400]").

Detection uses os.path.isfile so files without an extension also work. The
split-parsing logic is factored into a shared get_split_slice helper, which
derives the split name from the specification, and split/column are now
optional in DatasetSpecification, with the dataset branches raising a clear
error when either is missing. An invalid split raises instead of being
silently ignored.

A bare slice does not parse with the pinned datasets version, since
ReadInstruction.from_spec expects a named split, so the text branch prepends
a synthetic split name.

Revives the approach from #103.

Closes #98.

Co-authored-by: Ric <ricyoung@gmail.com>
This commit is contained in:
Rocker Zhang
2026-05-31 17:36:47 +08:00
committed by GitHub
parent 6338e2c99b
commit b790094193
3 changed files with 56 additions and 19 deletions
+5
View File
@@ -173,6 +173,11 @@ refusal_markers = [
# System prompt to use when prompting the model. # System prompt to use when prompting the model.
system_prompt = "You are a helpful assistant." system_prompt = "You are a helpful assistant."
# Each "dataset" below can be a Hugging Face dataset ID, a path to a dataset on disk,
# or a path to a plain text file with one prompt per line (empty lines are ignored).
# For text files, "column" is ignored and "split" is optional; when given, it selects
# a subset of the lines using slice notation (e.g. "[:400]").
# Dataset of prompts that tend to not result in refusals (used for calculating refusal directions). # Dataset of prompts that tend to not result in refusals (used for calculating refusal directions).
[good_prompts] [good_prompts]
dataset = "mlabonne/harmless_alpaca" dataset = "mlabonne/harmless_alpaca"
+8 -2
View File
@@ -42,9 +42,15 @@ class DatasetSpecification(BaseModel):
description="Hugging Face commit hash of the dataset.", description="Hugging Face commit hash of the dataset.",
) )
split: str = Field(description="Portion of the dataset to use.") split: str | None = Field(
default=None,
description="Portion of the dataset to use. Required for datasets, optional for plain text files.",
)
column: str = Field(description="Column in the dataset that contains the prompts.") column: str | None = Field(
default=None,
description="Column in the dataset that contains the prompts. Required for datasets, ignored for plain text files.",
)
prefix: str = Field( prefix: str = Field(
default="", default="",
+43 -17
View File
@@ -188,6 +188,20 @@ class Prompt:
user: str user: str
def get_split_slice(split_str: str, length: int) -> tuple[int, int]:
"""Resolves a split specification into absolute (start, end) indices."""
# The split name is the part before the slice, e.g. "train" in "train[:400]".
split_name = split_str.split("[")[0]
# Associate the split with its number of examples (lines).
name_to_length = {split_name: length}
# Convert the instructions to absolute indices and select the first one.
absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute(
name_to_length
)[0]
return absolute_instruction.from_, absolute_instruction.to
def load_prompts( def load_prompts(
settings: Settings, settings: Settings,
specification: DatasetSpecification, specification: DatasetSpecification,
@@ -195,29 +209,41 @@ def load_prompts(
path = specification.dataset path = specification.dataset
split_str = specification.split split_str = specification.split
if is_hf_path(path): if os.path.isfile(path):
dataset = load_dataset( # Plain text file with one prompt per line. Empty lines are ignored.
path, with open(path, encoding="utf-8") as file:
revision=specification.commit, prompts = [line.strip() for line in file if line.strip()]
split=split_str,
) # The split is optional for text files. When given, it selects a subset
# of the lines using slice notation (e.g. "[:400]"). A synthetic split
# name is prepended because ReadInstruction expects a named split.
if split_str is not None:
start, end = get_split_slice(f"_{split_str}", len(prompts))
prompts = prompts[start:end]
else: else:
if Path(path, DATASET_STATE_JSON_FILENAME).exists(): # All dataset sources require an explicit split and column.
if split_str is None:
raise ValueError(f'The "split" field is required for datasets: {path}')
if specification.column is None:
raise ValueError(f'The "column" field is required for datasets: {path}')
if is_hf_path(path):
dataset = load_dataset(
path,
revision=specification.commit,
split=split_str,
)
elif Path(path, DATASET_STATE_JSON_FILENAME).exists():
# Dataset saved with datasets.save_to_disk; needs special handling. # Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split. # Path should be the subdirectory for a particular split.
dataset = load_from_disk(path) dataset = load_from_disk(path)
assert not isinstance(dataset, DatasetDict), ( assert not isinstance(dataset, DatasetDict), (
"Loading dataset dicts is not supported" "Loading dataset dicts is not supported"
) )
# Parse the split instructions. # Parse the split instructions and apply them.
instruction = ReadInstruction.from_spec(split_str) start, end = get_split_slice(split_str, len(dataset))
# Associate the split with its number of examples (lines). dataset = dataset[start:end]
split_name = str(dataset.split)
name2len = {split_name: len(dataset)}
# Convert the instructions to absolute indices and select the first one.
abs_instruction = instruction.to_absolute(name2len)[0]
# Get the dataset by applying the indices.
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
else: else:
# Path should be a local directory. # Path should be a local directory.
dataset = load_dataset( dataset = load_dataset(
@@ -229,7 +255,7 @@ def load_prompts(
download_mode=DownloadMode.FORCE_REDOWNLOAD, download_mode=DownloadMode.FORCE_REDOWNLOAD,
) )
prompts = list(dataset[specification.column]) prompts = list(dataset[specification.column])
if specification.prefix: if specification.prefix:
prompts = [f"{specification.prefix} {prompt}" for prompt in prompts] prompts = [f"{specification.prefix} {prompt}" for prompt in prompts]