Skip to content

Commit

Permalink
wandb logout as part of initial cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Sep 5, 2024
1 parent 6a4e57b commit 9434603
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 30 deletions.
39 changes: 18 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from toolkit.config import get_config

from caption import Captioner
from wandb_client import WeightsAndBiasesClient
from wandb_client import WeightsAndBiasesClient, logout_wandb


JOB_NAME = "flux_train_replicate"
Expand Down Expand Up @@ -306,31 +306,26 @@ def train(
name=wandb_run,
)

try:
download_weights()
extract_zip(input_images, INPUT_DIR)
download_weights()
extract_zip(input_images, INPUT_DIR)

if not trigger_word:
del train_config["config"]["process"][0]["trigger_word"]
if not trigger_word:
del train_config["config"]["process"][0]["trigger_word"]

captioner = Captioner()
if autocaption and not captioner.all_images_are_captioned(INPUT_DIR):
captioner.load_models()
captioner.caption_images(INPUT_DIR, autocaption_prefix, autocaption_suffix)
captioner = Captioner()
if autocaption and not captioner.all_images_are_captioned(INPUT_DIR):
captioner.load_models()
captioner.caption_images(INPUT_DIR, autocaption_prefix, autocaption_suffix)

del captioner
torch.cuda.empty_cache()
del captioner
torch.cuda.empty_cache()

print("Starting train job")
job = CustomJob(get_config(train_config, name=None), wandb_client)
job.run()
print("Starting train job")
job = CustomJob(get_config(train_config, name=None), wandb_client)
job.run()

if wandb_client:
wandb_client.finish()

finally:
if wandb_client:
wandb_client.logout()
if wandb_client:
wandb_client.finish()

job.cleanup()

Expand Down Expand Up @@ -441,6 +436,8 @@ def extract_zip(input_images: Path, input_dir: Path):


def clean_up():
logout_wandb()

if INPUT_DIR.exists():
shutil.rmtree(INPUT_DIR)

Expand Down
22 changes: 13 additions & 9 deletions wandb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@
from wandb.sdk.wandb_settings import Settings


def logout_wandb():
netrc_path = Path("/root/.netrc")
if not netrc_path.exists():
return

n = netrc.netrc(netrc_path)

if "api.wandb.ai" in n.hosts:
del n.hosts["api.wandb.ai"]

netrc_path.write_text(repr(n))


class WeightsAndBiasesClient:
def __init__(
self,
Expand Down Expand Up @@ -57,15 +70,6 @@ def finish(self):
with suppress(Exception):
wandb.finish()

def logout(self):
netrc_path = Path("/root/.netrc")
n = netrc.netrc(netrc_path)

if "api.wandb.ai" in n.hosts:
del n.hosts["api.wandb.ai"]

netrc_path.write_text(repr(n))


def truncate(text, max_chars=50):
if len(text) <= max_chars:
Expand Down

0 comments on commit 9434603

Please sign in to comment.