Skip to content

Commit

Permalink
Wandb fixes
Browse files Browse the repository at this point in the history
* Always log out (shouldn't be an issue but just in case)
* Truncate long image file names
* Catch all wandb exceptions
  • Loading branch information
andreasjansson committed Sep 5, 2024
1 parent a889701 commit 6a4e57b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 52 deletions.
85 changes: 45 additions & 40 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,33 +206,6 @@ def train(
if wandb_sample_prompts:
sample_prompts = [p.strip() for p in wandb_sample_prompts.split("\n")]

wandb_client = None
if wandb_api_key:
wandb_config = {
"trigger_word": trigger_word,
"autocaption": autocaption,
"autocaption_prefix": autocaption_prefix,
"autocaption_suffix": autocaption_suffix,
"steps": steps,
"learning_rate": learning_rate,
"batch_size": batch_size,
"resolution": resolution,
"lora_rank": lora_rank,
"caption_dropout_rate": caption_dropout_rate,
"optimizer": optimizer,
}
wandb_client = WeightsAndBiasesClient(
api_key=wandb_api_key.get_secret_value(),
config=wandb_config,
sample_prompts=sample_prompts,
project=wandb_project,
entity=wandb_entity,
name=wandb_run,
)

download_weights()
extract_zip(input_images, INPUT_DIR)

train_config = OrderedDict(
{
"job": "custom_job",
Expand Down Expand Up @@ -309,23 +282,55 @@ def train(
}
)

if not trigger_word:
del train_config["config"]["process"][0]["trigger_word"]
wandb_client = None
if wandb_api_key:
wandb_config = {
"trigger_word": trigger_word,
"autocaption": autocaption,
"autocaption_prefix": autocaption_prefix,
"autocaption_suffix": autocaption_suffix,
"steps": steps,
"learning_rate": learning_rate,
"batch_size": batch_size,
"resolution": resolution,
"lora_rank": lora_rank,
"caption_dropout_rate": caption_dropout_rate,
"optimizer": optimizer,
}
wandb_client = WeightsAndBiasesClient(
api_key=wandb_api_key.get_secret_value(),
config=wandb_config,
sample_prompts=sample_prompts,
project=wandb_project,
entity=wandb_entity,
name=wandb_run,
)

try:
download_weights()
extract_zip(input_images, INPUT_DIR)

if not trigger_word:
del train_config["config"]["process"][0]["trigger_word"]

captioner = Captioner()
if autocaption and not captioner.all_images_are_captioned(INPUT_DIR):
captioner.load_models()
captioner.caption_images(INPUT_DIR, autocaption_prefix, autocaption_suffix)

captioner = Captioner()
if autocaption and not captioner.all_images_are_captioned(INPUT_DIR):
captioner.load_models()
captioner.caption_images(INPUT_DIR, autocaption_prefix, autocaption_suffix)
del captioner
torch.cuda.empty_cache()

del captioner
torch.cuda.empty_cache()
print("Starting train job")
job = CustomJob(get_config(train_config, name=None), wandb_client)
job.run()

print("Starting train job")
job = CustomJob(get_config(train_config, name=None), wandb_client)
job.run()
if wandb_client:
wandb_client.finish()

if wandb_client:
wandb_client.finish()
finally:
if wandb_client:
wandb_client.logout()

job.cleanup()

Expand Down
53 changes: 41 additions & 12 deletions wandb_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import netrc
from pathlib import Path
from typing import Any, Sequence
from contextlib import suppress
Expand All @@ -18,28 +19,56 @@ def __init__(
self.api_key = api_key
self.sample_prompts = sample_prompts
wandb.login(key=self.api_key, verify=True)
self.run = wandb.init(
project=project,
entity=entity,
name=name,
config=config,
save_code=False,
settings=Settings(_disable_machine_info=True),
)
try:
self.run = wandb.init(
project=project,
entity=entity,
name=name,
config=config,
save_code=False,
settings=Settings(_disable_machine_info=True),
)
except Exception as e:
raise ValueError(f"Failed to log in to Weights & Biases: {e}")

def log_loss(self, loss_dict: dict[str, Any], step: int | None):
wandb.log(data=loss_dict, step=step)
try:
wandb.log(data=loss_dict, step=step)
except Exception as e:
print(f"Failed to log to Weights & Biases: {e}")

def log_samples(self, image_paths: Sequence[Path], step: int | None):
data = {
f"samples/{prompt}": wandb.Image(str(path))
f"samples/{truncate(prompt)}": wandb.Image(str(path))
for prompt, path in zip(self.sample_prompts, image_paths)
}
wandb.log(data=data, step=step)
try:
wandb.log(data=data, step=step)
except Exception as e:
print(f"Failed to log to Weights & Biases: {e}")

def save_weights(self, lora_path: Path):
wandb.save(lora_path)
try:
wandb.save(lora_path)
except Exception as e:
print(f"Failed to save to Weights & Biases: {e}")

def finish(self):
with suppress(Exception):
wandb.finish()

def logout(self):
netrc_path = Path("/root/.netrc")
n = netrc.netrc(netrc_path)

if "api.wandb.ai" in n.hosts:
del n.hosts["api.wandb.ai"]

netrc_path.write_text(repr(n))


def truncate(text, max_chars=50):
if len(text) <= max_chars:
return text
half = (max_chars - 3) // 2
return f"{text[:half]}...{text[-half:]}"

0 comments on commit 6a4e57b

Please sign in to comment.