Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code-prose-composition tagger #234

Open
wants to merge 4 commits into
base: learn2code
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ on:
branches:
- main
- master
- learn2code
tags:
- "*"
pull_request:
branches:
- main
- master
- learn2code
workflow_dispatch:

permissions:
Expand Down
2 changes: 2 additions & 0 deletions python/dolma/taggers/code/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CodeSecretsTagger,
CodeStarCoderTaggers,
CodeStarCoderTaggers2,
Learn2CodeTaggers,
)

__all__ = [
Expand All @@ -12,4 +13,5 @@
"CodeRedPajamaTaggers",
"CodeStarCoderTaggers",
"CodeStarCoderTaggers2",
"Learn2CodeTaggers",
]
111 changes: 111 additions & 0 deletions python/dolma/taggers/code/code_taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
if CODE_DEPENDENCIES_AVAILABLE:
from .starcoder import get_nl_ratio
from .utils import (
b64_filter,
filter_html,
get_ext_to_lang_mapping,
get_line_stats,
get_proportion_alphabetic_chars,
get_secrets,
get_whitespace_regex,
hexadecimal_filter,
special_text_file_filter,
unicode_filter,
)


Expand Down Expand Up @@ -269,3 +275,108 @@ def predict(self, doc: DocumentWithMetadata) -> DocResult: # type: ignore
spans.append(Span(start=0, end=doc_length, type="code_to_text_ratio_html_doc", score=code_to_text_ratio))

return DocResult(doc=doc, spans=spans)


@TaggerRegistry.add("learn2code_taggers_v1")
class Learn2CodeTaggers(BaseTaggerWithMetadata):
"""
Based on a mix of filters from StarCoder and Granite
"""

def __init__(self) -> None:
check_code_dependencies()
self.ext_to_lang_mapping = get_ext_to_lang_mapping()
super().__init__()

def predict(self, doc: DocumentWithMetadata) -> DocResult: # type: ignore
spans: List[Span] = []
doc_length = len(doc.text)

num_github_stars = doc.metadata.get("max_stars_count", 0) or doc.metadata.get("star_events_count", 0) or 0
proportion_alpha = get_proportion_alphabetic_chars(doc.text)
has_xml_template = 1.0 if "<?xml version=" in doc.text[:100] else 0.0
line_stats = get_line_stats(doc.text)
b64_filter_results = b64_filter(doc.text)
hexadecimal_filter_results = hexadecimal_filter(doc.text)
unicode_filter_results = unicode_filter(doc.text)

try:
lang = self.ext_to_lang_mapping[doc.metadata.get("ext", "-no-lang")]
except KeyError:
lang = "-no-lang"

filename = doc.metadata.get("path", None)

try:
proportion_comments_doc = get_nl_ratio(doc.text, lang)
except: # pylint: disable=bare-except # noqa: E722
proportion_comments_doc = -1

# Not relevant for non-html code
if lang == "html":
try:
proportion_text_in_html = filter_html(doc.text)
except: # pylint: disable=bare-except # noqa: E722
proportion_text_in_html = -1.0
else:
proportion_text_in_html = 1.0

is_special_text_file = 1 if special_text_file_filter(filename, lang) else 0

# document-level scores
spans.append(Span(start=0, end=doc_length, type="num_chars_doc", score=float(doc_length)))
spans.append(Span(start=0, end=doc_length, type="num_github_stars_doc", score=float(num_github_stars)))
spans.append(Span(start=0, end=doc_length, type="proportion_alpha_doc", score=proportion_alpha))
spans.append(Span(start=0, end=doc_length, type="has_xml_template_doc", score=has_xml_template))
spans.append(Span(start=0, end=doc_length, type="num_lines_doc", score=float(line_stats.total_count)))
spans.append(Span(start=0, end=doc_length, type="mean_line_length_doc", score=line_stats.mean_length))
spans.append(Span(start=0, end=doc_length, type="max_line_length_doc", score=float(line_stats.max_length)))
spans.append(
Span(
start=0, end=doc_length, type="longest_seq_b64_doc", score=float(b64_filter_results.longest_match)
)
)
spans.append(
Span(start=0, end=doc_length, type="proportion_b64_doc", score=b64_filter_results.proportion_match)
)
spans.append(
Span(
start=0,
end=doc_length,
type="longest_seq_hexadecimal_doc",
score=float(hexadecimal_filter_results.longest_match),
)
)
spans.append(
Span(
start=0,
end=doc_length,
type="proportion_hexadecimal_doc",
score=hexadecimal_filter_results.proportion_match,
)
)
spans.append(
Span(
start=0,
end=doc_length,
type="longest_seq_unicode_doc",
score=float(unicode_filter_results.longest_match),
)
)
spans.append(
Span(
start=0,
end=doc_length,
type="proportion_unicode_doc",
score=unicode_filter_results.proportion_match,
)
)
spans.append(Span(start=0, end=doc_length, type="proportion_comments_doc", score=proportion_comments_doc))
spans.append(
Span(start=0, end=doc_length, type="proportion_text_in_html_doc", score=proportion_text_in_html)
)
spans.append(
Span(start=0, end=doc_length, type="is_special_text_file_doc", score=float(is_special_text_file))
)

