Skip to content

Commit

Permalink
updating textTrainExamples
Browse files Browse the repository at this point in the history
  • Loading branch information
OscarKjell committed Jan 20, 2025
1 parent 82f28b4 commit c685ff5
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 36 deletions.
1 change: 1 addition & 0 deletions R/0_1_globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ utils::globalVariables(c(

# textTrainExamples
"category", "distance_to_mean", "error", "language", "topic",
"height", "save_dir", "width", "y_axes_1",


# textPredcitTest
Expand Down
80 changes: 49 additions & 31 deletions R/2_5_textTrainPredictExamples.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
#Sys.setenv(OMP_NUM_THREADS = "1") #Limit the number of threads to prevent conflicts.
#
#Sys.setenv(OMP_MAX_ACTIVE_LEVELS = "1")
#
## If above does not work, you can also try this; although this solution might have some risks assocaited with it (for more information see https://github.com/dmlc/xgboost/issues/1715)
#Sys.setenv(KMP_DUPLICATE_LIB_OK = "TRUE") #Temporarily allows execution despite duplicate OpenMP libraries.




#' Show language examples (Experimental)
#'
Expand All @@ -23,8 +14,18 @@
#' @param error_color = (string)"darkred",
#' @param distribution_color (string) colors of the distribution plot
#' @param figure_format (string) file format of the figures.
#' @param scatter_legend_dot_size (integer) The size of dots in the scatter legend.
#' @param scatter_legend_bg_dot_size (integer) The size of background dots in the scatter legend.
#' @param scatter_legend_n (numeric or vector) A vector determining the number of dots to emphasize in each quadrant of the scatter legend.
#' For example: c(1,0,1) result in one dot in each quadrant except for the middle quadrant.
# @param scatter_legend_method (string) The method to filter topics to be emphasized in the scatter legend; either "mean", "max_x", or "max_y".
# @param scatter_legend_specified_topics (vector) Specify which topic(s) to emphasize in the scatter legend.
# For example, c("t_1", "t_2"). If set, scatter_legend_method will have no effect.
# @param scatter_legend_topic_n (boolean) If TRUE, the topic numbers are shown in the scatter legend.
#' @param scatter_show_axis_values (boolean) If TRUE, the estimate values are shown on the distribution plot axes.

#' @returns A tibble including examples with descriptive variables.
#' @importFrom dplyr filter select arrange slice group_by summarize mutate
#' @importFrom dplyr filter select arrange slice group_by summarize mutate rename
#' @importFrom stringi stri_detect_fixed
#' @importFrom purrr map_lgl
#' @export
Expand All @@ -40,7 +41,14 @@ textTrainExamples <- function(
predictions_color = "darkblue",
error_color = "darkred",
distribution_color = c("darkgreen", "gray", "darkred"),
figure_format = "svg"
figure_format = "svg",
scatter_legend_dot_size = 4,
scatter_legend_bg_dot_size = 2,
scatter_legend_n = c(3,3,3),
scatter_show_axis_values = TRUE,
grid_legend_x_axes_label = "x",
grid_legend_y_axes_label = "y",
seed = 42
){

# Combine responses with predictions and target scores
Expand Down Expand Up @@ -175,7 +183,9 @@ textTrainExamples <- function(
) +
ggplot2::theme_minimal()

if(selection_method == "min_max"){
if(selection_method == "min_max" |
selection_method == "min" |
selection_method == "max"){
df <- df %>%
dplyr::left_join(df_short %>% dplyr::select(topic, category), by = "topic") %>%
dplyr::mutate(
Expand All @@ -188,14 +198,14 @@ textTrainExamples <- function(
dplyr::select(-category)

num_popout = c(n_examples, 0, n_examples)

user_spec_topics <- paste0(df_short$topic)
}

if(selection_method == "min_mean_max"){

df <- df %>%
dplyr::left_join(df_short %>% dplyr::select(topic, category), by = "topic") %>%
mutate(
dplyr::mutate(
color_categories = dplyr::case_when(
category == "min" ~ 1,
category == "mean" ~ 2,
Expand All @@ -205,14 +215,18 @@ textTrainExamples <- function(
) %>%
dplyr::select(-category) # Optionally remove the temporary `category` column

distribution_color = distribution_color[c(3, 2, 1)]

num_popout = c(n_examples, n_examples, n_examples)

user_spec_topics <- paste0("t_", df_short$topic)
#user_spec_topics <- df_short$topic
user_spec_topics <- paste0(df_short$topic)
}


# The current way of selecting colours are off in topcsScatterLegend
if(selection_method %in% c("min_mean_max", "min_max")){
distribution_color = distribution_color[c(3, 2, 1)]
}
if(selection_method %in% c("min", "max")){
distribution_color = distribution_color[c(1, 2)]
}
#table(df$color_categories)

# Dynamically move a column to the fifth position
Expand All @@ -223,24 +237,28 @@ textTrainExamples <- function(
scatter_plot <- topics::topicsScatterLegend(
bivariate_color_codes = distribution_color,
filtered_test = df,
num_popout = num_popout,
num_popout = scatter_legend_n,
way_popout_topics = "mean",
user_spec_topics = user_spec_topics,
allow_topic_num_legend = FALSE,
scatter_show_axis_values = TRUE,
y_axes_1 = 1,
scatter_show_axis_values = scatter_show_axis_values,
y_axes_1 = y_axes_1,
cor_var = "",
label_x_name = "x",
label_y_name = "y",
save_dir = NULL,
label_x_name = grid_legend_x_axes_label,
label_y_name = grid_legend_y_axes_label,
save_dir = save_dir,
figure_format = figure_format,
scatter_popout_dot_size = 8,
scatter_bg_dot_size = 4,
width = 10,
height = 8,
seed = 42
scatter_popout_dot_size = scatter_legend_dot_size,
scatter_bg_dot_size = scatter_legend_bg_dot_size,
width = width,
height = height,
seed = seed
)
scatter_plot


# Renaming variable
df_short <- df_short %>%
dplyr::rename(id = topic)

results <- list(
error_plot = error_plot,
Expand Down
27 changes: 25 additions & 2 deletions man/textTrainExamples.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 11 additions & 3 deletions tests/testthat/2_9_textTrainExamples.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#library(testthat)
#library(tibble)
#library(text)
#
#library(topics)
#context("Testing tasks")
#
#
Expand Down Expand Up @@ -91,7 +91,15 @@
# target_color = "darkgreen",
# predictions_color = "darkblue",
# error_color = "darkred",
# distribution_color = c("darkgreen", "gray", "darkred")
# distribution_color = c("darkgreen", "gray", "darkred"),
# figure_format = "svg",
# scatter_legend_dot_size = 6,
# scatter_legend_bg_dot_size = 3,
# scatter_legend_n = c(3,3,3),
# scatter_show_axis_values = TRUE,
# grid_legend_x_axes_label = "x",
# grid_legend_y_axes_label = "y",
# seed = 42
# )
#
# examples
Expand All @@ -112,5 +120,5 @@
#
#})
#

#

0 comments on commit c685ff5

Please sign in to comment.