diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py index 38ebc1c70..660cbf482 100644 --- a/loopy/library/reduction.py +++ b/loopy/library/reduction.py @@ -368,7 +368,7 @@ def __call__(self, from loopy.translation_unit import add_callable_to_table # getting the callable 'max' from target - max_scalar_callable = target.get_device_ast_builder().known_callables["max"] + max_scalar_callable = target.known_device_callables["max"] # type specialize the callable max_scalar_callable, callables_table = max_scalar_callable.with_types( @@ -404,7 +404,7 @@ def __call__(self, from loopy.translation_unit import add_callable_to_table # getting the callable 'min' from target - min_scalar_callable = target.get_device_ast_builder().known_callables["min"] + min_scalar_callable = target.known_device_callables["min"] # type specialize the callable min_scalar_callable, callables_table = min_scalar_callable.with_types( diff --git a/loopy/target/__init__.py b/loopy/target/__init__.py index 9476c0e2a..ec6eadca0 100644 --- a/loopy/target/__init__.py +++ b/loopy/target/__init__.py @@ -191,6 +191,24 @@ def get_kernel_executor( """ raise NotImplementedError() + @property + def known_host_callables(self): + """ + Returns a mapping from function ids to corresponding + :class:`loopy.kernel.function_interface.InKernelCallable` for the + function ids known to *self* for host code generation. + """ + return {} + + @property + def known_device_callables(self): + """ + Returns a mapping from function ids to corresponding + :class:`loopy.kernel.function_interface.InKernelCallable` for the + function ids known to *self* for device code generation. + """ + return {} + @dataclass(frozen=True) class ASTBuilderBase(ABC, Generic[ASTType]): @@ -206,10 +224,9 @@ def known_callables(self): """ Returns a mapping from function ids to corresponding :class:`loopy.kernel.function_interface.InKernelCallable` for the - function ids known to *self.target*. + function ids known to *self.target* for device code generation. """ - # FIXME: @inducer: Do we need to move this to TargetBase? - return {} + return dict(self.target.known_device_callables) def symbol_manglers(self): return [] @@ -351,6 +368,10 @@ def __str__(self): class DummyHostASTBuilder(ASTBuilderBase[None]): + @property + def known_callables(self): + return dict(self.target.known_host_callables) + def get_function_definition(self, codegen_state, codegen_result, schedule_index, function_decl, function_body): return function_body diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index a2885b7bf..6c5520bfc 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -509,6 +509,13 @@ def get_host_ast_builder(self): def get_device_ast_builder(self): return CFamilyASTBuilder(self) + @property + @override + def known_device_callables(self): + callables = super().known_device_callables + callables.update(get_c_callables()) + return callables + # {{{ types @memoize_method @@ -890,13 +897,6 @@ def preamble_generators(self): lambda preamble_info: _preamble_generator( preamble_info, self.preamble_function_qualifier)]) - @property - @override - def known_callables(self): - callables = super().known_callables - callables.update(get_c_callables()) - return callables - # }}} # {{{ code generation @@ -1606,19 +1606,29 @@ class CWithGNULibcTarget(CTarget): def get_device_ast_builder(self): return CWithGNULibcASTBuilder(self) - -class CWithGNULibcASTBuilder(CASTBuilder): @property - def known_callables(self): - callables = super().known_callables + @override + def known_device_callables(self): + callables = super().known_device_callables callables.update(get_gnu_libc_callables()) return callables +class CWithGNULibcASTBuilder(CASTBuilder): + pass + + class ExecutableCWithGNULibcTarget(ExecutableCTarget): def get_device_ast_builder(self): return CWithGNULibcASTBuilder(self) + @property + @override + def known_device_callables(self): + callables = super().known_device_callables + callables.update(get_gnu_libc_callables()) + return callables + # }}} # vim: foldmethod=marker diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py index 23cd2dab0..5e07da951 100644 --- a/loopy/target/cuda.py +++ b/loopy/target/cuda.py @@ -249,6 +249,12 @@ def split_kernel_at_global_barriers(self): def get_device_ast_builder(self): return CUDACASTBuilder(self) + @property + def known_device_callables(self): + callables = super().known_device_callables + callables.update(get_cuda_callables()) + return callables + # {{{ types @memoize_method @@ -330,16 +336,6 @@ class CUDACASTBuilder(CFamilyASTBuilder): preamble_function_qualifier = "inline __device__" - # {{{ library - - @property - def known_callables(self): - callables = super().known_callables - callables.update(get_cuda_callables()) - return callables - - # }}} - # {{{ top-level codegen def get_function_declaration( diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index ed05c7628..e434d608a 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -632,6 +632,13 @@ def split_kernel_at_global_barriers(self): def get_device_ast_builder(self): return OpenCLCASTBuilder(self) + @property + @override + def known_device_callables(self): + callables = super().known_device_callables + callables.update(get_opencl_callables()) + return callables + @memoize_method def get_dtype_registry(self) -> DTypeRegistry: from loopy.target.c.compyte.dtypes import ( @@ -673,12 +680,6 @@ def vector_dtype(self, base, count): class OpenCLCASTBuilder(CFamilyASTBuilder): # {{{ library - @property - def known_callables(self): - callables = super().known_callables - callables.update(get_opencl_callables()) - return callables - def symbol_manglers(self): return ( [*super().symbol_manglers(), opencl_symbol_mangler]) diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index 9d88394c9..dbea17765 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -584,6 +584,26 @@ def get_host_ast_builder(self): def get_device_ast_builder(self): return PyOpenCLCASTBuilder(self) + @property + @override + def known_device_callables(self): + from loopy.library.random123 import get_random123_callables + + # order matters: e.g. prefer our abs() over that of the + # superclass + callables = super().known_device_callables + callables.update(get_pyopencl_callables()) + callables.update(get_random123_callables(self)) + return callables + + @property + @override + def known_host_callables(self): + from loopy.target.c import get_c_callables + callables = super().known_host_callables + callables.update(get_c_callables()) + return callables + # {{{ types @override @@ -1224,17 +1244,6 @@ def get_function_declaration( # {{{ library - @property - def known_callables(self): - from loopy.library.random123 import get_random123_callables - - # order matters: e.g. prefer our abs() over that of the - # superclass - callables = super().known_callables - callables.update(get_pyopencl_callables()) - callables.update(get_random123_callables(self.target)) - return callables - def preamble_generators(self): return ([pyopencl_preamble_generator, *super().preamble_generators()]) diff --git a/loopy/target/python.py b/loopy/target/python.py index 3b4b9795f..5de59b38f 100644 --- a/loopy/target/python.py +++ b/loopy/target/python.py @@ -192,10 +192,7 @@ class PythonASTBuilderBase(ASTBuilderBase[Generable]): @property def known_callables(self): - from loopy.target.c import get_c_callables - callables = super().known_callables - callables.update(get_c_callables()) - return callables + return dict(self.target.known_host_callables) def preamble_generators(self): return ( diff --git a/loopy/tools.py b/loopy/tools.py index 8ab419585..685293558 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -674,7 +674,7 @@ def __init__(self, rule_mapping_context, callables_table, target): @cached_property def known_callables(self): from loopy.kernel.function_interface import CallableKernel - return (frozenset(self.target.get_device_ast_builder().known_callables) + return (frozenset(self.target.known_device_callables) | {name for name, clbl in self.callables_table.items() if isinstance(clbl, CallableKernel)}) diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 9d27e24a4..cfb2e9b12 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -873,7 +873,7 @@ def resolve_callables(t_unit: TranslationUnit) -> TranslationUnit: # get registered callables known_callables = dict(t_unit.callables_table) # get target specific callables - known_callables.update(t_unit.target.get_device_ast_builder().known_callables) + known_callables.update(t_unit.target.known_device_callables) # get loopy specific callables known_callables.update(get_loopy_callables())