Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修改bleu相应测试 #454

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions cotk/dataloader/language_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,17 @@ def get_teacher_forcing_metric(self, gen_log_prob_key="gen_log_prob") -> "Metric
return metric

GEN_KEY_ARGUMENTS = MetricBase.GEN_KEY_ARGUMENTS
SAMPLE_ARGUMENTS_IN_BLEU = MetricBase.SAMPLE_ARGUMENTS_IN_BLEU.\
replace("sample (int, optional)", "sample_in_bleu (int, optional)")
SAMPLE_HYP_ARGUMENTS_IN_BLEU = MetricBase.SAMPLE_HYP_ARGUMENTS_IN_BLEU.\
replace("n_sample_hyp (int, optional)", "n_sample_hyp_in_bleu (int, optional)")
SAMPLE_REF_ARGUMENTS_IN_BLEU = MetricBase.SAMPLE_REF_ARGUMENTS_IN_BLEU.\
replace("n_sample_ref (int, optional)", "n_sample_ref_in_bleu (int, optional)")
SAMPLE_ARGUMENTS_IN_NGRAM_PERPLEXITY = MetricBase.SAMPLE_ARGUMENTS_IN_NGRAM_PERPLEXITY.\
replace("sample (int, optional)", "sample_in_ngram_perplexity (int, optional)")
replace("n_sample (int, optional)", "n_sample_in_ngram_perplexity (int, optional)")
SEED_ARGUMENTS = MetricBase.SEED_ARGUMENTS
CPU_COUNT_ARGUMENTS = MetricBase.CPU_COUNT_ARGUMENTS
def get_inference_metric(self, gen_key="gen", sample_in_bleu=1000, \
sample_in_ngram_perplexity=10000, seed=1229, cpu_count=None) -> "MetricChain":
def get_inference_metric(self, gen_key="gen", n_sample_hyp_in_bleu=100, \
n_sample_ref_in_bleu=1000,
n_sample_in_ngram_perplexity=10000, seed=1229, cpu_count=None) -> "MetricChain":
'''Get metrics for inference. In other words, this function provides metrics for
language generation tasks.

Expand All @@ -144,7 +147,8 @@ def get_inference_metric(self, gen_key="gen", sample_in_bleu=1000, \

Arguments:
{GEN_KEY_ARGUMENTS}
{SAMPLE_ARGUMENTS_IN_BLEU}
{SAMPLE_HYP_ARGUMENTS_IN_BLEU}
{SAMPLE_REF_ARGUMENTS_IN_BLEU}
{SAMPLE_ARGUMENTS_IN_NGRAM_PERPLEXITY}
{SEED_ARGUMENTS}
{CPU_COUNT_ARGUMENTS}
Expand All @@ -154,19 +158,21 @@ def get_inference_metric(self, gen_key="gen", sample_in_bleu=1000, \
metric = MetricChain()
metric.add_metric(SelfBleuCorpusMetric(self, \
gen_key=gen_key, \
sample=sample_in_bleu, \
n_sample_hyp=n_sample_hyp_in_bleu, \
n_sample_ref=n_sample_ref_in_bleu, \
seed=seed, \
cpu_count=cpu_count))
metric.add_metric(FwBwBleuCorpusMetric(self, \
reference_test_list=self.get_all_batch("test")["sent"], \
gen_key=gen_key, \
sample=sample_in_bleu, \
n_sample_hyp=n_sample_hyp_in_bleu, \
n_sample_ref=n_sample_ref_in_bleu, \
seed=seed, \
cpu_count=cpu_count))
metric.add_metric(FwBwBleuCorpusMetric(self, \
metric.add_metric(NgramFwBwPerplexityMetric(self, \
reference_test_list=self.get_all_batch("test")["sent"], \
gen_key=gen_key, \
sample=sample_in_ngram_perplexity, \
n_sample=n_sample_in_ngram_perplexity, \
seed=seed, \
cpu_count=cpu_count))
metric.add_metric(LanguageGenerationRecorder(self, gen_key=gen_key))
Expand Down
166 changes: 93 additions & 73 deletions cotk/metric/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,17 @@ class SelfBleuCorpusMetric(MetricBase):
{MetricBase.NGRAM_ARGUMENTS}
{MetricBase.TOKENIZER_ARGUMENTS}
{MetricBase.GEN_KEY_ARGUMENTS}
{MetricBase.SAMPLE_ARGUMENTS_IN_BLEU}
{MetricBase.SAMPLE_HYP_ARGUMENTS_IN_BLEU}
{MetricBase.SAMPLE_REF_ARGUMENTS_IN_BLEU}
{MetricBase.SEED_ARGUMENTS}
{MetricBase.CPU_COUNT_ARGUMENTS}

Warning:
the calculation of ``hashvalue`` considers the actual sample size of hypotheses which
will be less than ``sample`` if the size of hypotheses is smaller than ``sample``.
The calculation of ``hashvalue`` considers the actual sample size of hypotheses.
Therefore ``hashvalue`` may vary if the number of generated samples is smaller
than ``n_sample_hyp`` or ``n_sample_ref``.

Here is an example:
Here is an example (to only show the format but not the exact value of results):

>>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small')
>>> gen_key = 'gen'
Expand All @@ -242,20 +244,33 @@ class SelfBleuCorpusMetric(MetricBase):
'''

