diff --git a/encodings/runend/public-api.lock b/encodings/runend/public-api.lock index c6dbee58be2..04b047ec5dc 100644 --- a/encodings/runend/public-api.lock +++ b/encodings/runend/public-api.lock @@ -69,8 +69,6 @@ pub fn vortex_runend::RunEndVTable::is_sorted(&self, array: &vortex_runend::RunE pub fn vortex_runend::RunEndVTable::is_strict_sorted(&self, array: &vortex_runend::RunEndArray) -> vortex_error::VortexResult> impl vortex_array::compute::min_max::MinMaxKernel for vortex_runend::RunEndVTable pub fn vortex_runend::RunEndVTable::min_max(&self, array: &vortex_runend::RunEndArray) -> vortex_error::VortexResult> -impl vortex_array::compute::numeric::NumericKernel for vortex_runend::RunEndVTable -pub fn vortex_runend::RunEndVTable::numeric(&self, array: &vortex_runend::RunEndArray, rhs: &dyn vortex_array::array::Array, op: vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::fill_null::kernel::FillNullReduce for vortex_runend::RunEndVTable pub fn vortex_runend::RunEndVTable::fill_null(array: &vortex_runend::RunEndArray, fill_value: &vortex_scalar::scalar::Scalar) -> vortex_error::VortexResult> impl vortex_array::vtable::VTable for vortex_runend::RunEndVTable diff --git a/encodings/runend/src/compute/binary_numeric.rs b/encodings/runend/src/compute/binary_numeric.rs deleted file mode 100644 index 4d5df750697..00000000000 --- a/encodings/runend/src/compute/binary_numeric.rs +++ /dev/null @@ -1,83 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::Array; -use vortex_array::ArrayRef; -use vortex_array::IntoArray; -use vortex_array::arrays::ConstantArray; -use vortex_array::compute::NumericKernel; -use vortex_array::compute::NumericKernelAdapter; -use vortex_array::compute::numeric; -use vortex_array::register_kernel; -use vortex_error::VortexResult; -use vortex_scalar::NumericOperator; - -use crate::RunEndArray; -use crate::RunEndVTable; - -impl NumericKernel for RunEndVTable { - fn numeric( - &self, - array: &RunEndArray, - rhs: &dyn Array, - op: NumericOperator, - ) -> VortexResult> { - let Some(rhs_scalar) = rhs.as_constant() else { - return Ok(None); - }; - - let rhs_const_array = ConstantArray::new(rhs_scalar, array.values().len()).into_array(); - - // SAFETY: ends are preserved. - unsafe { - Ok(Some( - RunEndArray::new_unchecked( - array.ends().clone(), - numeric(array.values(), &rhs_const_array, op)?, - array.offset(), - array.len(), - ) - .into_array(), - )) - } - } -} - -register_kernel!(NumericKernelAdapter(RunEndVTable).lift()); - -#[cfg(test)] -mod tests { - use rstest::rstest; - use vortex_array::IntoArray; - use vortex_array::arrays::PrimitiveArray; - use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array; - use vortex_buffer::buffer; - - use crate::RunEndArray; - - #[rstest] - #[case::runend_i32_basic(RunEndArray::encode( - buffer![10i32, 10, 10, 20, 20, 30, 30, 30, 30].into_array() - ).unwrap())] - #[case::runend_u32_basic(RunEndArray::encode( - buffer![100u32, 100, 200, 200, 200].into_array() - ).unwrap())] - #[case::runend_i64_basic(RunEndArray::encode( - buffer![1000i64, 1000, 2000, 2000, 3000, 3000].into_array() - ).unwrap())] - #[case::runend_u64_basic(RunEndArray::encode( - buffer![5000u64, 5000, 5000, 6000, 6000].into_array() - ).unwrap())] - #[case::runend_f32_basic(RunEndArray::encode( - buffer![1.5f32, 1.5, 2.5, 2.5, 3.5].into_array() - ).unwrap())] - #[case::runend_f64_basic(RunEndArray::encode( - buffer![10.1f64, 10.1, 20.2, 20.2, 20.2].into_array() - ).unwrap())] - #[case::runend_i32_large(RunEndArray::encode( - PrimitiveArray::from_iter((0..100).map(|i| i / 5)).into_array() - ).unwrap())] - fn test_runend_binary_numeric(#[case] array: RunEndArray) { - test_binary_numeric_array(array.into_array()); - } -} diff --git a/encodings/runend/src/compute/mod.rs b/encodings/runend/src/compute/mod.rs index 97ded7b4239..8b390bef750 100644 --- a/encodings/runend/src/compute/mod.rs +++ b/encodings/runend/src/compute/mod.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod binary_numeric; mod cast; mod compare; mod fill_null; diff --git a/encodings/sparse/public-api.lock b/encodings/sparse/public-api.lock index 0065da66482..264eba2e6fd 100644 --- a/encodings/sparse/public-api.lock +++ b/encodings/sparse/public-api.lock @@ -43,8 +43,6 @@ impl vortex_array::arrays::slice::SliceKernel for vortex_sparse::SparseVTable pub fn vortex_sparse::SparseVTable::slice(array: &vortex_sparse::SparseArray, range: core::ops::range::Range, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::compute::cast::CastReduce for vortex_sparse::SparseVTable pub fn vortex_sparse::SparseVTable::cast(array: &vortex_sparse::SparseArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> -impl vortex_array::compute::numeric::NumericKernel for vortex_sparse::SparseVTable -pub fn vortex_sparse::SparseVTable::numeric(&self, array: &vortex_sparse::SparseArray, rhs: &dyn vortex_array::array::Array, op: vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::not::kernel::NotReduce for vortex_sparse::SparseVTable pub fn vortex_sparse::SparseVTable::invert(array: &vortex_sparse::SparseArray) -> vortex_error::VortexResult> impl vortex_array::vtable::VTable for vortex_sparse::SparseVTable diff --git a/encodings/sparse/src/compute/binary_numeric.rs b/encodings/sparse/src/compute/binary_numeric.rs deleted file mode 100644 index f0df1797e36..00000000000 --- a/encodings/sparse/src/compute/binary_numeric.rs +++ /dev/null @@ -1,47 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::Array; -use vortex_array::ArrayRef; -use vortex_array::IntoArray; -use vortex_array::arrays::ConstantArray; -use vortex_array::compute::NumericKernel; -use vortex_array::compute::NumericKernelAdapter; -use vortex_array::compute::numeric; -use vortex_array::register_kernel; -use vortex_error::VortexResult; -use vortex_error::vortex_err; -use vortex_scalar::NumericOperator; - -use crate::SparseArray; -use crate::SparseVTable; - -impl NumericKernel for SparseVTable { - fn numeric( - &self, - array: &SparseArray, - rhs: &dyn Array, - op: NumericOperator, - ) -> VortexResult> { - let Some(rhs_scalar) = rhs.as_constant() else { - return Ok(None); - }; - - let new_patches = array.patches().clone().map_values(|values| { - let rhs_const_array = ConstantArray::new(rhs_scalar.clone(), values.len()).into_array(); - - numeric(&values, &rhs_const_array, op) - })?; - let new_fill_value = array - .fill_scalar() - .as_primitive() - .checked_binary_numeric(&rhs_scalar.as_primitive(), op) - .ok_or_else(|| vortex_err!("numeric overflow"))? - .into(); - Ok(Some( - SparseArray::try_new_from_patches(new_patches, new_fill_value)?.into_array(), - )) - } -} - -register_kernel!(NumericKernelAdapter(SparseVTable).lift()); diff --git a/encodings/sparse/src/compute/mod.rs b/encodings/sparse/src/compute/mod.rs index 339fb216217..ea34defd285 100644 --- a/encodings/sparse/src/compute/mod.rs +++ b/encodings/sparse/src/compute/mod.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod binary_numeric; mod cast; mod filter; mod take; diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 6fd1a7208f6..d5e7515bc1c 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -61,8 +61,6 @@ impl vortex_array::compute::LikeKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::like(&self, array: &vortex_array::arrays::DictArray, pattern: &dyn vortex_array::Array, options: vortex_array::compute::LikeOptions) -> vortex_error::VortexResult> impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::min_max(&self, array: &vortex_array::arrays::DictArray) -> vortex_error::VortexResult> -impl vortex_array::compute::NumericKernel for vortex_array::arrays::DictVTable -pub fn vortex_array::arrays::DictVTable::numeric(&self, lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, op: vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator) -> vortex_error::VortexResult> impl vortex_array::expr::FillNullKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::fill_null(array: &vortex_array::arrays::DictArray, fill_value: &vortex_scalar::scalar::Scalar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::DictVTable @@ -354,8 +352,6 @@ impl vortex_array::compute::MaskKernel for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::mask(&self, array: &vortex_array::arrays::ConstantArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::min_max(&self, array: &vortex_array::arrays::ConstantArray) -> vortex_error::VortexResult> -impl vortex_array::compute::NumericKernel for vortex_array::arrays::ConstantVTable -pub fn vortex_array::arrays::ConstantVTable::numeric(&self, array: &vortex_array::arrays::ConstantArray, rhs: &dyn vortex_array::Array, op: vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator) -> vortex_error::VortexResult> impl vortex_array::compute::SumKernel for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::sum(&self, array: &vortex_array::arrays::ConstantArray, accumulator: &vortex_scalar::scalar::Scalar) -> vortex_error::VortexResult impl vortex_array::expr::FillNullReduce for vortex_array::arrays::ConstantVTable @@ -573,8 +569,6 @@ impl vortex_array::compute::LikeKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::like(&self, array: &vortex_array::arrays::DictArray, pattern: &dyn vortex_array::Array, options: vortex_array::compute::LikeOptions) -> vortex_error::VortexResult> impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::min_max(&self, array: &vortex_array::arrays::DictArray) -> vortex_error::VortexResult> -impl vortex_array::compute::NumericKernel for vortex_array::arrays::DictVTable -pub fn vortex_array::arrays::DictVTable::numeric(&self, lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, op: vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator) -> vortex_error::VortexResult> impl vortex_array::expr::FillNullKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::fill_null(array: &vortex_array::arrays::DictArray, fill_value: &vortex_scalar::scalar::Scalar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::DictVTable @@ -3031,15 +3025,6 @@ pub fn vortex_array::expr::NotReduceAdaptor::fmt(&self, f: &mut core::fmt::Fo impl vortex_array::optimizer::rules::ArrayParentReduceRule for vortex_array::expr::NotReduceAdaptor where V: vortex_array::expr::NotReduce pub type vortex_array::expr::NotReduceAdaptor::Parent = vortex_array::arrays::ExactScalarFn pub fn vortex_array::expr::NotReduceAdaptor::reduce_parent(&self, array: &::Array, _parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Not>, _child_idx: usize) -> vortex_error::VortexResult> -pub struct vortex_array::compute::NumericKernelAdapter(pub V) -impl vortex_array::compute::NumericKernelAdapter -pub const fn vortex_array::compute::NumericKernelAdapter::lift(&'static self) -> vortex_array::compute::NumericKernelRef -impl core::fmt::Debug for vortex_array::compute::NumericKernelAdapter -pub fn vortex_array::compute::NumericKernelAdapter::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result -impl vortex_array::compute::Kernel for vortex_array::compute::NumericKernelAdapter -pub fn vortex_array::compute::NumericKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> -pub struct vortex_array::compute::NumericKernelRef(_) -impl inventory::Collect for vortex_array::compute::NumericKernelRef pub struct vortex_array::compute::SumArgs<'a> pub vortex_array::compute::SumArgs::accumulator: &'a vortex_scalar::scalar::Scalar pub vortex_array::compute::SumArgs::array: &'a dyn vortex_array::Array @@ -3239,8 +3224,6 @@ impl vort pub fn vortex_array::compute::MinMaxKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> impl vortex_array::compute::Kernel for vortex_array::compute::NaNCountKernelAdapter pub fn vortex_array::compute::NaNCountKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> -impl vortex_array::compute::Kernel for vortex_array::compute::NumericKernelAdapter -pub fn vortex_array::compute::NumericKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> impl vortex_array::compute::Kernel for vortex_array::compute::SumKernelAdapter pub fn vortex_array::compute::SumKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> impl vortex_array::compute::Kernel for vortex_array::compute::ZipKernelAdapter @@ -3319,12 +3302,6 @@ pub trait vortex_array::compute::NotReduce: vortex_array::vtable::VTable pub fn vortex_array::compute::NotReduce::invert(array: &Self::Array) -> vortex_error::VortexResult> impl vortex_array::expr::NotReduce for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::invert(array: &vortex_array::arrays::ConstantArray) -> vortex_error::VortexResult> -pub trait vortex_array::compute::NumericKernel: vortex_array::vtable::VTable -pub fn vortex_array::compute::NumericKernel::numeric(&self, array: &Self::Array, other: &dyn vortex_array::Array, op: vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator) -> vortex_error::VortexResult> -impl vortex_array::compute::NumericKernel for vortex_array::arrays::ConstantVTable -pub fn vortex_array::arrays::ConstantVTable::numeric(&self, array: &vortex_array::arrays::ConstantArray, rhs: &dyn vortex_array::Array, op: vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator) -> vortex_error::VortexResult> -impl vortex_array::compute::NumericKernel for vortex_array::arrays::DictVTable -pub fn vortex_array::arrays::DictVTable::numeric(&self, lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, op: vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator) -> vortex_error::VortexResult> pub trait vortex_array::compute::Options: 'static pub fn vortex_array::compute::Options::as_any(&self) -> &dyn core::any::Any impl vortex_array::compute::Options for () diff --git a/vortex-array/src/arrays/constant/compute/binary_numeric.rs b/vortex-array/src/arrays/constant/compute/binary_numeric.rs deleted file mode 100644 index 29eae688fe4..00000000000 --- a/vortex-array/src/arrays/constant/compute/binary_numeric.rs +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; -use vortex_error::vortex_err; -use vortex_scalar::NumericOperator; - -use crate::Array; -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::ConstantArray; -use crate::arrays::ConstantVTable; -use crate::compute::NumericKernel; -use crate::compute::NumericKernelAdapter; -use crate::register_kernel; - -impl NumericKernel for ConstantVTable { - fn numeric( - &self, - array: &ConstantArray, - rhs: &dyn Array, - op: NumericOperator, - ) -> VortexResult> { - let Some(rhs) = rhs.as_constant() else { - return Ok(None); - }; - - Ok(Some( - ConstantArray::new( - array - .scalar() - .as_primitive() - .checked_binary_numeric(&rhs.as_primitive(), op) - .ok_or_else(|| vortex_err!("numeric overflow"))?, - array.len(), - ) - .into_array(), - )) - } -} - -register_kernel!(NumericKernelAdapter(ConstantVTable).lift()); diff --git a/vortex-array/src/arrays/constant/compute/mod.rs b/vortex-array/src/arrays/constant/compute/mod.rs index 22e9ab4c5b3..96e21b0efd1 100644 --- a/vortex-array/src/arrays/constant/compute/mod.rs +++ b/vortex-array/src/arrays/constant/compute/mod.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod binary_numeric; mod cast; mod compare; mod fill_null; diff --git a/vortex-array/src/arrays/dict/compute/binary_numeric.rs b/vortex-array/src/arrays/dict/compute/binary_numeric.rs deleted file mode 100644 index 3ee362ffa9d..00000000000 --- a/vortex-array/src/arrays/dict/compute/binary_numeric.rs +++ /dev/null @@ -1,168 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; -use vortex_scalar::NumericOperator; - -use super::DictArray; -use super::DictVTable; -use crate::Array; -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::ConstantArray; -use crate::compute::NumericKernel; -use crate::compute::NumericKernelAdapter; -use crate::compute::numeric; -use crate::register_kernel; - -impl NumericKernel for DictVTable { - fn numeric( - &self, - lhs: &DictArray, - rhs: &dyn Array, - op: NumericOperator, - ) -> VortexResult> { - // If we have more values than codes, it is faster to canonicalise first. - if lhs.values().len() > lhs.codes().len() { - return Ok(None); - } - - // Only push down if all values are referenced to avoid incorrect results - // See: https://github.com/vortex-data/vortex/pull/4560 - // Unchecked operation will be fine to pushdown. - if !lhs.has_all_values_referenced() { - return Ok(None); - } - - // If the RHS is constant, then we just need to apply the operation to our encoded values. - if let Some(rhs_scalar) = rhs.as_constant() { - let values_result = numeric( - lhs.values(), - ConstantArray::new(rhs_scalar, lhs.values().len()).as_ref(), - op, - )?; - - // SAFETY: values len preserved, codes all still point to valid values - // all_values_referenced preserved since operation doesn't change which values are referenced - let result = unsafe { - DictArray::new_unchecked(lhs.codes().clone(), values_result) - .set_all_values_referenced(lhs.has_all_values_referenced()) - .into_array() - }; - - return Ok(Some(result)); - } - - // It's a little more complex, but we could perform binary operations against the dictionary - // values in the future. - Ok(None) - } -} - -register_kernel!(NumericKernelAdapter(DictVTable).lift()); - -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - use vortex_error::VortexResult; - use vortex_scalar::NumericOperator; - - use crate::IntoArray; - use crate::arrays::ConstantArray; - use crate::arrays::PrimitiveArray; - use crate::arrays::dict::DictArray; - use crate::assert_arrays_eq; - use crate::compute::numeric; - - #[test] - fn test_add_const() -> VortexResult<()> { - // Create a dict with all_values_referenced = true - let dict = unsafe { - DictArray::new_unchecked( - buffer![0u32, 1, 2, 0, 1].into_array(), - buffer![10i32, 20, 30].into_array(), - ) - .set_all_values_referenced(true) - }; - - let res = numeric( - dict.as_ref(), - ConstantArray::new(5i32, 5).as_ref(), - NumericOperator::Add, - ) - .unwrap(); - - let expected = PrimitiveArray::from_iter([15i32, 25, 35, 15, 25]); - assert_arrays_eq!(res.to_canonical()?.into_array(), expected.to_array()); - Ok(()) - } - - #[test] - fn test_mul_const() -> VortexResult<()> { - // Create a dict with all_values_referenced = true - let dict = unsafe { - DictArray::new_unchecked( - buffer![0u32, 1, 2, 1, 0].into_array(), - buffer![2i32, 3, 5].into_array(), - ) - .set_all_values_referenced(true) - }; - - let res = numeric( - dict.as_ref(), - ConstantArray::new(10i32, 5).as_ref(), - NumericOperator::Mul, - ) - .unwrap(); - - let expected = PrimitiveArray::from_iter([20i32, 30, 50, 30, 20]); - assert_arrays_eq!(res.to_canonical()?.into_array(), expected.to_array()); - Ok(()) - } - - #[test] - fn test_no_pushdown_when_not_all_values_referenced() -> VortexResult<()> { - // Create a dict with all_values_referenced = false (default) - let dict = DictArray::try_new( - buffer![0u32, 1, 0, 1].into_array(), - buffer![10i32, 20, 30].into_array(), // value at index 2 is not referenced - ) - .unwrap(); - - // Should return None, indicating no pushdown - let res = numeric( - dict.as_ref(), - ConstantArray::new(5i32, 4).as_ref(), - NumericOperator::Add, - ) - .unwrap(); - - // Verify the result by canonicalizing - let expected = PrimitiveArray::from_iter([15i32, 25, 15, 25]); - assert_arrays_eq!(res.to_canonical()?.into_array(), expected.to_array()); - Ok(()) - } - - #[test] - fn test_sub_const() -> VortexResult<()> { - // Create a dict with all_values_referenced = true - let dict = unsafe { - DictArray::new_unchecked( - buffer![0u32, 1, 2].into_array(), - buffer![100i32, 50, 25].into_array(), - ) - .set_all_values_referenced(true) - }; - - let res = numeric( - dict.as_ref(), - ConstantArray::new(10i32, 3).as_ref(), - NumericOperator::Sub, - ) - .unwrap(); - - let expected = PrimitiveArray::from_iter([90i32, 40, 15]); - assert_arrays_eq!(res.to_canonical()?.into_array(), expected.to_array()); - Ok(()) - } -} diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index cf5fe445857..056b151ec06 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod binary_numeric; mod cast; mod compare; mod fill_null; diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index ab2df295b8e..7d7785f7ecb 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -99,7 +99,6 @@ pub fn warm_up_vtables() { mask::warm_up_vtable(); min_max::warm_up_vtable(); nan_count::warm_up_vtable(); - numeric::warm_up_vtable(); sum::warm_up_vtable(); zip::warm_up_vtable(); } diff --git a/vortex-array/src/compute/numeric.rs b/vortex-array/src/compute/numeric.rs index 22fe2817f59..4998f926806 100644 --- a/vortex-array/src/compute/numeric.rs +++ b/vortex-array/src/compute/numeric.rs @@ -2,14 +2,8 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::any::Any; -use std::sync::LazyLock; -use arcref::ArcRef; -use vortex_dtype::DType; -use vortex_error::VortexError; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; use vortex_scalar::NumericOperator; use vortex_scalar::Scalar; @@ -19,25 +13,7 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::arrow::Datum; use crate::arrow::from_arrow_array_with_len; -use crate::compute::ComputeFn; -use crate::compute::ComputeFnVTable; -use crate::compute::InvocationArgs; -use crate::compute::Kernel; use crate::compute::Options; -use crate::compute::Output; -use crate::vtable::VTable; - -static NUMERIC_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub(crate) fn warm_up_vtable() -> usize { - NUMERIC_FN.kernels().len() -} /// Point-wise add two numeric arrays. /// @@ -101,142 +77,7 @@ pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult { /// Point-wise numeric operation between two arrays of the same type and length. pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult { - NUMERIC_FN - .invoke(&InvocationArgs { - inputs: &[lhs.into(), rhs.into()], - options: &op, - })? - .unwrap_array() -} - -pub struct NumericKernelRef(ArcRef); -inventory::collect!(NumericKernelRef); - -pub trait NumericKernel: VTable { - fn numeric( - &self, - array: &Self::Array, - other: &dyn Array, - op: NumericOperator, - ) -> VortexResult>; -} - -#[derive(Debug)] -pub struct NumericKernelAdapter(pub V); - -impl NumericKernelAdapter { - pub const fn lift(&'static self) -> NumericKernelRef { - NumericKernelRef(ArcRef::new_ref(self)) - } -} - -impl Kernel for NumericKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let inputs = NumericArgs::try_from(args)?; - let Some(lhs) = inputs.lhs.as_opt::() else { - return Ok(None); - }; - Ok(V::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into())) - } -} - -struct Numeric; - -impl ComputeFnVTable for Numeric { - fn invoke( - &self, - args: &InvocationArgs, - kernels: &[ArcRef], - ) -> VortexResult { - let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?; - - for kernel in kernels { - if let Some(output) = kernel.invoke(args)? { - return Ok(output); - } - } - - // Check if RHS supports the operation directly. - let inverted_args = InvocationArgs { - inputs: &[rhs.into(), lhs.into()], - options: &operator.swap(), - }; - for kernel in kernels { - if let Some(output) = kernel.invoke(&inverted_args)? { - return Ok(output); - } - } - - tracing::debug!( - "No numeric implementation found for LHS {}, RHS {}, and operator {:?}", - lhs.encoding_id(), - rhs.encoding_id(), - operator, - ); - - // If neither side implements the trait, then we delegate to Arrow compute. - Ok(arrow_numeric(lhs, rhs, operator)?.into()) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?; - if !matches!( - (lhs.dtype(), rhs.dtype()), - (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..)) - ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype()) - { - vortex_bail!( - "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}", - lhs.dtype(), - rhs.dtype() - ) - } - Ok(lhs.dtype().union_nullability(rhs.dtype().nullability())) - } - - fn return_len(&self, args: &InvocationArgs) -> VortexResult { - let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?; - if lhs.len() != rhs.len() { - vortex_bail!( - "Numeric operations aren't supported on arrays of different lengths {} {}", - lhs.len(), - rhs.len() - ) - } - Ok(lhs.len()) - } - - fn is_elementwise(&self) -> bool { - true - } -} - -struct NumericArgs<'a> { - lhs: &'a dyn Array, - rhs: &'a dyn Array, - operator: NumericOperator, -} - -impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> { - type Error = VortexError; - - fn try_from(args: &InvocationArgs<'a>) -> VortexResult { - if args.inputs.len() != 2 { - vortex_bail!("Numeric operations require exactly 2 inputs"); - } - let lhs = args.inputs[0] - .array() - .ok_or_else(|| vortex_err!("LHS is not an array"))?; - let rhs = args.inputs[1] - .array() - .ok_or_else(|| vortex_err!("RHS is not an array"))?; - let operator = *args - .options - .as_any() - .downcast_ref::() - .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?; - Ok(Self { lhs, rhs, operator }) - } + arrow_numeric(lhs, rhs, op) } impl Options for NumericOperator { @@ -245,11 +86,8 @@ impl Options for NumericOperator { } } -/// Implementation of `BinaryNumericFn` using the Arrow crate. -/// -/// Note that other encodings should handle a constant RHS value, so we can assume here that -/// the RHS is not constant and expand to a full array. -fn arrow_numeric( +/// Implementation of numeric operations using the Arrow crate. +pub(crate) fn arrow_numeric( lhs: &dyn Array, rhs: &dyn Array, operator: NumericOperator, @@ -319,12 +157,4 @@ mod test { let _results = sub_scalar(&values, 1.0f32.into()).unwrap(); let _results = sub_scalar(&values, f32::MAX.into()).unwrap(); } - - #[test] - fn test_scalar_subtract_type_mismatch_fails() { - let values = buffer![1u64, 2, 3].into_array(); - // Subtracting incompatible dtypes should fail - let _results = - sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error"); - } } diff --git a/vortex-array/src/expr/exprs/binary/mod.rs b/vortex-array/src/expr/exprs/binary/mod.rs index 319bbb841a2..c66eb1a3548 100644 --- a/vortex-array/src/expr/exprs/binary/mod.rs +++ b/vortex-array/src/expr/exprs/binary/mod.rs @@ -14,11 +14,7 @@ use vortex_session::VortexSession; use crate::ArrayRef; use crate::compute; use crate::compute::BooleanOperator; -use crate::compute::add; use crate::compute::compare; -use crate::compute::div; -use crate::compute::mul; -use crate::compute::sub; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; @@ -33,6 +29,8 @@ use crate::expr::stats::Stat; mod boolean; pub(crate) use boolean::*; +mod numeric; +pub(crate) use numeric::*; pub struct Binary; @@ -118,10 +116,10 @@ impl VTable for Binary { Operator::Gte => compare(lhs, rhs, compute::Operator::Gte), Operator::And => execute_boolean(lhs, rhs, BooleanOperator::AndKleene), Operator::Or => execute_boolean(lhs, rhs, BooleanOperator::OrKleene), - Operator::Add => add(lhs, rhs), - Operator::Sub => sub(lhs, rhs), - Operator::Mul => mul(lhs, rhs), - Operator::Div => div(lhs, rhs), + Operator::Add => execute_numeric(lhs, rhs, vortex_scalar::NumericOperator::Add), + Operator::Sub => execute_numeric(lhs, rhs, vortex_scalar::NumericOperator::Sub), + Operator::Mul => execute_numeric(lhs, rhs, vortex_scalar::NumericOperator::Mul), + Operator::Div => execute_numeric(lhs, rhs, vortex_scalar::NumericOperator::Div), } } diff --git a/vortex-array/src/expr/exprs/binary/numeric.rs b/vortex-array/src/expr/exprs/binary/numeric.rs new file mode 100644 index 00000000000..fe73a0977a0 --- /dev/null +++ b/vortex-array/src/expr/exprs/binary/numeric.rs @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_scalar::NumericOperator; + +use crate::Array; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; +use crate::compute::arrow_numeric; + +/// Execute a numeric operation between two arrays. +/// +/// This is the entry point for numeric operations from the binary expression. +/// Handles constant-constant directly, otherwise falls back to Arrow. +pub(crate) fn execute_numeric( + lhs: &dyn Array, + rhs: &dyn Array, + op: NumericOperator, +) -> VortexResult { + if let Some(result) = constant_numeric(lhs, rhs, op)? { + return Ok(result); + } + arrow_numeric(lhs, rhs, op) +} + +fn constant_numeric( + lhs: &dyn Array, + rhs: &dyn Array, + op: NumericOperator, +) -> VortexResult> { + let (Some(lhs), Some(rhs)) = ( + lhs.as_opt::(), + rhs.as_opt::(), + ) else { + return Ok(None); + }; + + Ok(Some( + ConstantArray::new( + lhs.scalar() + .as_primitive() + .checked_binary_numeric(&rhs.scalar().as_primitive(), op) + .ok_or_else(|| vortex_err!("numeric overflow"))?, + lhs.len(), + ) + .into_array(), + )) +}