Skip to content

Commit

Permalink
Merge pull request #32 from replicate/temp-logging
Browse files Browse the repository at this point in the history
Add some lang logging
  • Loading branch information
fofr authored Sep 13, 2024
2 parents beb330e + 9f0968b commit acc81ad
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
OUTPUT_DIR = Path("output")
JOB_DIR = OUTPUT_DIR / JOB_NAME

print(f"Environment language: {os.environ.get('LANG', 'Not set')}")
os.environ["LANG"] = "en_US.UTF-8"
print(f"Updated environment language: {os.environ.get('LANG', 'Not set')}")


class CustomSDTrainer(SDTrainer):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -242,9 +246,9 @@ def train(
},
"save": {
"dtype": "float16",
"save_every": wandb_save_interval
if wandb_api_key
else steps + 1,
"save_every": (
wandb_save_interval if wandb_api_key else steps + 1
),
"max_step_saves_to_keep": 1,
},
"datasets": [
Expand Down Expand Up @@ -282,9 +286,11 @@ def train(
},
"sample": {
"sampler": "flowmatch",
"sample_every": wandb_sample_interval
if wandb_api_key and sample_prompts
else steps + 1,
"sample_every": (
wandb_sample_interval
if wandb_api_key and sample_prompts
else steps + 1
),
"width": 1024,
"height": 1024,
"prompts": sample_prompts,
Expand Down

0 comments on commit acc81ad

Please sign in to comment.