_name = 'SelfBleuCorpusMetric'
_version = 2
_version = 3
MINIMAL_PARALLEL_N = 100
SHOW_PROGRESS = 100

def __init__(self, dataloader: Union["LanguageProcessing", "Sentence", "Session"], ngram: int = 4, *, \
tokenizer: Union[None, Tokenizer, str] = None, \
gen_key: str = "gen", \
sample: int = 1000, \
n_sample_hyp: int = 100, \
n_sample_ref: int = 1000, \
seed: int = 1229, \
cpu_count: Optional[int] = None):
super().__init__(self._name, self._version)
self.dataloader = dataloader
self.ngram = ngram
self.tokenizer = tokenizer
self.gen_key = gen_key
self.sample = sample
self.n_sample_hyp = n_sample_hyp
self.n_sample_ref = n_sample_ref

if self.n_sample_hyp <= 1:
raise RuntimeError('`sample_hyp` should be more than 1, \
whose value is `{}`'.format(self.n_sample_hyp))

if self.n_sample_ref <= 1:
raise RuntimeError('`sample_hyp` should be more than 1, \
whose value is `{}`'.format(self.n_sample_ref))

self.hyps: List[Any] = []
self.seed = seed
if cpu_count is not None:
Expand Down Expand Up @@ -300,27 +315,24 @@ def close(self) -> Dict[str, Any]:
if not self.hyps:
raise RuntimeError("The metric has not been forwarded data correctly.")
if len(self.hyps) == 1:
raise RuntimeError("Self-Bleu can't be computed because there is only 1 generated sentence.")
if self.sample <= 1:
raise RuntimeError('`self.sample` should be more than 1, \
whose value is `{}`'.format(self.sample))

if self.sample > len(self.hyps):
self.sample = len(self.hyps)
raise RuntimeError("Selfbleu can't be computed because there is only 1 generated sentence.")

rng_state = random.getstate()
random.seed(self.seed)
random.shuffle(self.hyps)
random.setstate(rng_state)

ref = self.hyps[:self.sample]
n_sample_hyp = min(self.n_sample_hyp, len(self.hyps))
n_sample_ref = min(self.n_sample_ref, len(self.hyps))

ref = self.hyps[:n_sample_ref]

if self.tokenizer:
tokenizer: Tokenizer
if isinstance(self.tokenizer, str):
tokenizer = SimpleTokenizer(self.tokenizer)
else:
tokenizer = tokenizer
tokenizer = self.tokenizer
ref = [self.dataloader.convert_ids_to_sentence(ids, remove_special=True, trim=True) for ids in ref]
ref = tokenizer.tokenize_sentences(ref)
else:
Expand All @@ -330,31 +342,29 @@ def close(self) -> Dict[str, Any]:
_ref = replace_unk(ref, self.dataloader.get_special_tokens_mapping()["unk"])
else:
_ref = ref

