Add save and upload functionality

This commit is contained in:
Philipp Emanuel Weidmann
2025-09-27 11:15:41 +05:30
parent 7573a2eebd
commit 5b01ad4344
3 changed files with 75 additions and 26 deletions
+1
View File
@@ -10,6 +10,7 @@ requires-python = ">=3.10"
dependencies = [ dependencies = [
"accelerate>=1.10.0", "accelerate>=1.10.0",
"datasets>=4.0.0", "datasets>=4.0.0",
"huggingface-hub>=0.34.4",
"optuna>=4.5.0", "optuna>=4.5.0",
"pydantic-settings>=2.10.1", "pydantic-settings>=2.10.1",
"questionary>=2.1.1", "questionary>=2.1.1",
+72 -26
View File
@@ -4,7 +4,9 @@
import sys import sys
import time import time
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path
import huggingface_hub
import optuna import optuna
import questionary import questionary
import torch import torch
@@ -258,33 +260,77 @@ def main():
], ],
).ask() ).ask()
match action: # All actions are wrapped in a try/except block so that if an error occurs,
case "Save the model to a local folder": # another action can be tried, instead of the program crashing and losing
# TODO # the optimized model.
pass try:
case "Upload the model to Hugging Face": match action:
# TODO case "Save the model to a local folder":
pass save_directory = questionary.path("Path to the folder:").ask()
case "Chat with the model": if not save_directory:
print() continue
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
chat = [ print("Saving model...")
{"role": "system", "content": settings.system_prompt}, model.model.save_pretrained(save_directory)
] model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
while True: case "Upload the model to Hugging Face":
try: # We don't use huggingface_hub.login() because that stores the token on disk,
message = questionary.text("User:", qmark=">").unsafe_ask() # and since this program will often be run on rented or shared GPU servers,
if not message: # it's better to not persist credentials.
token = huggingface_hub.get_token()
if not token:
token = questionary.password("Hugging Face access token:").ask()
if not token:
continue
user = huggingface_hub.whoami(token)
print(f"Logged in as [bold]{user['fullname']} ({user['email']})[/]")
repo_id = questionary.text(
"Name of repository:",
default=f"{user['name']}/{Path(settings.model).name}-heretic",
).ask()
visibility = questionary.select(
"Should the repository be public or private?",
choices=[
"Public",
"Private",
],
).ask()
private = visibility == "Private"
print("Uploading model...")
model.model.push_to_hub(repo_id, private=private, token=token)
model.tokenizer.push_to_hub(repo_id, private=private, token=token)
print(f"Model uploaded to [bold]{repo_id}[/].")
case "Chat with the model":
print()
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
chat = [
{"role": "system", "content": settings.system_prompt},
]
while True:
try:
message = questionary.text("User:", qmark=">").unsafe_ask()
if not message:
break
chat.append({"role": "user", "content": message})
print("[bold]Assistant:[/] ", end="")
response = model.stream_chat_response(chat)
chat.append({"role": "assistant", "content": response})
except (KeyboardInterrupt, EOFError):
# Ctrl+C/Ctrl+D
break break
chat.append({"role": "user", "content": message})
print("[bold]Assistant:[/] ", end="") case "Nothing (Quit)":
response = model.stream_chat_response(chat) break
chat.append({"role": "assistant", "content": response})
except (KeyboardInterrupt, EOFError): except Exception as error:
# Ctrl+C/Ctrl+D print(f"[red]Error: {error}[/]")
break
case "Nothing (Quit)":
break
Generated
+2
View File
@@ -479,6 +479,7 @@ source = { editable = "." }
dependencies = [ dependencies = [
{ name = "accelerate" }, { name = "accelerate" },
{ name = "datasets" }, { name = "datasets" },
{ name = "huggingface-hub" },
{ name = "optuna" }, { name = "optuna" },
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "questionary" }, { name = "questionary" },
@@ -490,6 +491,7 @@ dependencies = [
requires-dist = [ requires-dist = [
{ name = "accelerate", specifier = ">=1.10.0" }, { name = "accelerate", specifier = ">=1.10.0" },
{ name = "datasets", specifier = ">=4.0.0" }, { name = "datasets", specifier = ">=4.0.0" },
{ name = "huggingface-hub", specifier = ">=0.34.4" },
{ name = "optuna", specifier = ">=4.5.0" }, { name = "optuna", specifier = ">=4.5.0" },
{ name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "pydantic-settings", specifier = ">=2.10.1" },
{ name = "questionary", specifier = ">=2.1.1" }, { name = "questionary", specifier = ">=2.1.1" },