From 5b01ad43440405ac81f1532eda9993a277564bf7 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Sat, 27 Sep 2025 11:15:41 +0530 Subject: [PATCH] Add save and upload functionality --- pyproject.toml | 1 + src/heretic/main.py | 98 +++++++++++++++++++++++++++++++++------------ uv.lock | 2 + 3 files changed, 75 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d18a7b7..4ec9132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ requires-python = ">=3.10" dependencies = [ "accelerate>=1.10.0", "datasets>=4.0.0", + "huggingface-hub>=0.34.4", "optuna>=4.5.0", "pydantic-settings>=2.10.1", "questionary>=2.1.1", diff --git a/src/heretic/main.py b/src/heretic/main.py index 517086f..2f3a6cd 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -4,7 +4,9 @@ import sys import time from importlib.metadata import version +from pathlib import Path +import huggingface_hub import optuna import questionary import torch @@ -258,33 +260,77 @@ def main(): ], ).ask() - match action: - case "Save the model to a local folder": - # TODO - pass - case "Upload the model to Hugging Face": - # TODO - pass - case "Chat with the model": - print() - print("[cyan]Press Ctrl+C at any time to return to the menu.[/]") + # All actions are wrapped in a try/except block so that if an error occurs, + # another action can be tried, instead of the program crashing and losing + # the optimized model. + try: + match action: + case "Save the model to a local folder": + save_directory = questionary.path("Path to the folder:").ask() + if not save_directory: + continue - chat = [ - {"role": "system", "content": settings.system_prompt}, - ] + print("Saving model...") + model.model.save_pretrained(save_directory) + model.tokenizer.save_pretrained(save_directory) + print(f"Model saved to [bold]{save_directory}[/].") - while True: - try: - message = questionary.text("User:", qmark=">").unsafe_ask() - if not message: + case "Upload the model to Hugging Face": + # We don't use huggingface_hub.login() because that stores the token on disk, + # and since this program will often be run on rented or shared GPU servers, + # it's better to not persist credentials. + token = huggingface_hub.get_token() + if not token: + token = questionary.password("Hugging Face access token:").ask() + if not token: + continue + + user = huggingface_hub.whoami(token) + print(f"Logged in as [bold]{user['fullname']} ({user['email']})[/]") + + repo_id = questionary.text( + "Name of repository:", + default=f"{user['name']}/{Path(settings.model).name}-heretic", + ).ask() + + visibility = questionary.select( + "Should the repository be public or private?", + choices=[ + "Public", + "Private", + ], + ).ask() + private = visibility == "Private" + + print("Uploading model...") + model.model.push_to_hub(repo_id, private=private, token=token) + model.tokenizer.push_to_hub(repo_id, private=private, token=token) + print(f"Model uploaded to [bold]{repo_id}[/].") + + case "Chat with the model": + print() + print("[cyan]Press Ctrl+C at any time to return to the menu.[/]") + + chat = [ + {"role": "system", "content": settings.system_prompt}, + ] + + while True: + try: + message = questionary.text("User:", qmark=">").unsafe_ask() + if not message: + break + chat.append({"role": "user", "content": message}) + + print("[bold]Assistant:[/] ", end="") + response = model.stream_chat_response(chat) + chat.append({"role": "assistant", "content": response}) + except (KeyboardInterrupt, EOFError): + # Ctrl+C/Ctrl+D break - chat.append({"role": "user", "content": message}) - print("[bold]Assistant:[/] ", end="") - response = model.stream_chat_response(chat) - chat.append({"role": "assistant", "content": response}) - except (KeyboardInterrupt, EOFError): - # Ctrl+C/Ctrl+D - break - case "Nothing (Quit)": - break + case "Nothing (Quit)": + break + + except Exception as error: + print(f"[red]Error: {error}[/]") diff --git a/uv.lock b/uv.lock index 509e1a8..192a347 100644 --- a/uv.lock +++ b/uv.lock @@ -479,6 +479,7 @@ source = { editable = "." } dependencies = [ { name = "accelerate" }, { name = "datasets" }, + { name = "huggingface-hub" }, { name = "optuna" }, { name = "pydantic-settings" }, { name = "questionary" }, @@ -490,6 +491,7 @@ dependencies = [ requires-dist = [ { name = "accelerate", specifier = ">=1.10.0" }, { name = "datasets", specifier = ">=4.0.0" }, + { name = "huggingface-hub", specifier = ">=0.34.4" }, { name = "optuna", specifier = ">=4.5.0" }, { name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "questionary", specifier = ">=2.1.1" },