Skip to content

Commit

Permalink
Check strong encoding and coerce if necessary in fmatch() / join(). S…
Browse files Browse the repository at this point in the history
…hould fix #566, #579, and #618.
  • Loading branch information
SebKrantz committed Aug 20, 2024
1 parent 77fd62d commit f4b5deb
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/base_radixsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#define SET_TRLEN(x, v) SET_STDVEC_TRUELENGTH(x, ((int) (v)))

#define MYLEV(x) (((SEXPREC_partial *)(x))->sxpinfo.gp)
#define IS_UTF8(x) (MYLEV(x) & 8)
#define IS_ASCII(x) (MYLEV(x) & 64) // from data.table.h

#define SETTOF(x,v) ((((SEXPREC_partial *)(x))->sxpinfo.type)=(v))
Expand Down
6 changes: 6 additions & 0 deletions src/data.table.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#define SEXPPTR(x) ((SEXP *)DATAPTR(x)) // to avoid overhead of looped VECTOR_ELT
#define SEXPPTR_RO(x) ((const SEXP *)DATAPTR_RO(x)) // to avoid overhead of looped VECTOR_ELT

// Needed for match.c and join.c
#define NEED2UTF8(s) !(IS_ASCII(s) || (s)==NA_STRING || IS_UTF8(s))
#define ENC2UTF8(s) (!NEED2UTF8(s) ? (s) : mkCharCE(translateCharUTF8(s), CE_UTF8))

// for use with bit64::integer64
#define NA_INTEGER64 INT64_MIN
#define MAX_INTEGER64 INT64_MAX
Expand Down Expand Up @@ -52,6 +56,8 @@ extern size_t sizes[100]; // max appears to be FUNSXP = 99, see Rinternals.h
extern size_t typeorder[100];

// data.table_utils.c
int need2utf8(SEXP x);
SEXP coerceUtf8IfNeeded(SEXP x);
SEXP setnames(SEXP x, SEXP nam);
bool allNA(SEXP x, bool errorForBadType);
SEXP allNAv(SEXP x, SEXP errorForBadType);
Expand Down
26 changes: 26 additions & 0 deletions src/data.table_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,32 @@
#include "data.table.h"


int need2utf8(SEXP x) {
const int xlen = length(x);
const SEXP *xd = STRING_PTR_RO(x);
// for (int i=0; i<xlen; i++) {
// if (NEED2UTF8(xd[i]))
// return(true);
// }
// return(false);
if (xlen <= 1) return xlen == 1 ? NEED2UTF8(xd[0]) : 0;
return NEED2UTF8(xd[0]) || NEED2UTF8(xd[xlen/2]) || NEED2UTF8(xd[xlen-1]);
}

SEXP coerceUtf8IfNeeded(SEXP x) {
if (!need2utf8(x))
return(x);
const int xlen = length(x);
SEXP ans = PROTECT(allocVector(STRSXP, xlen));
const SEXP *xd = STRING_PTR_RO(x);
for (int i=0; i<xlen; i++) {
SET_STRING_ELT(ans, i, ENC2UTF8(xd[i]));
}
UNPROTECT(1);
return(ans);
}


