Add save and upload functionality

This commit is contained in:
Philipp Emanuel Weidmann
2025-09-27 11:15:41 +05:30
parent 7573a2eebd
commit 5b01ad4344
3 changed files with 75 additions and 26 deletions
+72 -26
View File
@@ -4,7 +4,9 @@
import sys
import time
from importlib.metadata import version
from pathlib import Path
import huggingface_hub
import optuna
import questionary
import torch
@@ -258,33 +260,77 @@ def main():
],
).ask()
match action:
case "Save the model to a local folder":
# TODO
pass
case "Upload the model to Hugging Face":
# TODO
pass
case "Chat with the model":
print()
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
# All actions are wrapped in a try/except block so that if an error occurs,
# another action can be tried, instead of the program crashing and losing
# the optimized model.
try:
match action:
case "Save the model to a local folder":
save_directory = questionary.path("Path to the folder:").ask()
if not save_directory:
continue
chat = [
{"role": "system", "content": settings.system_prompt},
]
print("Saving model...")
model.model.save_pretrained(save_directory)
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
while True:
try:
message = questionary.text("User:", qmark=">").unsafe_ask()
if not message:
case "Upload the model to Hugging Face":
# We don't use huggingface_hub.login() because that stores the token on disk,
# and since this program will often be run on rented or shared GPU servers,
# it's better to not persist credentials.
token = huggingface_hub.get_token()
if not token:
token = questionary.password("Hugging Face access token:").ask()
if not token:
continue
user = huggingface_hub.whoami(token)
print(f"Logged in as [bold]{user['fullname']} ({user['email']})[/]")
repo_id = questionary.text(
"Name of repository:",
default=f"{user['name']}/{Path(settings.model).name}-heretic",
).ask()
visibility = questionary.select(
"Should the repository be public or private?",
choices=[
"Public",
"Private",
],
).ask()
private = visibility == "Private"
print("Uploading model...")
model.model.push_to_hub(repo_id, private=private, token=token)
model.tokenizer.push_to_hub(repo_id, private=private, token=token)
print(f"Model uploaded to [bold]{repo_id}[/].")
case "Chat with the model":
print()
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
chat = [
{"role": "system", "content": settings.system_prompt},
]
while True:
try:
message = questionary.text("User:", qmark=">").unsafe_ask()
if not message:
break
chat.append({"role": "user", "content": message})
print("[bold]Assistant:[/] ", end="")
response = model.stream_chat_response(chat)
chat.append({"role": "assistant", "content": response})
except (KeyboardInterrupt, EOFError):
# Ctrl+C/Ctrl+D
break
chat.append({"role": "user", "content": message})
print("[bold]Assistant:[/] ", end="")
response = model.stream_chat_response(chat)
chat.append({"role": "assistant", "content": response})
except (KeyboardInterrupt, EOFError):
# Ctrl+C/Ctrl+D
break
case "Nothing (Quit)":
break
case "Nothing (Quit)":
break
except Exception as error:
print(f"[red]Error: {error}[/]")