bleu_irl = []

weights = np.ones(self.ngram) / self.ngram
tasks = ((ref[:i]+ref[i+1:self.sample], _ref[i], weights) for i in range(self.sample))
tasks = ((ref[:i]+ref[i+1:], _ref[i], weights) for i in range(n_sample_hyp))

pool: Optional[Any]
pool: Optional[Any] = None
values: Iterable[Any]
if self.sample >= 1000 and self.cpu_count > 1:
if n_sample_hyp >= SelfBleuCorpusMetric.MINIMAL_PARALLEL_N and self.cpu_count > 1:
# use multiprocessing
pool = Pool(self.cpu_count)
if pool is None:
pool = Pool(self.cpu_count)
values = pool.imap_unordered(_sentence_bleu, tasks, chunksize=20)
else:
pool = None
values = map(_sentence_bleu, tasks)
if self.sample >= 1000:
# use tqdm
values = tqdm.tqdm(values, total=self.sample)
if n_sample_hyp >= SelfBleuCorpusMetric.SHOW_PROGRESS:
values = tqdm.tqdm(values, total=n_sample_hyp) # use tqdm
for ans in values:
bleu_irl.append(ans)
if pool is not None:
pool.close()
pool.join()

self._hash_ordered_data((self.seed, self.sample))
self._hash_ordered_data((self.seed, n_sample_hyp, n_sample_ref))
res.update({"self-bleu" : 1.0 * sum(bleu_irl) / len(bleu_irl),\
"self-bleu hashvalue": self._hashvalue()})
return res
Expand All @@ -368,15 +378,17 @@ class FwBwBleuCorpusMetric(MetricBase):
{MetricBase.NGRAM_ARGUMENTS}
{MetricBase.TOKENIZER_ARGUMENTS}
{MetricBase.GEN_KEY_ARGUMENTS}
{MetricBase.SAMPLE_ARGUMENTS_IN_BLEU}
{MetricBase.SAMPLE_HYP_ARGUMENTS_IN_BLEU}
{MetricBase.SAMPLE_REF_ARGUMENTS_IN_BLEU}
{MetricBase.SEED_ARGUMENTS}
{MetricBase.CPU_COUNT_ARGUMENTS}
Warning:
The calculation of ``hashvalue`` considers the actual sample size of hypotheses and
references. Therefore ``hashvalue`` may vary with the size of hypothesis or references
if the size of them is smaller than ``sample``.

Here is an example:
.. warning::
``fw-bw-bleu hashvalue`` considers the actual sample size of generated samples.
Therefore ``hashvalue`` may vary if the number of generated samples is smaller
than ``n_sample_hyp`` or ``n_sample_ref``.

Here is an example (to only show the format but not the exact value of results):

