Source code for feast_spark.pyspark.launchers.gcloud.dataproc

import os
import time
import uuid
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import urlparse

from google.cloud.dataproc_v1 import Job, JobControllerClient, JobStatus

from feast.staging.storage_client import get_staging_client
from feast_spark.pyspark.abc import (
    BatchIngestionJob,
    BatchIngestionJobParameters,
    JobLauncher,
    RetrievalJob,
    RetrievalJobParameters,
    ScheduledBatchIngestionJobParameters,
    SparkJob,
    SparkJobFailure,
    SparkJobParameters,
    SparkJobStatus,
    SparkJobType,
    StreamIngestionJob,
    StreamIngestionJobParameters,
)


def _truncate_label(label: str) -> str:
    return label[:63]


[docs]class DataprocJobMixin: def __init__( self, job: Job, refresh_fn: Callable[[], Job], cancel_fn: Callable[[], None], project: str, region: str, ): """ Implementation of common methods for different types of SparkJob running on Dataproc cluster. Args: job (Job): Dataproc job resource. refresh_fn (Callable[[], Job]): A function that returns the latest job resource. cancel_fn (Callable[[], None]): A function that cancel the current job. """ self._job = job self._refresh_fn = refresh_fn self._cancel_fn = cancel_fn self._project = project self._region = region
[docs] def get_id(self) -> str: """ Getter for the job id. Returns: str: Dataproc job id. """ return self._job.reference.job_id
[docs] def get_status(self) -> SparkJobStatus: """ Job Status retrieval Returns: SparkJobStatus: Job status """ self._job = self._refresh_fn() status = self._job.status if status.state in ( JobStatus.State.ERROR, JobStatus.State.CANCEL_PENDING, JobStatus.State.CANCEL_STARTED, JobStatus.State.CANCELLED, ): return SparkJobStatus.FAILED elif status.state == JobStatus.State.RUNNING: return SparkJobStatus.IN_PROGRESS elif status.state in ( JobStatus.State.PENDING, JobStatus.State.SETUP_DONE, JobStatus.State.STATE_UNSPECIFIED, ): return SparkJobStatus.STARTING return SparkJobStatus.COMPLETED
[docs] def cancel(self): """ Manually terminate job """ self._cancel_fn()
[docs] def get_error_message(self) -> Optional[str]: """ Getter for the job's error message if applicable. Returns: str: Status detail of the job. Return None if the job is successful. """ self._job = self._refresh_fn() status = self._job.status if status.state == JobStatus.State.ERROR: return status.details elif status.state in ( JobStatus.State.CANCEL_PENDING, JobStatus.State.CANCEL_STARTED, JobStatus.State.CANCELLED, ): return "Job was cancelled." return None
[docs] def block_polling(self, interval_sec=30, timeout_sec=3600) -> SparkJobStatus: """ Blocks until the Dataproc job is completed or failed. Args: interval_sec (int): Polling interval. timeout_sec (int): Timeout limit. Returns: SparkJobStatus: Latest job status Raise: SparkJobFailure: Raise error if the job neither failed nor completed within the timeout limit. """ start = time.time() while True: elapsed_time = time.time() - start if timeout_sec and elapsed_time >= timeout_sec: raise SparkJobFailure( f"Job is still not completed after {timeout_sec}." ) status = self.get_status() if status in [SparkJobStatus.FAILED, SparkJobStatus.COMPLETED]: break time.sleep(interval_sec) return status
[docs] def get_start_time(self): return self._job.status.state_start_time
[docs] def get_log_uri(self) -> Optional[str]: return ( f"https://console.cloud.google.com/dataproc/jobs/{self.get_id()}" f"?region={self._region}&project={self._project}" )
[docs]class DataprocRetrievalJob(DataprocJobMixin, RetrievalJob): """ Historical feature retrieval job result for a Dataproc cluster """ def __init__( self, job: Job, refresh_fn: Callable[[], Job], cancel_fn: Callable[[], None], project: str, region: str, output_file_uri: str, ): """ This is the returned historical feature retrieval job result for DataprocClusterLauncher. Args: output_file_uri (str): Uri to the historical feature retrieval job output file. """ super().__init__(job, refresh_fn, cancel_fn, project, region) self._output_file_uri = output_file_uri
[docs] def get_output_file_uri(self, timeout_sec=None, block=True): if not block: return self._output_file_uri status = self.block_polling(timeout_sec=timeout_sec) if status == SparkJobStatus.COMPLETED: return self._output_file_uri raise SparkJobFailure(self.get_error_message())
[docs]class DataprocBatchIngestionJob(DataprocJobMixin, BatchIngestionJob): """ Batch Ingestion job result for a Dataproc cluster """
[docs] def get_feature_table(self) -> str: return self._job.labels.get(DataprocClusterLauncher.FEATURE_TABLE_LABEL_KEY, "")
[docs]class DataprocStreamingIngestionJob(DataprocJobMixin, StreamIngestionJob): """ Streaming Ingestion job result for a Dataproc cluster """ def __init__( self, job: Job, refresh_fn: Callable[[], Job], cancel_fn: Callable[[], None], project: str, region: str, job_hash: str, ) -> None: super().__init__(job, refresh_fn, cancel_fn, project, region) self._job_hash = job_hash
[docs] def get_hash(self) -> str: return self._job_hash
[docs] def get_feature_table(self) -> str: return self._job.labels.get(DataprocClusterLauncher.FEATURE_TABLE_LABEL_KEY, "")
[docs]class DataprocClusterLauncher(JobLauncher): """ Submits jobs to an existing Dataproc cluster. Depends on google-cloud-dataproc and google-cloud-storage, which are optional dependencies that the user has to installed in addition to the Feast SDK. """ EXTERNAL_JARS = ["gs://spark-lib/bigquery/spark-bigquery-latest_2.12.jar"] JOB_TYPE_LABEL_KEY = "feast_job_type" JOB_HASH_LABEL_KEY = "feast_job_hash" FEATURE_TABLE_LABEL_KEY = "feast_feature_tables" PROJECT_LABEL_KEY = "feast_project" def __init__( self, cluster_name: str, staging_location: str, region: str, project_id: str, executor_instances: str, executor_cores: str, executor_memory: str, ): """ Initialize a dataproc job controller client, used internally for job submission and result retrieval. Args: cluster_name (str): Dataproc cluster name. staging_location (str): GCS directory for the storage of files generated by the launcher, such as the pyspark scripts. region (str): Dataproc cluster region. project_id (str): GCP project id for the dataproc cluster. executor_instances (str): Number of executor instances for dataproc job. executor_cores (str): Number of cores for dataproc job. executor_memory (str): Amount of memory for dataproc job. """ self.cluster_name = cluster_name scheme, self.staging_bucket, self.remote_path, _, _, _ = urlparse( staging_location ) if scheme != "gs": raise ValueError( "Only GCS staging location is supported for DataprocLauncher." ) self.project_id = project_id self.region = region self.job_client = JobControllerClient( client_options={"api_endpoint": f"{region}-dataproc.googleapis.com:443"} ) self.executor_instances = executor_instances self.executor_cores = executor_cores self.executor_memory = executor_memory def _stage_file(self, file_path: str, job_id: str) -> str: if not os.path.isfile(file_path): return file_path staging_client = get_staging_client("gs") blob_path = os.path.join( self.remote_path, job_id, os.path.basename(file_path), ).lstrip("/") blob_uri_str = f"gs://{self.staging_bucket}/{blob_path}" with open(file_path, "rb") as f: staging_client.upload_fileobj( f, file_path, remote_uri=urlparse(blob_uri_str) ) return blob_uri_str
[docs] def dataproc_submit( self, job_params: SparkJobParameters, extra_properties: Dict[str, str] ) -> Tuple[Job, Callable[[], Job], Callable[[], None]]: local_job_id = str(uuid.uuid4()) main_file_uri = self._stage_file(job_params.get_main_file_path(), local_job_id) job_config: Dict[str, Any] = { "reference": {"job_id": local_job_id}, "placement": {"cluster_name": self.cluster_name}, "labels": {self.JOB_TYPE_LABEL_KEY: job_params.get_job_type().name.lower()}, } maven_package_properties = { "spark.jars.packages": ",".join(job_params.get_extra_packages()) } common_properties = { "spark.executor.instances": self.executor_instances, "spark.executor.cores": self.executor_cores, "spark.executor.memory": self.executor_memory, } if isinstance(job_params, StreamIngestionJobParameters): job_config["labels"][self.FEATURE_TABLE_LABEL_KEY] = _truncate_label( job_params.get_feature_table_name() ) # Add job hash to labels only for the stream ingestion job job_config["labels"][self.JOB_HASH_LABEL_KEY] = job_params.get_job_hash() job_config["labels"][self.PROJECT_LABEL_KEY] = job_params.get_project() if isinstance(job_params, BatchIngestionJobParameters): job_config["labels"][self.FEATURE_TABLE_LABEL_KEY] = _truncate_label( job_params.get_feature_table_name() ) job_config["labels"][self.PROJECT_LABEL_KEY] = job_params.get_project() if job_params.get_class_name(): scala_job_properties = { "spark.yarn.user.classpath.first": "true", "spark.executor.instances": self.executor_instances, "spark.executor.cores": self.executor_cores, "spark.executor.memory": self.executor_memory, "spark.pyspark.driver.python": "python3.7", "spark.pyspark.python": "python3.7", } job_config.update( { "spark_job": { "jar_file_uris": [main_file_uri] + self.EXTERNAL_JARS, "main_class": job_params.get_class_name(), "args": job_params.get_arguments(), "properties": { **scala_job_properties, **common_properties, **maven_package_properties, **extra_properties, }, } } ) else: job_config.update( { "pyspark_job": { "main_python_file_uri": main_file_uri, "jar_file_uris": self.EXTERNAL_JARS, "args": job_params.get_arguments(), "properties": { **common_properties, **maven_package_properties, **extra_properties, }, } } ) job = self.job_client.submit_job( request={ "project_id": self.project_id, "region": self.region, "job": job_config, } ) refresh_fn = partial( self.job_client.get_job, project_id=self.project_id, region=self.region, job_id=job.reference.job_id, ) cancel_fn = partial(self.dataproc_cancel, job.reference.job_id) return job, refresh_fn, cancel_fn
[docs] def dataproc_cancel(self, job_id): self.job_client.cancel_job( project_id=self.project_id, region=self.region, job_id=job_id )
[docs] def historical_feature_retrieval( self, job_params: RetrievalJobParameters ) -> RetrievalJob: job, refresh_fn, cancel_fn = self.dataproc_submit( job_params, {"dev.feast.outputuri": job_params.get_destination_path()} ) return DataprocRetrievalJob( job=job, refresh_fn=refresh_fn, cancel_fn=cancel_fn, project=self.project_id, region=self.region, output_file_uri=job_params.get_destination_path(), )
[docs] def offline_to_online_ingestion( self, ingestion_job_params: BatchIngestionJobParameters ) -> BatchIngestionJob: job, refresh_fn, cancel_fn = self.dataproc_submit(ingestion_job_params, {}) return DataprocBatchIngestionJob( job=job, refresh_fn=refresh_fn, cancel_fn=cancel_fn, project=self.project_id, region=self.region, )
[docs] def schedule_offline_to_online_ingestion( self, ingestion_job_params: ScheduledBatchIngestionJobParameters ): raise NotImplementedError( "Schedule spark jobs are not supported by dataproc launcher" )
[docs] def unschedule_offline_to_online_ingestion(self, project: str, feature_table: str): raise NotImplementedError( "Unschedule spark jobs are not supported by dataproc launcher" )
[docs] def start_stream_to_online_ingestion( self, ingestion_job_params: StreamIngestionJobParameters ) -> StreamIngestionJob: job, refresh_fn, cancel_fn = self.dataproc_submit(ingestion_job_params, {}) job_hash = ingestion_job_params.get_job_hash() return DataprocStreamingIngestionJob( job=job, refresh_fn=refresh_fn, cancel_fn=cancel_fn, project=self.project_id, region=self.region, job_hash=job_hash, )
[docs] def get_job_by_id(self, job_id: str) -> SparkJob: job = self.job_client.get_job( project_id=self.project_id, region=self.region, job_id=job_id ) return self._dataproc_job_to_spark_job(job)
def _dataproc_job_to_spark_job(self, job: Job) -> SparkJob: job_type = job.labels[self.JOB_TYPE_LABEL_KEY] job_id = job.reference.job_id refresh_fn = partial( self.job_client.get_job, project_id=self.project_id, region=self.region, job_id=job_id, ) cancel_fn = partial(self.dataproc_cancel, job_id) if job_type == SparkJobType.HISTORICAL_RETRIEVAL.name.lower(): output_path = job.pyspark_job.properties.get("dev.feast.outputuri", "") return DataprocRetrievalJob( job=job, refresh_fn=refresh_fn, cancel_fn=cancel_fn, project=self.project_id, region=self.region, output_file_uri=output_path, ) if job_type == SparkJobType.BATCH_INGESTION.name.lower(): return DataprocBatchIngestionJob( job=job, refresh_fn=refresh_fn, cancel_fn=cancel_fn, project=self.project_id, region=self.region, ) if job_type == SparkJobType.STREAM_INGESTION.name.lower(): job_hash = job.labels[self.JOB_HASH_LABEL_KEY] return DataprocStreamingIngestionJob( job=job, refresh_fn=refresh_fn, cancel_fn=cancel_fn, project=self.project_id, region=self.region, job_hash=job_hash, ) raise ValueError(f"Unrecognized job type: {job_type}")
[docs] def list_jobs( self, include_terminated: bool, project: Optional[str] = None, table_name: Optional[str] = None, ) -> List[SparkJob]: job_filter = f"labels.{self.JOB_TYPE_LABEL_KEY} = * AND clusterName = {self.cluster_name}" if project: job_filter = ( job_filter + f" AND labels.{self.PROJECT_LABEL_KEY} = {project}" ) if table_name: job_filter = ( job_filter + f" AND labels.{self.FEATURE_TABLE_LABEL_KEY} = {_truncate_label(table_name)}" ) if not include_terminated: job_filter = job_filter + " AND status.state = ACTIVE" return [ self._dataproc_job_to_spark_job(job) for job in self.job_client.list_jobs( project_id=self.project_id, region=self.region, filter=job_filter ) ]