SEXP setnames(SEXP x, SEXP nam) {
if(TYPEOF(nam) != STRSXP) error("names need to be character typed");
if(INHERITS(x, char_datatable)) {
Expand Down
9 changes: 5 additions & 4 deletions src/join.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "collapse_c.h" // Needs to be first because includes OpenMP, to avoid namespace conflicts.
#include "data.table.h"
#include "kit.h"
#include "base_radixsort.h"


/* A Sort-Merge Join
See: https://www.dcs.ed.ac.uk/home/tz/phd/thesis/node20.htm
Expand Down Expand Up @@ -301,9 +302,9 @@ SEXP sort_merge_join(SEXP x, SEXP table, SEXP ot, SEXP count) {
else sort_merge_join_double_second(REAL_RO(pci[0]), REAL_RO(pci[1])-1, pg, ptab, pot, nx, nt, pres);
break;
case STRSXP:
checkEncodings(pci[0]); checkEncodings(pci[1]);
if(i == 0) sort_merge_join_string(SEXPPTR_RO(pci[0]), SEXPPTR_RO(pci[1])-1, pg, ptab, pot, nx, nt, pres);
else sort_merge_join_string_second(SEXPPTR_RO(pci[0]), SEXPPTR_RO(pci[1])-1, pg, ptab, pot, nx, nt, pres);
if(i == 0) sort_merge_join_string(SEXPPTR_RO(PROTECT(coerceUtf8IfNeeded(pci[0]))), SEXPPTR_RO(PROTECT(coerceUtf8IfNeeded(pci[1])))-1, pg, ptab, pot, nx, nt, pres);
else sort_merge_join_string_second(SEXPPTR_RO(PROTECT(coerceUtf8IfNeeded(pci[0]))), SEXPPTR_RO(PROTECT(coerceUtf8IfNeeded(pci[1])))-1, pg, ptab, pot, nx, nt, pres);
UNPROTECT(2);
break;
case CPLXSXP:
if(i == 0) sort_merge_join_complex(COMPLEX_RO(pci[0]), COMPLEX_RO(pci[1])-1, pg, ptab, pot, nx, nt, pres);
Expand Down
34 changes: 32 additions & 2 deletions src/match.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "collapse_c.h" // Needs to be first because includes OpenMP, to avoid namespace conflicts.
#include "data.table.h"
#include "kit.h"


Expand Down Expand Up @@ -250,6 +251,12 @@ SEXP match_single(SEXP x, SEXP table, SEXP nomatch) {
}
} break;
case STRSXP: {
if (need2utf8(x)) {
PROTECT(x = coerceUtf8IfNeeded(x)); ++nprotect;
}
if (need2utf8(table)) {
PROTECT(table = coerceUtf8IfNeeded(table)); ++nprotect;
}
const SEXP *restrict px = SEXPPTR(x), *restrict pt = SEXPPTR(table);
// fill hash table with indices of 'table'
for (int i = 0; i != nt; ++i) {
Expand Down Expand Up @@ -428,6 +435,10 @@ SEXP match_two_vectors(SEXP x, SEXP table, SEXP nomatch) {
}
} break;
case STRSXP: {
for(int i = 0; i < 2; ++i) {
if(need2utf8(pc1[i])) SET_VECTOR_ELT(pc[0], i, coerceUtf8IfNeeded(pc1[i]));
if(need2utf8(pc2[i])) SET_VECTOR_ELT(pc[1], i, coerceUtf8IfNeeded(pc2[i]));
}
const SEXP *restrict px1 = SEXPPTR(pc1[0]), *restrict px2 = SEXPPTR(pc2[0]),
*restrict pt1 = SEXPPTR(pc1[1]), *restrict pt2 = SEXPPTR(pc2[1]);
// fill hash table with indices of 'table'
Expand Down Expand Up @@ -523,6 +534,9 @@ SEXP match_two_vectors(SEXP x, SEXP table, SEXP nomatch) {
} else if ((t1 == REALSXP && t2 == STRSXP) || (t1 == STRSXP && t2 == REALSXP)) {
const int rev = t1 == STRSXP;
const double *restrict pxr = REAL(VECTOR_ELT(pc[rev], 0)), *restrict ptr = REAL(VECTOR_ELT(pc[rev], 1));
for(int i = 0; i < 2; ++i) {
if(need2utf8(VECTOR_ELT(pc[1-rev], i))) SET_VECTOR_ELT(pc[1-rev], i, coerceUtf8IfNeeded(VECTOR_ELT(pc[1-rev], i)));
}
const SEXP *restrict pxs = SEXPPTR(VECTOR_ELT(pc[1-rev], 0)), *restrict pts = SEXPPTR(VECTOR_ELT(pc[1-rev], 1));
union uno tpv;
// fill hash table with indices of 'table'
Expand Down Expand Up @@ -554,6 +568,9 @@ SEXP match_two_vectors(SEXP x, SEXP table, SEXP nomatch) {
} else if((t1 == INTSXP && t2 == STRSXP) || (t1 == STRSXP && t2 == INTSXP)) {
const int rev = t1 == STRSXP;
const int *restrict pxi = INTEGER(VECTOR_ELT(pc[rev], 0)), *restrict pti = INTEGER(VECTOR_ELT(pc[rev], 1));
for(int i = 0; i < 2; ++i) {
if(need2utf8(VECTOR_ELT(pc[1-rev], i))) SET_VECTOR_ELT(pc[1-rev], i, coerceUtf8IfNeeded(VECTOR_ELT(pc[1-rev], i)));
}
const SEXP *restrict pxs = SEXPPTR(VECTOR_ELT(pc[1-rev], 0)), *restrict pts = SEXPPTR(VECTOR_ELT(pc[1-rev], 1));

// fill hash table with indices of 'table'
Expand Down Expand Up @@ -643,6 +660,10 @@ void match_two_vectors_extend(const SEXP *pc, const int nmv, const int n, const
}
} break;
case STRSXP: {
for(int i = 0; i < 2; ++i) {
if(need2utf8(pc1[i])) SET_VECTOR_ELT(pc[0], i, coerceUtf8IfNeeded(pc1[i]));
if(need2utf8(pc2[i])) SET_VECTOR_ELT(pc[1], i, coerceUtf8IfNeeded(pc2[i]));
}
const SEXP *restrict px1 = SEXPPTR(pc1[0]), *restrict px2 = SEXPPTR(pc2[0]),
*restrict pt1 = SEXPPTR(pc1[1]), *restrict pt2 = SEXPPTR(pc2[1]);
// fill hash table with indices of 'table'
Expand Down Expand Up @@ -747,6 +768,9 @@ void match_two_vectors_extend(const SEXP *pc, const int nmv, const int n, const
} else if ((t1 == REALSXP && t2 == STRSXP) || (t1 == STRSXP && t2 == REALSXP)) {
const int rev = t1 == STRSXP;
const double *restrict pxr = REAL(VECTOR_ELT(pc[rev], 0)), *restrict ptr = REAL(VECTOR_ELT(pc[rev], 1));
for(int i = 0; i < 2; ++i) {
if(need2utf8(VECTOR_ELT(pc[1-rev], i))) SET_VECTOR_ELT(pc[1-rev], i, coerceUtf8IfNeeded(VECTOR_ELT(pc[1-rev], i)));
}
const SEXP *restrict pxs = SEXPPTR(VECTOR_ELT(pc[1-rev], 0)), *restrict pts = SEXPPTR(VECTOR_ELT(pc[1-rev], 1));
union uno tpv;
// fill hash table with indices of 'table'
Expand Down Expand Up @@ -781,6 +805,9 @@ void match_two_vectors_extend(const SEXP *pc, const int nmv, const int n, const
} else if((t1 == INTSXP && t2 == STRSXP) || (t1 == STRSXP && t2 == INTSXP)) {
const int rev = t1 == STRSXP;
const int *restrict pxi = INTEGER(VECTOR_ELT(pc[rev], 0)), *restrict pti = INTEGER(VECTOR_ELT(pc[rev], 1));
for(int i = 0; i < 2; ++i) {
if(need2utf8(VECTOR_ELT(pc[1-rev], i))) SET_VECTOR_ELT(pc[1-rev], i, coerceUtf8IfNeeded(VECTOR_ELT(pc[1-rev], i)));
}
const SEXP *restrict pxs = SEXPPTR(VECTOR_ELT(pc[1-rev], 0)), *restrict pts = SEXPPTR(VECTOR_ELT(pc[1-rev], 1));

// fill hash table with indices of 'table'
Expand Down Expand Up @@ -871,7 +898,8 @@ void match_additional(const SEXP *pcj, const int nmv, const int n, const int nt,
}
} break;
case STRSXP: {
const SEXP *restrict px = SEXPPTR(pcj[0]), *restrict pt = SEXPPTR(pcj[1]);
const SEXP *restrict px = SEXPPTR(PROTECT(coerceUtf8IfNeeded(pcj[0]))),
*restrict pt = SEXPPTR(PROTECT(coerceUtf8IfNeeded(pcj[1])));
// fill hash table with indices of 'table'
for (int i = 0; i != nt; ++i) {
if(ptab_copy[i] == nmv) {
Expand Down Expand Up @@ -903,6 +931,7 @@ void match_additional(const SEXP *pcj, const int nmv, const int n, const int nt,
pans[i] = nmv;
stbl2:;
}
UNPROTECT(2);
} break;
case REALSXP: {
const double *restrict px = REAL(pcj[0]), *restrict pt = REAL(pcj[1]);
Expand Down Expand Up @@ -964,11 +993,12 @@ void match_rest(const SEXP *pcj, const int nmv, const int n, const int nt, int *
}
} break;
case STRSXP: {
const SEXP *restrict px = SEXPPTR(pcj[0]), *restrict pt = SEXPPTR(pcj[1])-1;
const SEXP *restrict px = SEXPPTR(PROTECT(coerceUtf8IfNeeded(pcj[0]))), *restrict pt = SEXPPTR(PROTECT(coerceUtf8IfNeeded(pcj[1])))-1;
for (int i = 0; i != n; ++i) {
if(pans[i] == nmv) continue;
if(px[i] != pt[pans[i]]) pans[i] = nmv;
}
UNPROTECT(2);
} break;
case REALSXP: {
const double *restrict px = REAL(pcj[0]), *restrict pt = REAL(pcj[1])-1;
Expand Down

0 comments on commit f4b5deb

Please sign in to comment.