From b790094193e0f54ba38f60fd3be41687de4e3bb3 Mon Sep 17 00:00:00 2001 From: Rocker Zhang Date: Sun, 31 May 2026 17:36:47 +0800 Subject: [PATCH] feat: support plain text files as prompt datasets (#337) A dataset path that points to a plain file is now read as one prompt per line, with empty lines ignored. For text files, "column" is ignored and "split" is optional; when given, it selects a subset of lines using slice notation (e.g. "[:400]"). Detection uses os.path.isfile so files without an extension also work. The split-parsing logic is factored into a shared get_split_slice helper, which derives the split name from the specification, and split/column are now optional in DatasetSpecification, with the dataset branches raising a clear error when either is missing. An invalid split raises instead of being silently ignored. A bare slice does not parse with the pinned datasets version, since ReadInstruction.from_spec expects a named split, so the text branch prepends a synthetic split name. Revives the approach from #103. Closes #98. Co-authored-by: Ric --- config.default.toml | 5 ++++ src/heretic/config.py | 10 ++++++-- src/heretic/utils.py | 60 +++++++++++++++++++++++++++++++------------ 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/config.default.toml b/config.default.toml index cdfe826..9424bfb 100644 --- a/config.default.toml +++ b/config.default.toml @@ -173,6 +173,11 @@ refusal_markers = [ # System prompt to use when prompting the model. system_prompt = "You are a helpful assistant." +# Each "dataset" below can be a Hugging Face dataset ID, a path to a dataset on disk, +# or a path to a plain text file with one prompt per line (empty lines are ignored). +# For text files, "column" is ignored and "split" is optional; when given, it selects +# a subset of the lines using slice notation (e.g. "[:400]"). + # Dataset of prompts that tend to not result in refusals (used for calculating refusal directions). [good_prompts] dataset = "mlabonne/harmless_alpaca" diff --git a/src/heretic/config.py b/src/heretic/config.py index 668073a..ada5792 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -42,9 +42,15 @@ class DatasetSpecification(BaseModel): description="Hugging Face commit hash of the dataset.", ) - split: str = Field(description="Portion of the dataset to use.") + split: str | None = Field( + default=None, + description="Portion of the dataset to use. Required for datasets, optional for plain text files.", + ) - column: str = Field(description="Column in the dataset that contains the prompts.") + column: str | None = Field( + default=None, + description="Column in the dataset that contains the prompts. Required for datasets, ignored for plain text files.", + ) prefix: str = Field( default="", diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 32f78b7..778d52e 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -188,6 +188,20 @@ class Prompt: user: str +def get_split_slice(split_str: str, length: int) -> tuple[int, int]: + """Resolves a split specification into absolute (start, end) indices.""" + + # The split name is the part before the slice, e.g. "train" in "train[:400]". + split_name = split_str.split("[")[0] + # Associate the split with its number of examples (lines). + name_to_length = {split_name: length} + # Convert the instructions to absolute indices and select the first one. + absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute( + name_to_length + )[0] + return absolute_instruction.from_, absolute_instruction.to + + def load_prompts( settings: Settings, specification: DatasetSpecification, @@ -195,29 +209,41 @@ def load_prompts( path = specification.dataset split_str = specification.split - if is_hf_path(path): - dataset = load_dataset( - path, - revision=specification.commit, - split=split_str, - ) + if os.path.isfile(path): + # Plain text file with one prompt per line. Empty lines are ignored. + with open(path, encoding="utf-8") as file: + prompts = [line.strip() for line in file if line.strip()] + + # The split is optional for text files. When given, it selects a subset + # of the lines using slice notation (e.g. "[:400]"). A synthetic split + # name is prepended because ReadInstruction expects a named split. + if split_str is not None: + start, end = get_split_slice(f"_{split_str}", len(prompts)) + prompts = prompts[start:end] else: - if Path(path, DATASET_STATE_JSON_FILENAME).exists(): + # All dataset sources require an explicit split and column. + if split_str is None: + raise ValueError(f'The "split" field is required for datasets: {path}') + + if specification.column is None: + raise ValueError(f'The "column" field is required for datasets: {path}') + + if is_hf_path(path): + dataset = load_dataset( + path, + revision=specification.commit, + split=split_str, + ) + elif Path(path, DATASET_STATE_JSON_FILENAME).exists(): # Dataset saved with datasets.save_to_disk; needs special handling. # Path should be the subdirectory for a particular split. dataset = load_from_disk(path) assert not isinstance(dataset, DatasetDict), ( "Loading dataset dicts is not supported" ) - # Parse the split instructions. - instruction = ReadInstruction.from_spec(split_str) - # Associate the split with its number of examples (lines). - split_name = str(dataset.split) - name2len = {split_name: len(dataset)} - # Convert the instructions to absolute indices and select the first one. - abs_instruction = instruction.to_absolute(name2len)[0] - # Get the dataset by applying the indices. - dataset = dataset[abs_instruction.from_ : abs_instruction.to] + # Parse the split instructions and apply them. + start, end = get_split_slice(split_str, len(dataset)) + dataset = dataset[start:end] else: # Path should be a local directory. dataset = load_dataset( @@ -229,7 +255,7 @@ def load_prompts( download_mode=DownloadMode.FORCE_REDOWNLOAD, ) - prompts = list(dataset[specification.column]) + prompts = list(dataset[specification.column]) if specification.prefix: prompts = [f"{specification.prefix} {prompt}" for prompt in prompts]