return DocResult(doc=doc, spans=spans)
82 changes: 82 additions & 0 deletions python/dolma/taggers/code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Generator

Expand Down Expand Up @@ -54,6 +56,73 @@ def get_secrets(code: str):
return secrets


@dataclass
class StarCoderRegexFilterResults:
longest_match: int
proportion_match: float


def regex_match(regex_string: str, text: str) -> StarCoderRegexFilterResults:
all_matches = re.findall(regex_string, text)

match_lengths = [len(match) for match in all_matches]
longest_match = max(match_lengths) if match_lengths else 0
proportion_match = sum(match_lengths) / len(text)

return StarCoderRegexFilterResults(longest_match=longest_match, proportion_match=proportion_match)


def b64_filter(text: str) -> StarCoderRegexFilterResults:
"""
Taken from the StarCoder2 paper.
"""
regex = r"[a-zA-Z0-9+/\n=]{64,}"
return regex_match(regex, text)


def hexadecimal_filter(text: str) -> StarCoderRegexFilterResults:
"""
Taken from StarCoder2 paper.
The escaped literal case, e.g. "\\x48\\x31\\xc0\\x50\\x68\\x2f\\x2f\\x73\\x68",
is a bit broken, because it'll always drop the first byte in the sequence due to
how \b is interpreted in that context.
"""
regex = r"(?:\b(?:0x|\\x)?[0-9a-fA-F]{2}(?:,|\b\s*)){8,}"
return regex_match(regex, text)


def unicode_filter(text: str) -> StarCoderRegexFilterResults:
"""
Taken from the StarCoder2 paper.
"""
regex = r"(?:\\u[0-9a-fA-F]{4}){8,}"
return regex_match(regex, text)


def get_proportion_alphabetic_chars(text: str) -> float:
"""Calculates the proportion of characters in passed text that are alphabetic"""
nonalpha = re.sub(r"[^A-Za-z]", "", text)
return len(nonalpha) / len(text)


@dataclass
class LineStats:
total_count: int
mean_length: float
max_length: int


def get_line_stats(text: str) -> LineStats:
"""Finds some summary stats about the lines in the passed text"""

lines = text.split("\n")
line_lengths = [len(line) for line in lines]

return LineStats(
total_count=len(lines), mean_length=sum(line_lengths) / len(lines), max_length=max(line_lengths)
)


def filter_html(html: str) -> float:
"""Filter HTML files based on displayed text VS code ratio"""
try:
Expand All @@ -80,3 +149,16 @@ def get_ext_to_lang_mapping() -> Dict[str, str]:
path = Path(__file__).parent / "../../data/ext_to_lang_mapping.json"
with smart_open.open(path, "r") as f:
return json.load(f)


def special_text_file_filter(filepath: str, lang: str) -> bool:
if lang == "text": # TODO: include markdown as well?
filename = Path(filepath).stem.lower()

if "requirement" in filename:
return True

if filename in {"readme", "todo", "description", "cmakelists"}:
return True

return False
64 changes: 64 additions & 0 deletions python/dolma/taggers/quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

"""

import math
from typing import Iterable, List, Tuple

from tokenizers import normalizers, pre_tokenizers
Expand Down Expand Up @@ -66,3 +67,66 @@ def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]:
for label, score in sorted(zip(*preds), key=lambda x: x[1], reverse=True)
]
return out


@TaggerRegistry.add("code-prose-composition")
class CodeProseCompositionClassifier(BaseFastTextTagger):
MODEL_PATH = "hf://allenai/code-prose-composition/code-comment-prose-model.bin" # noqa: E501

def __init__(self):
super().__init__(model_path=self.MODEL_PATH, model_mode=self.DOCUMENT_LEVEL_TAGGER)

def calculate_entropy(self, distribution):
entropy = 0
for p in distribution:
if p > 0:
entropy -= p * math.log2(p)
return entropy

def mean_entropy(self, list_of_distributions):
if not list_of_distributions:
return 0

total_entropy = 0
for dist in list_of_distributions:
total_entropy += self.calculate_entropy(dist)
return total_entropy / len(list_of_distributions)

def predict_slice(self, text_slice: TextSlice) -> Iterable[Prediction]:
class_counts = {}
composition = {}
prediction_distributions = {}

lines = text_slice.text.splitlines()
for line in lines:
line = line.strip()
if not line:
continue

label = 'other'
if len(line) > 3:
labels, probabilities = self.classifier.predict(line, k=-1)

label = labels[0].lstrip("__label__")

if label not in prediction_distributions:
prediction_distributions[label] = []
prediction_distributions[label].append(probabilities)

class_counts[label] = class_counts.get(label, 0) + 1

total_count = sum(class_counts.values())
for key, count in class_counts.items():
composition[key] = round((count / total_count), 2)

out = [
Prediction(label=label.replace("__label__", ""), score=score) for label, score in composition.items()
]

for key in composition.keys():
out.append(Prediction(label=f"{key}_count", score=class_counts.get(key, 0)))
out.append(
Prediction(label=f"{key}_mean_entropy", score=self.mean_entropy(prediction_distributions.get(key, [])))
)

return out
Loading
Loading