feat: add functionality for collecting reproduce.json files from Hugging Face
This commit is contained in:
@@ -103,6 +103,16 @@ class Settings(BaseSettings):
|
|||||||
exclude=True,
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
collect_reproducibles: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If this directory path is set, then instead of abliterating a model, "
|
||||||
|
"download all reproduce.json files from public Heretic model repositories "
|
||||||
|
"on Hugging Face, and store them in that directory for archival purposes."
|
||||||
|
),
|
||||||
|
exclude=True,
|
||||||
|
)
|
||||||
|
|
||||||
dtypes: list[str] = Field(
|
dtypes: list[str] = Field(
|
||||||
default=[
|
default=[
|
||||||
# In practice, "auto" almost always means bfloat16.
|
# In practice, "auto" almost always means bfloat16.
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ from .analyzer import Analyzer
|
|||||||
from .config import QuantizationMethod
|
from .config import QuantizationMethod
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import AbliterationParameters, Model, get_model_class
|
from .model import AbliterationParameters, Model, get_model_class
|
||||||
|
from .reproduce import collect_reproducibles
|
||||||
from .system import empty_cache, get_accelerator_info
|
from .system import empty_cache, get_accelerator_info
|
||||||
from .utils import (
|
from .utils import (
|
||||||
format_duration,
|
format_duration,
|
||||||
@@ -177,6 +178,8 @@ def run():
|
|||||||
if (
|
if (
|
||||||
# There is at least one argument (argv[0] is the program name).
|
# There is at least one argument (argv[0] is the program name).
|
||||||
len(sys.argv) > 1
|
len(sys.argv) > 1
|
||||||
|
# Heretic is being invoked in standard (model processing) mode.
|
||||||
|
and "--collect-reproducibles" not in sys.argv
|
||||||
# No model has been explicitly provided.
|
# No model has been explicitly provided.
|
||||||
and "--model" not in sys.argv
|
and "--model" not in sys.argv
|
||||||
# The last argument is a parameter value rather than a flag (such as "--help").
|
# The last argument is a parameter value rather than a flag (such as "--help").
|
||||||
@@ -185,6 +188,11 @@ def run():
|
|||||||
# Assume the last argument is the model.
|
# Assume the last argument is the model.
|
||||||
sys.argv.insert(-1, "--model")
|
sys.argv.insert(-1, "--model")
|
||||||
|
|
||||||
|
# Work around the "model" argument being required
|
||||||
|
# when Heretic is invoked in a non-processing mode.
|
||||||
|
if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv:
|
||||||
|
sys.argv.extend(["--model", ""])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# The required argument "model" must be provided by the user,
|
# The required argument "model" must be provided by the user,
|
||||||
# either on the command line or in the configuration file.
|
# either on the command line or in the configuration file.
|
||||||
@@ -201,6 +209,10 @@ def run():
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if settings.collect_reproducibles is not None:
|
||||||
|
collect_reproducibles(settings.collect_reproducibles)
|
||||||
|
return
|
||||||
|
|
||||||
if settings.seed is None:
|
if settings.seed is None:
|
||||||
settings.seed = random.randint(0, 2**32 - 1)
|
settings.seed = random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,83 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
|
||||||
|
|
||||||
|
from .utils import print
|
||||||
|
|
||||||
|
|
||||||
|
def collect_reproducibles(path: str):
|
||||||
|
print(
|
||||||
|
f"Collecting [bold]reproduce.json[/] files from Hugging Face and storing them in [bold]{path}[/]..."
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
models = api.list_models(
|
||||||
|
filter=["heretic", "reproducible"],
|
||||||
|
sort="created_at",
|
||||||
|
)
|
||||||
|
|
||||||
|
found = 0
|
||||||
|
downloaded = 0
|
||||||
|
|
||||||
|
# We're only downloading tiny files, so the progress bars are just noise.
|
||||||
|
disable_progress_bars()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for model in models:
|
||||||
|
# Ignore repositories containing quantizations.
|
||||||
|
if model.tags is not None and "gguf" in model.tags:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"[bold]{model.id}[/]...", end="")
|
||||||
|
|
||||||
|
user, repository = model.id.split("/")
|
||||||
|
|
||||||
|
paths_info = api.get_paths_info(
|
||||||
|
model.id,
|
||||||
|
"reproduce/reproduce.json",
|
||||||
|
expand=True,
|
||||||
|
)
|
||||||
|
# The reproduce.json file might not exist in the repository
|
||||||
|
# despite the relevant tags being present.
|
||||||
|
if not paths_info:
|
||||||
|
print(" [yellow]no reproduce.json found[/]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
found += 1
|
||||||
|
|
||||||
|
commit_hash = paths_info[0].last_commit.oid
|
||||||
|
|
||||||
|
file_path = (
|
||||||
|
Path(path)
|
||||||
|
/ "huggingface.co"
|
||||||
|
/ user
|
||||||
|
/ f"{repository}-{commit_hash[:7]}.json"
|
||||||
|
)
|
||||||
|
if file_path.exists():
|
||||||
|
print(" already stored")
|
||||||
|
continue
|
||||||
|
|
||||||
|
cache_path = hf_hub_download(
|
||||||
|
model.id,
|
||||||
|
"reproduce/reproduce.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copyfile(cache_path, file_path)
|
||||||
|
print(" [green]downloaded[/]")
|
||||||
|
|
||||||
|
downloaded += 1
|
||||||
|
finally:
|
||||||
|
enable_progress_bars()
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"Found: [bold]{found}[/] files")
|
||||||
|
print(f"Downloaded: [bold]{downloaded}[/] files")
|
||||||
|
print(f"Already stored: [bold]{found - downloaded}[/] files")
|
||||||
Reference in New Issue
Block a user