Skip to content

Commit

Permalink
Add rich text formatting for better correlation visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
OrigamiDream committed Oct 31, 2022
1 parent 5d8b28b commit c6968a9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ Install following main packages by manually, or use `requirements.txt`
- torch
- transformers
- scikit-learn
- wandb
- wandb # MLOps service
- pandas
- konlpy
- soynlp
- konlpy # For Mecab
- soynlp # For text normalization
- rich # For text highlighting
```
```bash
pip install -r requirements.txt
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ certifi==2022.9.24
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
commonmark==0.9.1
contourpy==1.0.5
cycler==0.11.0
debugpy==1.6.3
Expand Down Expand Up @@ -106,6 +107,7 @@ QtPy==2.2.0
regex==2022.9.13
requests==2.28.1
requests-oauthlib==1.3.1
rich==12.6.0
rsa==4.9
scikit-learn==1.1.2
scipy==1.9.1
Expand Down
8 changes: 5 additions & 3 deletions run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import tensorflow as tf

from rich.console import Console
from cort.config import Config
from cort.modeling import CortForSequenceClassification
from utils import utils, formatting_utils
Expand All @@ -16,6 +17,7 @@
from tensorflow_addons import metrics as metrics_tfa

formatting_utils.setup_formatter(logging.INFO)
console = Console()


KOREAN_PATTERN = re.compile('[ㄱ-ㅎ가-힣]')
Expand Down Expand Up @@ -93,7 +95,7 @@ def build_score_unicodes(word_slice, score_index):

def colorize(text, attention_score, c1=(150, 0, 0), c2=(0, 150, 0)):
color = (1 - attention_score) * np.array(list(c1)) + attention_score * np.array(list(c2))
return '\033[38;2;{};{};{}m{}\033[0m'.format(int(color[0]), int(color[1]), int(color[2]), text)
return '[rgb({},{},{})]{}[reset]'.format(int(color[0]), int(color[1]), int(color[2]), text)

ComposedToken = collections.namedtuple('ComposedToken', [
'matched', 'text', 'colorized_text',
Expand Down Expand Up @@ -236,8 +238,8 @@ def perform_interactive_predictions(config, model):
probs = cort_outputs['probs'][0]
index = np.argmax(probs)
print('\nCorrelations:')
print(''.join([composed.colorized_correlation_unicode for composed in composed_tokens]))
print(''.join([composed.colorized_text for composed in composed_tokens]))
console.print(''.join([composed.colorized_correlation_unicode for composed in composed_tokens]))
console.print(''.join([composed.colorized_text for composed in composed_tokens]))
print('\nPrediction: {}: ({:.06f} of confidence score)'.format(LABEL_NAMES[index], probs[index]))
print()

Expand Down

0 comments on commit c6968a9

Please sign in to comment.