Add save and upload functionality

This commit is contained in:
Philipp Emanuel Weidmann
2025-09-27 11:15:41 +05:30
parent 7573a2eebd
commit 5b01ad4344
3 changed files with 75 additions and 26 deletions
+1
View File
@@ -10,6 +10,7 @@ requires-python = ">=3.10"
dependencies = [
"accelerate>=1.10.0",
"datasets>=4.0.0",
"huggingface-hub>=0.34.4",
"optuna>=4.5.0",
"pydantic-settings>=2.10.1",
"questionary>=2.1.1",
+50 -4
View File
@@ -4,7 +4,9 @@
import sys
import time
from importlib.metadata import version
from pathlib import Path
import huggingface_hub
import optuna
import questionary
import torch
@@ -258,13 +260,53 @@ def main():
],
).ask()
# All actions are wrapped in a try/except block so that if an error occurs,
# another action can be tried, instead of the program crashing and losing
# the optimized model.
try:
match action:
case "Save the model to a local folder":
# TODO
pass
save_directory = questionary.path("Path to the folder:").ask()
if not save_directory:
continue
print("Saving model...")
model.model.save_pretrained(save_directory)
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
case "Upload the model to Hugging Face":
# TODO
pass
# We don't use huggingface_hub.login() because that stores the token on disk,
# and since this program will often be run on rented or shared GPU servers,
# it's better to not persist credentials.
token = huggingface_hub.get_token()
if not token:
token = questionary.password("Hugging Face access token:").ask()
if not token:
continue
user = huggingface_hub.whoami(token)
print(f"Logged in as [bold]{user['fullname']} ({user['email']})[/]")
repo_id = questionary.text(
"Name of repository:",
default=f"{user['name']}/{Path(settings.model).name}-heretic",
).ask()
visibility = questionary.select(
"Should the repository be public or private?",
choices=[
"Public",
"Private",
],
).ask()
private = visibility == "Private"
print("Uploading model...")
model.model.push_to_hub(repo_id, private=private, token=token)
model.tokenizer.push_to_hub(repo_id, private=private, token=token)
print(f"Model uploaded to [bold]{repo_id}[/].")
case "Chat with the model":
print()
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
@@ -286,5 +328,9 @@ def main():
except (KeyboardInterrupt, EOFError):
# Ctrl+C/Ctrl+D
break
case "Nothing (Quit)":
break
except Exception as error:
print(f"[red]Error: {error}[/]")
Generated
+2
View File
@@ -479,6 +479,7 @@ source = { editable = "." }
dependencies = [
{ name = "accelerate" },
{ name = "datasets" },
{ name = "huggingface-hub" },
{ name = "optuna" },
{ name = "pydantic-settings" },
{ name = "questionary" },
@@ -490,6 +491,7 @@ dependencies = [
requires-dist = [
{ name = "accelerate", specifier = ">=1.10.0" },
{ name = "datasets", specifier = ">=4.0.0" },
{ name = "huggingface-hub", specifier = ">=0.34.4" },
{ name = "optuna", specifier = ">=4.5.0" },
{ name = "pydantic-settings", specifier = ">=2.10.1" },
{ name = "questionary", specifier = ">=2.1.1" },