>>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small')
>>> gen_key = 'gen'
Expand All @@ -396,23 +408,36 @@ class FwBwBleuCorpusMetric(MetricBase):
'''

_name = 'FwBwBleuCorpusMetric'
_version = 2
_version = 3
MINIMAL_PARALLEL_N = 100
SHOW_PROGRESS = 100

def __init__(self, dataloader: Union["LanguageProcessing", "Sentence", "Session"], \
reference_test_list: List[Any], ngram: int = 4, *, \
tokenizer: Union[None, Tokenizer, str] = None, \
gen_key: str = "gen", \
sample: int = 1000, \
n_sample_hyp: int = 100, \
n_sample_ref: int = 1000, \
seed: int = 1229, \
cpu_count: Optional[int] = None):
super().__init__(self._name, self._version)
self.dataloader = dataloader
self.tokenizer = tokenizer
self.reference_test_list = reference_test_list
self.gen_key = gen_key
self.sample = sample
self.n_sample_hyp = n_sample_hyp
self.n_sample_ref = n_sample_ref

if self.n_sample_hyp <= 1:
raise RuntimeError('`sample_hyp` should be more than 1, \
whose value is `{}`'.format(self.n_sample_hyp))

if self.n_sample_ref <= 1:
raise RuntimeError('`sample_hyp` should be more than 1, \
whose value is `{}`'.format(self.n_sample_ref))

self.seed = seed
self.ngram=ngram
self.ngram = ngram
if cpu_count is not None:
self.cpu_count = cpu_count
elif "CPU_COUNT" in os.environ and os.environ["CPU_COUNT"] is not None:
Expand Down Expand Up @@ -442,8 +467,7 @@ def forward(self, data: Dict[str, Any]):
if not isinstance(gen, (np.ndarray, list)):
raise TypeError("Unknown type for gen.")

for gen_sen in gen:
self.hyps.append(list(self.dataloader.trim_in_ids(gen_sen)))
self.hyps.extend(gen)

def close(self) -> Dict[str, Any]:
'''Return a dict which contains
Expand All @@ -460,18 +484,17 @@ def close(self) -> Dict[str, Any]:
if not self.reference_test_list:
raise RuntimeError("Reference cannot be empty")

sample_hyps_num = self.sample if self.sample < len(self.hyps) else len(self.hyps)
sample_refs_num = self.sample if self.sample < len(self.reference_test_list) else len(self.reference_test_list)

if sample_hyps_num <= 1:
raise RuntimeError('`sample_hyps` should be more than 1, \
whose value is `{}`'.format(sample_hyps_num))
if sample_refs_num <= 1:
raise RuntimeError('`sample_refs` should be more than 1, \
whose value is `{}`'.format(sample_refs_num))
rng_state = random.getstate()
random.seed(self.seed)
sample_hyps = self.hyps.copy()
sample_refs = self.reference_test_list.copy()
random.shuffle(sample_hyps)
random.shuffle(sample_refs)
random.setstate(rng_state)

sample_hyps = self.hyps[:sample_hyps_num]
sample_refs = self.reference_test_list[:sample_refs_num]
n_sample_max = max(self.n_sample_hyp, self.n_sample_ref)
sample_hyps = sample_hyps[:n_sample_max]
sample_refs = sample_refs[:n_sample_max]

refs: List[Any]
hyps: List[Any]
Expand All @@ -480,7 +503,7 @@ def close(self) -> Dict[str, Any]:
if isinstance(self.tokenizer, str):
tokenizer = SimpleTokenizer(self.tokenizer)
else:
tokenizer = tokenizer
tokenizer = self.tokenizer
if isinstance(sample_refs[0], List):
ref_sents = [self.dataloader.convert_ids_to_sentence(ids, remove_special=True, trim=True) for ids in sample_refs]
else:
Expand All @@ -493,47 +516,44 @@ def close(self) -> Dict[str, Any]:
refs = [self.dataloader.convert_ids_to_tokens(ids, remove_special=True, trim=True) for ids in sample_refs]
hyps = [self.dataloader.convert_ids_to_tokens(ids, remove_special=True, trim=True) for ids in sample_hyps]

rng_state = random.getstate()
random.seed(self.seed)
random.shuffle(hyps)
random.shuffle(refs)
random.setstate(rng_state)

if "unk" in self.dataloader.get_special_tokens_mapping():
refs = replace_unk(refs, self.dataloader.get_special_tokens_mapping()["unk"])


bleu_irl_fw, bleu_irl_bw = [], []
weights = np.ones(self.ngram) / self.ngram

