diff --git a/rust/cubesql/cubesql/src/sql/dataframe.rs b/rust/cubesql/cubesql/src/sql/dataframe.rs index d932fe75a6212..3c0f856f13a2c 100644 --- a/rust/cubesql/cubesql/src/sql/dataframe.rs +++ b/rust/cubesql/cubesql/src/sql/dataframe.rs @@ -409,14 +409,15 @@ pub fn arrow_to_column_type(arrow_type: DataType) -> Result Ok(ColumnType::Double), DataType::Boolean => Ok(ColumnType::Boolean), DataType::List(field) => Ok(ColumnType::List(field)), - DataType::Int32 | DataType::UInt32 => Ok(ColumnType::Int32), DataType::Decimal(_, _) => Ok(ColumnType::Int32), - DataType::Int8 - | DataType::Int16 - | DataType::Int64 - | DataType::UInt8 + DataType::Int8 //we are missing TableValue::Int8 type to use ColumnType:Int8 + | DataType::UInt8 //we are missing ColumnType::Int16 type + | DataType::Int16 //we are missing ColumnType::Int16 type | DataType::UInt16 - | DataType::UInt64 => Ok(ColumnType::Int64), + | DataType::Int32 => Ok(ColumnType::Int32), + DataType::UInt32 + | DataType::Int64 => Ok(ColumnType::Int64), + DataType::UInt64 => Ok(ColumnType::Decimal(39, 0)), DataType::Null => Ok(ColumnType::String), x => Err(CubeError::internal(format!("unsupported type {:?}", x))), } @@ -452,12 +453,23 @@ pub fn batches_to_dataframe( let array = batch.column(column_index); let num_rows = batch.num_rows(); match array.data_type() { - DataType::UInt16 => convert_array!(array, num_rows, rows, UInt16Array, Int16, i16), + DataType::Int8 => convert_array!(array, num_rows, rows, Int8Array, Int16, i16), + DataType::UInt8 => convert_array!(array, num_rows, rows, UInt8Array, Int16, i16), DataType::Int16 => convert_array!(array, num_rows, rows, Int16Array, Int16, i16), - DataType::UInt32 => convert_array!(array, num_rows, rows, UInt32Array, Int32, i32), + DataType::UInt16 => convert_array!(array, num_rows, rows, UInt16Array, Int32, i32), DataType::Int32 => convert_array!(array, num_rows, rows, Int32Array, Int32, i32), - DataType::UInt64 => convert_array!(array, num_rows, rows, UInt64Array, Int64, i64), + DataType::UInt32 => convert_array!(array, num_rows, rows, UInt32Array, Int64, i64), DataType::Int64 => convert_array!(array, num_rows, rows, Int64Array, Int64, i64), + DataType::UInt64 => { + let a = array.as_any().downcast_ref::().unwrap(); + for i in 0..num_rows { + rows[i].push(if a.is_null(i) { + TableValue::Null + } else { + TableValue::Decimal128(Decimal128Value::new(a.value(i) as i128, 0)) + }); + } + } DataType::Boolean => { convert_array!(array, num_rows, rows, BooleanArray, Boolean, bool) } @@ -685,7 +697,16 @@ pub fn batches_to_dataframe( #[cfg(test)] mod tests { + use std::sync::Arc; + + use datafusion::arrow::array::PrimitiveArray; + use itertools::Itertools; + use super::*; + use crate::compile::arrow::{ + datatypes::{ArrowPrimitiveType, Field}, + record_batch::RecordBatchOptions, + }; #[test] fn test_dataframe_print() { @@ -815,14 +836,14 @@ mod tests { (DataType::Float32, ColumnType::Double), (DataType::Float64, ColumnType::Double), (DataType::Boolean, ColumnType::Boolean), + (DataType::Int8, ColumnType::Int32), + (DataType::UInt8, ColumnType::Int32), + (DataType::Int16, ColumnType::Int32), + (DataType::UInt16, ColumnType::Int32), (DataType::Int32, ColumnType::Int32), - (DataType::UInt32, ColumnType::Int32), - (DataType::Int8, ColumnType::Int64), - (DataType::Int16, ColumnType::Int64), + (DataType::UInt32, ColumnType::Int64), (DataType::Int64, ColumnType::Int64), - (DataType::UInt8, ColumnType::Int64), - (DataType::UInt16, ColumnType::Int64), - (DataType::UInt64, ColumnType::Int64), + (DataType::UInt64, ColumnType::Decimal(39, 0)), (DataType::Null, ColumnType::String), ]; @@ -831,4 +852,197 @@ mod tests { assert_eq!(result, expected_column_type, "Failed for {:?}", arrow_type); } } + + fn create_record_batch( + data_type: DataType, + value: PrimitiveArray, + expected_data_type: ColumnType, + expected_data: Vec, + ) -> Result<(), CubeError> { + let batch = RecordBatch::try_new_with_options( + Arc::new(Schema::new(vec![Field::new("data", data_type, false)])), + vec![Arc::new(value)], + &RecordBatchOptions::default(), + ) + .map_err(|e| CubeError::from(e))?; + + let df = batches_to_dataframe(&batch.schema(), vec![batch.clone()])?; + let colums = df.get_columns().clone(); + let data = df.data; + assert_eq!( + colums.len(), + 1, + "Expecting one column in DF, but: {:?}", + colums + ); + assert_eq!(expected_data_type, colums.get(0).unwrap().column_type); + assert_eq!( + data.len(), + expected_data.len(), + "Expecting {} columns in DF data, but: {:?}", + expected_data.len(), + data + ); + let vec1 = data.into_iter().map(|r| r.values).flatten().collect_vec(); + assert_eq!( + vec1.len(), + expected_data.len(), + "Data len {} != {}", + vec1.len(), + expected_data.len() + ); + assert_eq!(vec1, expected_data); + Ok(()) + } + + #[test] + fn test_timestamp_conversion() -> Result<(), CubeError> { + let data_nano = vec![Some(1640995200000000000)]; + create_record_batch( + DataType::Timestamp(TimeUnit::Nanosecond, None), + TimestampNanosecondArray::from(data_nano.clone()), + ColumnType::Timestamp, + data_nano + .into_iter() + .map(|e| TableValue::Timestamp(TimestampValue::new(e.unwrap(), None))) + .collect::>(), + )?; + + let data_micro = vec![Some(1640995200000000)]; + create_record_batch( + DataType::Timestamp(TimeUnit::Microsecond, None), + TimestampMicrosecondArray::from(data_micro.clone()), + ColumnType::Timestamp, + data_micro + .into_iter() + .map(|e| TableValue::Timestamp(TimestampValue::new(e.unwrap() * 1000, None))) + .collect::>(), + )?; + + let data_milli = vec![Some(1640995200000)]; + create_record_batch( + DataType::Timestamp(TimeUnit::Millisecond, None), + TimestampMillisecondArray::from(data_milli.clone()), + ColumnType::Timestamp, + data_milli + .into_iter() + .map(|e| TableValue::Timestamp(TimestampValue::new(e.unwrap() * 1000000, None))) + .collect::>(), + ) + } + + #[test] + fn test_signed_conversion() -> Result<(), CubeError> { + let data8 = vec![i8::MIN, -1, 0, 1, 2, 3, 4, i8::MAX]; + create_record_batch( + DataType::Int8, + Int8Array::from(data8.clone()), + ColumnType::Int32, //here we are missing TableValue::Int8 to use ColumnType::Int32 + data8 + .into_iter() + .map(|e| TableValue::Int16(e as i16)) + .collect::>(), + )?; + + let data16 = vec![i16::MIN, -1, 0, 1, 2, 3, 4, i16::MAX]; + create_record_batch( + DataType::Int16, + Int16Array::from(data16.clone()), + ColumnType::Int32, //here we are missing ColumnType::Int16 + data16 + .into_iter() + .map(|e| TableValue::Int16(e)) + .collect::>(), + )?; + + let data32 = vec![i32::MIN, -1, 0, 1, 2, 3, 4, i32::MAX]; + create_record_batch( + DataType::Int32, + Int32Array::from(data32.clone()), + ColumnType::Int32, + data32 + .into_iter() + .map(|e| TableValue::Int32(e)) + .collect::>(), + )?; + + let data64 = vec![i64::MIN, -1, 0, 1, 2, 3, 4, i64::MAX]; + create_record_batch( + DataType::Int64, + Int64Array::from(data64.clone()), + ColumnType::Int64, + data64 + .into_iter() + .map(|e| TableValue::Int64(e)) + .collect::>(), + ) + } + + #[test] + fn test_unsigned_conversion() -> Result<(), CubeError> { + let data8 = vec![0, 1, 2, 3, 4, u8::MAX]; + create_record_batch( + DataType::UInt8, + UInt8Array::from(data8.clone()), + ColumnType::Int32, //here we are missing ColumnType::Int16 + data8 + .into_iter() + .map(|e| TableValue::Int16(e as i16)) + .collect::>(), + )?; + + let data16 = vec![0, 1, 2, 3, 4, u16::MAX]; + create_record_batch( + DataType::UInt16, + UInt16Array::from(data16.clone()), + ColumnType::Int32, + data16 + .into_iter() + .map(|e| TableValue::Int32(e as i32)) + .collect::>(), + )?; + + let data32 = vec![0, 1, 2, 3, 4, u32::MAX]; + create_record_batch( + DataType::UInt32, + UInt32Array::from(data32.clone()), + ColumnType::Int64, + data32 + .into_iter() + .map(|e| TableValue::Int64(e as i64)) + .collect::>(), + )?; + + let data64 = vec![0, 1, 2, 3, 4, u64::MAX]; + create_record_batch( + DataType::UInt64, + UInt64Array::from(data64.clone()), + ColumnType::Decimal(39, 0), + data64 + .into_iter() + .map(|e| TableValue::Decimal128(Decimal128Value::new(e as i128, 0))) + .collect::>(), + ) + } + + impl PartialEq for TableValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (TableValue::Null, TableValue::Null) => true, + (TableValue::String(a), TableValue::String(b)) => a == b, + (TableValue::Int16(a), TableValue::Int16(b)) => a == b, + (TableValue::Int32(a), TableValue::Int32(b)) => a == b, + (TableValue::Int64(a), TableValue::Int64(b)) => a == b, + (TableValue::Boolean(a), TableValue::Boolean(b)) => a == b, + (TableValue::Float32(a), TableValue::Float32(b)) => a == b, + (TableValue::Float64(a), TableValue::Float64(b)) => a == b, + (TableValue::List(_), TableValue::List(_)) => panic!("unsupported"), + (TableValue::Decimal128(a), TableValue::Decimal128(b)) => a == b, + (TableValue::Date(a), TableValue::Date(b)) => a == b, + (TableValue::Timestamp(a), TableValue::Timestamp(b)) => a == b, + (TableValue::Interval(_), TableValue::Interval(_)) => panic!("unsupported"), + _ => false, + } + } + } }