Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions backend/app/api/routes/stt_evaluations/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from app.api.deps import AuthContextDep, SessionDep
from app.api.permissions import Permission, require_permission
from app.celery.utils import start_low_priority_job
from app.celery.utils import start_stt_batch_submission
from app.core.cloud import get_cloud_storage
from app.crud.stt_evaluations import (
create_stt_run,
Expand Down Expand Up @@ -83,8 +83,7 @@ def start_stt_evaluation(
# Offload batch submission (signed URLs, JSONL, Gemini upload) to Celery worker
trace_id = correlation_id.get() or "N/A"
try:
celery_task_id = start_low_priority_job(
function_path="app.services.stt_evaluations.batch_job.execute_batch_submission",
celery_task_id = start_stt_batch_submission(
project_id=auth_context.project_.id,
job_id=str(run.id),
trace_id=trace_id,
Expand Down
5 changes: 2 additions & 3 deletions backend/app/api/routes/tts_evaluations/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from app.api.deps import AuthContextDep, SessionDep
from app.api.permissions import Permission, require_permission
from app.celery.utils import start_low_priority_job
from app.celery.utils import start_tts_batch_submission
from app.core.cloud import get_cloud_storage
from app.crud.tts_evaluations import (
create_tts_run,
Expand Down Expand Up @@ -86,8 +86,7 @@ def start_tts_evaluation(
# Offload batch submission (result creation, JSONL, Gemini upload) to Celery worker
trace_id = correlation_id.get() or "N/A"
try:
celery_task_id = start_low_priority_job(
function_path="app.services.tts_evaluations.batch_job.execute_batch_submission",
celery_task_id = start_tts_batch_submission(
project_id=auth_context.project_.id,
job_id=str(run.id),
trace_id=trace_id,
Expand Down
14 changes: 2 additions & 12 deletions backend/app/celery/celery_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,10 @@
Queue("cron", exchange=default_exchange, routing_key="cron"),
Queue("default", exchange=default_exchange, routing_key="default"),
),
# Task routing
# Task routing — queue is set per-task via @celery_app.task(queue=...).
# Only cron tasks need an explicit override here.
task_routes={
"app.celery.tasks.job_execution.execute_high_priority_task": {
"queue": "high_priority",
"priority": 9,
},
"app.celery.tasks.job_execution.execute_low_priority_task": {
"queue": "low_priority",
"priority": 1,
},
"app.celery.tasks.*_cron_*": {"queue": "cron"},
"app.celery.tasks.*": {"queue": "default"},
},
task_default_queue="default",
# Enable priority support
Expand Down Expand Up @@ -93,5 +85,3 @@
broker_pool_limit=settings.CELERY_BROKER_POOL_LIMIT,
)

# Auto-discover tasks
# celery_app.autodiscover_tasks()
268 changes: 139 additions & 129 deletions backend/app/celery/tasks/job_execution.py
Original file line number Diff line number Diff line change
@@ -1,143 +1,153 @@
import logging
from collections.abc import Callable
from celery import current_task

from asgi_correlation_id import correlation_id
from celery import current_task

from app.celery.celery_app import celery_app
import app.services.llm.jobs as _llm_jobs
import app.services.response.jobs as _response_jobs
import app.services.doctransform.job as _doctransform_job
import app.services.collections.create_collection as _create_collection
import app.services.collections.delete_collection as _delete_collection
import app.services.stt_evaluations.batch_job as _stt_batch_job
import app.services.stt_evaluations.metric_job as _stt_metric_job
import app.services.tts_evaluations.batch_job as _tts_batch_job
import app.services.tts_evaluations.batch_result_processing as _tts_result_processing

logger = logging.getLogger(__name__)

# Hardcoded dispatch table — avoids dynamic importlib at task execution time.
# Imports above happen once in the main Celery process before worker forks,
# so all child workers inherit them via copy-on-write instead of each loading
# them independently (which was causing OOM with warmup_job_modules).
_FUNCTION_REGISTRY: dict[str, Callable] = {
"app.services.llm.jobs.execute_job": _llm_jobs.execute_job,
"app.services.llm.jobs.execute_chain_job": _llm_jobs.execute_chain_job,
"app.services.response.jobs.execute_job": _response_jobs.execute_job,
"app.services.doctransform.job.execute_job": _doctransform_job.execute_job,
"app.services.collections.create_collection.execute_job": _create_collection.execute_job,
"app.services.collections.delete_collection.execute_job": _delete_collection.execute_job,
"app.services.stt_evaluations.batch_job.execute_batch_submission": _stt_batch_job.execute_batch_submission,
"app.services.stt_evaluations.metric_job.execute_metric_computation": _stt_metric_job.execute_metric_computation,
"app.services.tts_evaluations.batch_job.execute_batch_submission": _tts_batch_job.execute_batch_submission,
"app.services.tts_evaluations.batch_result_processing.execute_tts_result_processing": _tts_result_processing.execute_tts_result_processing,
}


@celery_app.task(bind=True, queue="high_priority")
def execute_high_priority_task(
self,
function_path: str,
project_id: int,
job_id: str,
trace_id: str,
**kwargs,
):
"""
High priority Celery task to execute any job function.
Use this for urgent operations that need immediate processing.

Args:
function_path: Import path to the execute_job function (e.g., "app.services.doctransform.service.execute_job")
project_id: ID of the project executing the job
job_id: ID of the job (should already exist in database)
trace_id: Trace/correlation ID to preserve context across Celery tasks
**kwargs: Additional arguments to pass to the execute_job function
"""
return _execute_job_internal(
self, function_path, project_id, job_id, "high_priority", trace_id, **kwargs

def _set_trace(trace_id: str) -> None:
correlation_id.set(trace_id)
logger.info(f"[_set_trace] Set correlation ID: {trace_id}")


@celery_app.task(bind=True, queue="high_priority", priority=9)
def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.llm.jobs import execute_job

_set_trace(trace_id)
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority")
def execute_low_priority_task(
self,
function_path: str,
project_id: int,
job_id: str,
trace_id: str,
**kwargs,
):
"""
Low priority Celery task to execute any job function.
Use this for background operations that can wait.

Args:
function_path: Import path to the execute_job function (e.g., "app.services.doctransform.service.execute_job")
project_id: ID of the project executing the job
job_id: ID of the job (should already exist in database)
trace_id: Trace/correlation ID to preserve context across Celery tasks
**kwargs: Additional arguments to pass to the execute_job function
"""
return _execute_job_internal(
self, function_path, project_id, job_id, "low_priority", trace_id, **kwargs
@celery_app.task(bind=True, queue="high_priority", priority=9)
def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.llm.jobs import execute_chain_job

_set_trace(trace_id)
return execute_chain_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


def _execute_job_internal(
task_instance,
function_path: str,
project_id: int,
job_id: str,
priority: str,
trace_id: str,
**kwargs,
):
"""
Internal function to execute job logic for both priority levels.

Args:
task_instance: Celery task instance (for progress updates, retries, etc.)
function_path: Import path to the execute_job function
project_id: ID of the project executing the job
job_id: ID of the job (should already exist in database)
priority: Priority level ("high_priority" or "low_priority")
trace_id: Trace/correlation ID to preserve context across Celery tasks
**kwargs: Additional arguments to pass to the execute_job function
"""
task_id = current_task.request.id
@celery_app.task(bind=True, queue="high_priority", priority=9)
def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.response.jobs import execute_job

correlation_id.set(trace_id)
logger.info(f"Set correlation ID context: {trace_id} for job {job_id}")

try:
execute_function = _FUNCTION_REGISTRY.get(function_path)
if execute_function is None:
raise ValueError(
f"[_execute_job_internal] Unknown function path: {function_path}"
)

logger.info(
f"Executing {priority} job {job_id} (task {task_id}) using function {function_path}"
)

# Execute the business logic function with standardized parameters
result = execute_function(
project_id=project_id,
job_id=job_id,
task_id=task_id,
task_instance=task_instance, # For progress updates, retries if needed
**kwargs,
)

logger.info(
f"{priority.capitalize()} job {job_id} (task {task_id}) completed successfully"
)
return result

except Exception as exc:
logger.error(
f"{priority.capitalize()} job {job_id} (task {task_id}) failed: {exc}",
exc_info=True,
)
raise
_set_trace(trace_id)
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.doctransform.job import execute_job

_set_trace(trace_id)
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
def run_create_collection_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.collections.create_collection import execute_job

_set_trace(trace_id)
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
def run_delete_collection_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.collections.delete_collection import execute_job

_set_trace(trace_id)
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
def run_stt_batch_submission(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.stt_evaluations.batch_job import execute_batch_submission

_set_trace(trace_id)
return execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
def run_stt_metric_computation(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.stt_evaluations.metric_job import execute_metric_computation

_set_trace(trace_id)
return execute_metric_computation(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
def run_tts_batch_submission(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.tts_evaluations.batch_job import execute_batch_submission

_set_trace(trace_id)
return execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
def run_tts_result_processing(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.tts_evaluations.batch_result_processing import execute_tts_result_processing

_set_trace(trace_id)
return execute_tts_result_processing(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)
Loading
Loading