Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 86 additions & 62 deletions datafusion/functions-nested/src/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,23 @@

use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, UInt64Array,
Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait,
UInt64Array,
};
use arrow::buffer::{NullBuffer, OffsetBuffer};
use arrow::compute;
use arrow::compute::cast;
use arrow::datatypes::DataType;
use arrow::datatypes::{
DataType::{LargeList, List},
Field,
};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array};
use datafusion_common::{Result, exec_err, utils::take_function_args};
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
use datafusion_common::types::{NativeType, logical_int64};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
use datafusion_macros::user_doc;
use std::any::Any;
use std::sync::Arc;
Expand Down Expand Up @@ -88,7 +90,17 @@ impl Default for ArrayRepeat {
impl ArrayRepeat {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Any),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_int64()),
vec![TypeSignatureClass::Integer],
NativeType::Int64,
),
],
Volatility::Immutable,
),
aliases: vec![String::from("list_repeat")],
}
}
Expand Down Expand Up @@ -132,39 +144,14 @@ impl ScalarUDFImpl for ArrayRepeat {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [first_type, second_type] = take_function_args(self.name(), arg_types)?;

// Coerce the second argument to Int64/UInt64 if it's a numeric type
let second = match second_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
DataType::Int64
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
DataType::UInt64
}
_ => return exec_err!("count must be an integer type"),
};

Ok(vec![first_type.clone(), second])
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let element = &args[0];
let count_array = &args[1];

let count_array = match count_array.data_type() {
DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
DataType::UInt64 => count_array,
_ => return exec_err!("count must be an integer type"),
};

let count_array = as_uint64_array(count_array)?;
let count_array = as_int64_array(&args[1])?;

match element.data_type() {
List(_) => {
Expand Down Expand Up @@ -193,21 +180,31 @@ fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
/// ```
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &UInt64Array,
count_array: &Int64Array,
) -> Result<ArrayRef> {
// Build offsets and take_indices
let total_repeated_values: usize =
count_array.values().iter().map(|&c| c as usize).sum();
let total_repeated_values: usize = (0..count_array.len())
.map(|i| get_count_with_validity(count_array, i))
.sum();

let mut take_indices = Vec::with_capacity(total_repeated_values);
let mut offsets = Vec::with_capacity(count_array.len() + 1);
offsets.push(O::zero());
let mut running_offset = 0usize;

for (idx, &count) in count_array.values().iter().enumerate() {
let count = count as usize;
running_offset += count;
offsets.push(O::from_usize(running_offset).unwrap());
take_indices.extend(std::iter::repeat_n(idx as u64, count))
for idx in 0..count_array.len() {
let count = get_count_with_validity(count_array, idx);
running_offset = running_offset.checked_add(count).ok_or_else(|| {
DataFusionError::Execution(
"array_repeat: running_offset overflowed usize".to_string(),
)
})?;
let offset = O::from_usize(running_offset).ok_or_else(|| {
DataFusionError::Execution(format!(
"array_repeat: offset {running_offset} exceeds the maximum value for offset type"
))
})?;
offsets.push(offset);
take_indices.extend(std::iter::repeat_n(idx as u64, count));
}

// Build the flattened values
Expand All @@ -222,7 +219,7 @@ fn general_repeat<O: OffsetSizeTrait>(
Arc::new(Field::new_list_field(array.data_type().to_owned(), true)),
OffsetBuffer::new(offsets.into()),
repeated_values,
None,
count_array.nulls().cloned(),
)?))
}

