Skip to content

Commit

Permalink
Align Model Path instead of Model File Name
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa committed May 21, 2024
1 parent c6b5d9e commit e1ec135
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions docs/user-guide/DistributedDataClassification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ Let's see how ``DomainClassifier`` works in a small excerpt taken from ``example
"Travel_and_Transportation",
]
model_file_name = "pytorch_model_file.pth"
model_path = "pytorch_model_file.pth"
files = get_all_files_paths_under("books_dataset/")
input_dataset = DocumentDataset.read_json(files, backend="cudf", add_filename=True)
domain_classifier = DomainClassifier(
model_file_name=model_file_name,
model_path=model_path,
labels=labels,
filter_by=["Games", "Sports"],
)
Expand Down
4 changes: 2 additions & 2 deletions examples/domain_classifier_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main(args):
"Travel_and_Transportation",
]

model_file_name = "/path/to/pytorch_model_file.pth"
model_path = "/path/to/pytorch_model_file.pth"

# Input can be a string or list
input_file_path = "/path/to/data"
Expand All @@ -66,7 +66,7 @@ def main(args):
)

domain_classifier = DomainClassifier(
model_path=model_file_name,
model_path=model_path,
labels=labels,
filter_by=["Games", "Sports"],
)
Expand Down
4 changes: 2 additions & 2 deletions examples/quality_classifier_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main(args):
global_st = time.time()

labels = ["High", "Medium", "Low"]
model_file_name = "/path/to/pytorch_model_file.pth"
model_path = "/path/to/pytorch_model_file.pth"

# Input can be a string or list
input_file_path = "/path/to/data"
Expand All @@ -38,7 +38,7 @@ def main(args):
)

quality_classifier = QualityClassifier(
model_path=model_file_name,
model_path=model_path,
labels=labels,
filter_by=["High", "Medium"],
)
Expand Down

0 comments on commit e1ec135

Please sign in to comment.