Skip to content

Commit

Permalink
Adding stop warning and chaning names for implicit motives.
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Kjell authored and Oscar Kjell committed Dec 4, 2024
1 parent d243a96 commit 6e5a063
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
2 changes: 1 addition & 1 deletion R/2_4_0_textPredict_Assess_Classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ textPredict <- function(
}

# display message to user
message(colourise("Predictions are ready!", fg = "green"))
message(colourise("Assessments are ready!", fg = "green"))
message("\n")
return(results)

Expand Down
24 changes: 19 additions & 5 deletions R/2_4_2_textPredictImplicitMotives.R
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,7 @@ get_implicit_model_info <- function(
show_texts,
type,
texts,
story_id#,
#lower_case_model
story_id
) {
# show_prob is by default FALSE
show_prob <- FALSE
Expand Down Expand Up @@ -618,6 +617,11 @@ textPredictImplicitMotives <- function(
story_id <- NULL
}

if ((previous_sentence == T & is.null(story_id)) ||
previous_sentence == T & is.null(participant_id)){
stop("error: there must be story_id and participant_id when previous_sentence = T")
}

use_row_id_name = FALSE

#### Special treatment for implicit motives - see private functions ####
Expand Down Expand Up @@ -709,8 +713,7 @@ textPredictImplicitMotives <- function(
predicted_scores2 <- tibble(
!!class_name:=ifelse(classifications_rev > 0.5 , 1, 0),
.pred_0 = 1-classifications_rev,
.pred_1 = classifications_rev#,
#texts = PSE_stories_sentence_level$Story_Text[c(1, 79)]
.pred_1 = classifications_rev
)
}

Expand All @@ -735,7 +738,18 @@ textPredictImplicitMotives <- function(
# change participant_id to row_id
if(use_row_id_name){
colnames(predicted_scores2[[2]])[colnames(predicted_scores2[[2]]) == "participant_id"] <- "row_id"
}
}



# Check and rename if necessary
if ("person_predictions" %in% names(predicted_scores2)) {
names(predicted_scores2)[names(predicted_scores2) == "person_predictions"] <- "person_assessments"
}

if ("sentence_predictions" %in% names(predicted_scores2)) {
names(predicted_scores2)[names(predicted_scores2) == "sentence_predictions"] <- "sentence_assessments"
}

return(predicted_scores2)
}
Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test_2_6_textPredict_implicitmotives.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ test_that("textPredict Implicit motives", {
)
testthat::expect_that(predictions_participant_1, testthat::is_a("list"))
testthat::expect_equal(length(predictions_participant_1), 3)
testthat::expect_equal(predictions_participant_1$sentence_predictions$.pred_0[[1]], 0.9233226, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$sentence_predictions$.pred_1[[1]], 0.07667742, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$person_predictions$person_prob[[1]], -0.07572437, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$person_predictions$person_class[[2]], -0.1359569, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$person_predictions$person_prob_no_wc_correction[[3]], 0.1402563, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$sentence_assessments$.pred_0[[1]], 0.9233226, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$sentence_assessments$.pred_1[[1]], 0.07667742, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$person_assessments$person_prob[[1]], -0.07572437, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$person_assessments$person_class[[2]], -0.1359569, tolerance = 0.0001)
testthat::expect_equal(predictions_participant_1$person_assessments$person_prob_no_wc_correction[[3]], 0.1402563, tolerance = 0.0001)

testthat::expect_equal(predictions_participant_1$dataset$Participant_ID[[2]], "P02", tolerance = 0.0001)

Expand Down Expand Up @@ -199,7 +199,7 @@ test_that("textPredict Implicit motives", {
)
testthat::expect_that(predictions_sentence_2, testthat::is_a("list"))
testthat::expect_equal(length(predictions_sentence_2), 3)
testthat::expect_equal(predictions_sentence_2$sentence_predictions$.pred_0[[1]], 0.9089251, tolerance = 0.0001)
testthat::expect_equal(predictions_sentence_2$sentence_predictions$.pred_1[[2]], 0.09193248, tolerance = 0.0001)
testthat::expect_equal(predictions_sentence_2$sentence_assessments$.pred_0[[1]], 0.9089251, tolerance = 0.0001)
testthat::expect_equal(predictions_sentence_2$sentence_assessments$.pred_1[[2]], 0.09193248, tolerance = 0.0001)

})

0 comments on commit 6e5a063

Please sign in to comment.