From c685ff5d5479449f73e61299d49e9c93dd6d2cdd Mon Sep 17 00:00:00 2001 From: Oscar Kjell Date: Mon, 20 Jan 2025 18:43:10 +0100 Subject: [PATCH] updating textTrainExamples --- R/0_1_globals.R | 1 + R/2_5_textTrainPredictExamples.R | 80 ++++++++++++++++---------- man/textTrainExamples.Rd | 27 ++++++++- tests/testthat/2_9_textTrainExamples.R | 14 ++++- 4 files changed, 86 insertions(+), 36 deletions(-) diff --git a/R/0_1_globals.R b/R/0_1_globals.R index 26d5a026..38ef3da4 100644 --- a/R/0_1_globals.R +++ b/R/0_1_globals.R @@ -40,6 +40,7 @@ utils::globalVariables(c( # textTrainExamples "category", "distance_to_mean", "error", "language", "topic", + "height", "save_dir", "width", "y_axes_1", # textPredcitTest diff --git a/R/2_5_textTrainPredictExamples.R b/R/2_5_textTrainPredictExamples.R index 58e803d4..3eadd5e9 100644 --- a/R/2_5_textTrainPredictExamples.R +++ b/R/2_5_textTrainPredictExamples.R @@ -1,12 +1,3 @@ -#Sys.setenv(OMP_NUM_THREADS = "1") #Limit the number of threads to prevent conflicts. -# -#Sys.setenv(OMP_MAX_ACTIVE_LEVELS = "1") -# -## If above does not work, you can also try this; although this solution might have some risks assocaited with it (for more information see https://github.com/dmlc/xgboost/issues/1715) -#Sys.setenv(KMP_DUPLICATE_LIB_OK = "TRUE") #Temporarily allows execution despite duplicate OpenMP libraries. - - - #' Show language examples (Experimental) #' @@ -23,8 +14,18 @@ #' @param error_color = (string)"darkred", #' @param distribution_color (string) colors of the distribution plot #' @param figure_format (string) file format of the figures. +#' @param scatter_legend_dot_size (integer) The size of dots in the scatter legend. +#' @param scatter_legend_bg_dot_size (integer) The size of background dots in the scatter legend. +#' @param scatter_legend_n (numeric or vector) A vector determining the number of dots to emphasize in each quadrant of the scatter legend. +#' For example: c(1,0,1) result in one dot in each quadrant except for the middle quadrant. +# @param scatter_legend_method (string) The method to filter topics to be emphasized in the scatter legend; either "mean", "max_x", or "max_y". +# @param scatter_legend_specified_topics (vector) Specify which topic(s) to emphasize in the scatter legend. +# For example, c("t_1", "t_2"). If set, scatter_legend_method will have no effect. +# @param scatter_legend_topic_n (boolean) If TRUE, the topic numbers are shown in the scatter legend. +#' @param scatter_show_axis_values (boolean) If TRUE, the estimate values are shown on the distribution plot axes. + #' @returns A tibble including examples with descriptive variables. -#' @importFrom dplyr filter select arrange slice group_by summarize mutate +#' @importFrom dplyr filter select arrange slice group_by summarize mutate rename #' @importFrom stringi stri_detect_fixed #' @importFrom purrr map_lgl #' @export @@ -40,7 +41,14 @@ textTrainExamples <- function( predictions_color = "darkblue", error_color = "darkred", distribution_color = c("darkgreen", "gray", "darkred"), - figure_format = "svg" + figure_format = "svg", + scatter_legend_dot_size = 4, + scatter_legend_bg_dot_size = 2, + scatter_legend_n = c(3,3,3), + scatter_show_axis_values = TRUE, + grid_legend_x_axes_label = "x", + grid_legend_y_axes_label = "y", + seed = 42 ){ # Combine responses with predictions and target scores @@ -175,7 +183,9 @@ textTrainExamples <- function( ) + ggplot2::theme_minimal() - if(selection_method == "min_max"){ + if(selection_method == "min_max" | + selection_method == "min" | + selection_method == "max"){ df <- df %>% dplyr::left_join(df_short %>% dplyr::select(topic, category), by = "topic") %>% dplyr::mutate( @@ -188,14 +198,14 @@ textTrainExamples <- function( dplyr::select(-category) num_popout = c(n_examples, 0, n_examples) - + user_spec_topics <- paste0(df_short$topic) } if(selection_method == "min_mean_max"){ df <- df %>% dplyr::left_join(df_short %>% dplyr::select(topic, category), by = "topic") %>% - mutate( + dplyr::mutate( color_categories = dplyr::case_when( category == "min" ~ 1, category == "mean" ~ 2, @@ -205,14 +215,18 @@ textTrainExamples <- function( ) %>% dplyr::select(-category) # Optionally remove the temporary `category` column - distribution_color = distribution_color[c(3, 2, 1)] - num_popout = c(n_examples, n_examples, n_examples) - - user_spec_topics <- paste0("t_", df_short$topic) - #user_spec_topics <- df_short$topic + user_spec_topics <- paste0(df_short$topic) } + + # The current way of selecting colours are off in topcsScatterLegend + if(selection_method %in% c("min_mean_max", "min_max")){ + distribution_color = distribution_color[c(3, 2, 1)] + } + if(selection_method %in% c("min", "max")){ + distribution_color = distribution_color[c(1, 2)] + } #table(df$color_categories) # Dynamically move a column to the fifth position @@ -223,24 +237,28 @@ textTrainExamples <- function( scatter_plot <- topics::topicsScatterLegend( bivariate_color_codes = distribution_color, filtered_test = df, - num_popout = num_popout, + num_popout = scatter_legend_n, way_popout_topics = "mean", user_spec_topics = user_spec_topics, allow_topic_num_legend = FALSE, - scatter_show_axis_values = TRUE, - y_axes_1 = 1, + scatter_show_axis_values = scatter_show_axis_values, + y_axes_1 = y_axes_1, cor_var = "", - label_x_name = "x", - label_y_name = "y", - save_dir = NULL, + label_x_name = grid_legend_x_axes_label, + label_y_name = grid_legend_y_axes_label, + save_dir = save_dir, figure_format = figure_format, - scatter_popout_dot_size = 8, - scatter_bg_dot_size = 4, - width = 10, - height = 8, - seed = 42 + scatter_popout_dot_size = scatter_legend_dot_size, + scatter_bg_dot_size = scatter_legend_bg_dot_size, + width = width, + height = height, + seed = seed ) - scatter_plot + + + # Renaming variable + df_short <- df_short %>% + dplyr::rename(id = topic) results <- list( error_plot = error_plot, diff --git a/man/textTrainExamples.Rd b/man/textTrainExamples.Rd index 08560e93..9a0682cb 100644 --- a/man/textTrainExamples.Rd +++ b/man/textTrainExamples.Rd @@ -17,7 +17,14 @@ textTrainExamples( predictions_color = "darkblue", error_color = "darkred", distribution_color = c("darkgreen", "gray", "darkred"), - figure_format = "svg" + figure_format = "svg", + scatter_legend_dot_size = 4, + scatter_legend_bg_dot_size = 2, + scatter_legend_n = c(3, 3, 3), + scatter_show_axis_values = TRUE, + grid_legend_x_axes_label = "x", + grid_legend_y_axes_label = "y", + seed = 42 ) textPredictExamples( @@ -32,7 +39,14 @@ textPredictExamples( predictions_color = "darkblue", error_color = "darkred", distribution_color = c("darkgreen", "gray", "darkred"), - figure_format = "svg" + figure_format = "svg", + scatter_legend_dot_size = 4, + scatter_legend_bg_dot_size = 2, + scatter_legend_n = c(3, 3, 3), + scatter_show_axis_values = TRUE, + grid_legend_x_axes_label = "x", + grid_legend_y_axes_label = "y", + seed = 42 ) } \arguments{ @@ -59,6 +73,15 @@ textPredictExamples( \item{distribution_color}{(string) colors of the distribution plot} \item{figure_format}{(string) file format of the figures.} + +\item{scatter_legend_dot_size}{(integer) The size of dots in the scatter legend.} + +\item{scatter_legend_bg_dot_size}{(integer) The size of background dots in the scatter legend.} + +\item{scatter_legend_n}{(numeric or vector) A vector determining the number of dots to emphasize in each quadrant of the scatter legend. +For example: c(1,0,1) result in one dot in each quadrant except for the middle quadrant.} + +\item{scatter_show_axis_values}{(boolean) If TRUE, the estimate values are shown on the distribution plot axes.} } \value{ A tibble including examples with descriptive variables. diff --git a/tests/testthat/2_9_textTrainExamples.R b/tests/testthat/2_9_textTrainExamples.R index 502b2c17..2c15ac19 100644 --- a/tests/testthat/2_9_textTrainExamples.R +++ b/tests/testthat/2_9_textTrainExamples.R @@ -3,7 +3,7 @@ #library(testthat) #library(tibble) #library(text) -# +#library(topics) #context("Testing tasks") # # @@ -91,7 +91,15 @@ # target_color = "darkgreen", # predictions_color = "darkblue", # error_color = "darkred", -# distribution_color = c("darkgreen", "gray", "darkred") +# distribution_color = c("darkgreen", "gray", "darkred"), +# figure_format = "svg", +# scatter_legend_dot_size = 6, +# scatter_legend_bg_dot_size = 3, +# scatter_legend_n = c(3,3,3), +# scatter_show_axis_values = TRUE, +# grid_legend_x_axes_label = "x", +# grid_legend_y_axes_label = "y", +# seed = 42 # ) # # examples @@ -112,5 +120,5 @@ # #}) # - +#