tasks = ((refs, hyps[i], weights) for i in range(sample_hyps_num))
pool: Optional[Any]
n_fw_sample = min(len(hyps), self.n_sample_hyp)
n_fw_reference = min(len(refs), self.n_sample_ref)
tasks = ((refs[:n_fw_reference], hyps[i], weights) for i in range(n_fw_sample))
pool: Optional[Any] = None
values: Iterable[Any]
if sample_hyps_num >= 1000 and self.cpu_count > 1:
pool = Pool(self.cpu_count)
if n_fw_sample >= FwBwBleuCorpusMetric.MINIMAL_PARALLEL_N and self.cpu_count > 1:
if not pool:
pool = Pool(self.cpu_count)
values = pool.imap_unordered(_sentence_bleu, tasks, chunksize=20)
else:
pool = None
values = map(_sentence_bleu, tasks)
if sample_hyps_num >= 1000:
values = tqdm.tqdm(values, total=sample_hyps_num)

if n_fw_sample >= FwBwBleuCorpusMetric.MINIMAL_PARALLEL_N:
values = tqdm.tqdm(values, total=n_fw_sample)
for ans in values:
bleu_irl_fw.append(ans)
if pool is not None:
pool.close()
pool.join()

tasks = ((hyps, refs[i], weights) for i in range(sample_refs_num))
if sample_refs_num >= 1000 and self.cpu_count > 1:
pool = Pool(self.cpu_count)
n_bw_sample = min(len(refs), self.n_sample_hyp)
n_bw_reference = min(len(hyps), self.n_sample_ref)
tasks = ((hyps[:n_fw_reference], refs[i], weights) for i in range(n_bw_sample))
if n_bw_sample >= FwBwBleuCorpusMetric.SHOW_PROGRESS and self.cpu_count > 1:
if pool is None:
pool = Pool(self.cpu_count)
values = pool.imap_unordered(_sentence_bleu, tasks, chunksize=20)
else:
pool = None
values = map(_sentence_bleu, tasks)
if sample_refs_num >= 1000:
values = tqdm.tqdm(values, total=sample_refs_num)
if n_bw_sample >= FwBwBleuCorpusMetric.SHOW_PROGRESS:
values = tqdm.tqdm(values, total=n_bw_sample)
for ans in values:
bleu_irl_bw.append(ans)

if pool is not None:
pool.close()
pool.join()
Expand All @@ -551,7 +571,7 @@ def close(self) -> Dict[str, Any]:
})

self._hash_unordered_list(refs)
self._hash_ordered_data((self.ngram, self.seed, sample_hyps_num, sample_refs_num))
self._hash_ordered_data((self.ngram, self.seed, n_fw_sample, n_fw_reference, n_bw_sample, n_bw_reference))
res.update({"fw-bw-bleu hashvalue" : self._hashvalue()})
return res

Expand Down
11 changes: 8 additions & 3 deletions cotk/metric/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ class MetricBase(LoadClassInterface, metaclass=DocStringInheritor):
IGNORE_SMOOTHING_ERROR_ARGUMENTS = \
"""ignore_smoothing_error (bool, optional): Specifies whether to ignore the smoothing error when calculating \
BLEU. Default: ``False``."""
SAMPLE_ARGUMENTS_IN_BLEU = \
"""sample (int, optional): Number of examples sampled from the generated sentences. Default: ``1000``."""
SAMPLE_HYP_ARGUMENTS_IN_BLEU = \
"""n_sample_hyp (int, optional): Number of hypothesis sampled for calculating the BLEU score. A larger value will reduce \
the variance of the result but become slower. Default: ``100``."""
SAMPLE_REF_ARGUMENTS_IN_BLEU = \
"""n_sample_ref (int, optional): Number of references sampled for calculating the BLEU score. A larger value will lead to a larger \
result, because we have more acceptable references. Default: ``1000``."""
SAMPLE_ARGUMENTS_IN_NGRAM_PERPLEXITY = \
SAMPLE_ARGUMENTS_IN_BLEU.replace("Default: ``1000``.", "Default: ``10000``.")
"""n_sample (int, optional): Number of hypothesis sampled for training the language model and calculating the ngram perplexity. \
A larger value will reduce the variance of the result but become slower. Default: ``10000``."""
SEED_ARGUMENTS = \
"""seed (int, optional): Random seed for sampling. Default: ``1229``."""
REFERENCE_TEST_LIST_ARGUMENTS = \
Expand Down
Loading