Expand All @@ -238,23 +235,24 @@ fn general_repeat<O: OffsetSizeTrait>(
/// ```
fn general_list_repeat<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
count_array: &UInt64Array,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let counts = count_array.values();
let list_offsets = list_array.value_offsets();

// calculate capacities for pre-allocation
let outer_total = counts.iter().map(|&c| c as usize).sum();
let inner_total = counts
.iter()
.enumerate()
.filter(|&(i, _)| !list_array.is_null(i))
.map(|(i, &c)| {
let len = list_offsets[i + 1].to_usize().unwrap()
- list_offsets[i].to_usize().unwrap();
len * (c as usize)
})
.sum();
let mut outer_total = 0usize;
let mut inner_total = 0usize;
for i in 0..count_array.len() {
let count = get_count_with_validity(count_array, i);
if count > 0 {
outer_total += count;
if list_array.is_valid(i) {
let len = list_offsets[i + 1].to_usize().unwrap()
- list_offsets[i].to_usize().unwrap();
inner_total += len * count;
}
}
}

// Build inner structures
let mut inner_offsets = Vec::with_capacity(outer_total + 1);
Expand All @@ -263,17 +261,27 @@ fn general_list_repeat<O: OffsetSizeTrait>(
let mut inner_running = 0usize;
inner_offsets.push(O::zero());

for (row_idx, &count) in counts.iter().enumerate() {
let is_valid = !list_array.is_null(row_idx);
for row_idx in 0..count_array.len() {
let count = get_count_with_validity(count_array, row_idx);
let list_is_valid = list_array.is_valid(row_idx);
let start = list_offsets[row_idx].to_usize().unwrap();
let end = list_offsets[row_idx + 1].to_usize().unwrap();
let row_len = end - start;

for _ in 0..count {
inner_running += row_len;
inner_offsets.push(O::from_usize(inner_running).unwrap());
inner_nulls.append(is_valid);
if is_valid {
inner_running = inner_running.checked_add(row_len).ok_or_else(|| {
DataFusionError::Execution(
"array_repeat: inner offset overflowed usize".to_string(),
)
})?;
let offset = O::from_usize(inner_running).ok_or_else(|| {
DataFusionError::Execution(format!(
"array_repeat: offset {inner_running} exceeds the maximum value for offset type"
))
})?;
inner_offsets.push(offset);
inner_nulls.append(list_is_valid);
if list_is_valid {
take_indices.extend(start as u64..end as u64);
}
}
Expand All @@ -298,8 +306,24 @@ fn general_list_repeat<O: OffsetSizeTrait>(
list_array.data_type().to_owned(),
true,
)),
OffsetBuffer::<O>::from_lengths(counts.iter().map(|&c| c as usize)),
OffsetBuffer::<O>::from_lengths(
count_array
.iter()
.map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)),
),
Arc::new(inner_list),
None,
count_array.nulls().cloned(),
)?))
}

/// Helper function to get count from count_array at given index
/// Return 0 for null values or non-positive count.
#[inline]
fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize {
if count_array.is_null(idx) {
0
} else {
let c = count_array.value(idx);
if c > 0 { c as usize } else { 0 }
}
}
81 changes: 78 additions & 3 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3256,24 +3256,99 @@ drop table array_repeat_table;
statement ok
drop table large_array_repeat_table;


# array_repeat: arrays with NULL counts
statement ok
create table array_repeat_null_count_table
as values
(1, 2),
(2, null),
(3, 1);
(3, 1),
(4, -1),
(null, null);

query I?
select column1, array_repeat(column1, column2) from array_repeat_null_count_table;
----
1 [1, 1]
2 []
2 NULL
3 [3]
4 []
NULL NULL

statement ok
drop table array_repeat_null_count_table

# array_repeat: nested arrays with NULL counts
statement ok
create table array_repeat_nested_null_count_table
as values
([[1, 2], [3, 4]], 2),
([[5, 6], [7, 8]], null),
([[null, null], [9, 10]], 1),
(null, 3),
([[11, 12]], -1);

query ??
select column1, array_repeat(column1, column2) from array_repeat_nested_null_count_table;
----
[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
[[5, 6], [7, 8]] NULL
[[NULL, NULL], [9, 10]] [[[NULL, NULL], [9, 10]]]
NULL [NULL, NULL, NULL]
[[11, 12]] []

statement ok
drop table array_repeat_nested_null_count_table

# array_repeat edge cases: empty arrays
query ???
select array_repeat([], 3), array_repeat([], 0), array_repeat([], null);
----
[[], [], []] [] NULL

query ??
select array_repeat(null::int, 0), array_repeat(null::int, null);
----
[] NULL

# array_repeat LargeList with NULL count
statement ok
create table array_repeat_large_list_null_table
as values
(arrow_cast([1, 2, 3], 'LargeList(Int64)'), 2),
(arrow_cast([4, 5], 'LargeList(Int64)'), null),
(arrow_cast(null, 'LargeList(Int64)'), 3);

query ??
select column1, array_repeat(column1, column2) from array_repeat_large_list_null_table;
----
[1, 2, 3] [[1, 2, 3], [1, 2, 3]]
[4, 5] NULL
NULL [NULL, NULL, NULL]

statement ok
drop table array_repeat_large_list_null_table

# array_repeat edge cases: LargeList nested with NULL count
statement ok
create table array_repeat_large_nested_null_table
as values
(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2),
(arrow_cast([[5, 6], [7, 8]], 'LargeList(List(Int64))'), null),
(arrow_cast([[null, null]], 'LargeList(List(Int64))'), 1),
(null, 3);

query ??
select column1, array_repeat(column1, column2) from array_repeat_large_nested_null_table;
----
[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
[[5, 6], [7, 8]] NULL
[[NULL, NULL]] [[[NULL, NULL]]]
NULL [NULL, NULL, NULL]

statement ok
drop table array_repeat_large_nested_null_table

## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`)

# test with empty array
Expand Down