Source code for

import os
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Optional
from urllib.parse import urlunparse

import boto3
from botocore.config import Config as BotoConfig

from feast.staging.storage_client import get_staging_client
from import (

from .emr_utils import (

[docs]class EmrJobMixin: def __init__(self, emr_client, job_ref: EmrJobRef): """ Args: emr_client: boto3 emr client job_ref: job reference """ self._job_ref = job_ref self._emr_client = emr_client
[docs] def get_id(self) -> str: return _job_ref_to_str(self._job_ref)
[docs] def get_status(self) -> SparkJobStatus: emr_state = _get_job_state(self._emr_client, self._job_ref) if emr_state in IN_PROGRESS_STEP_STATES: return SparkJobStatus.IN_PROGRESS elif emr_state in SUCCEEDED_STEP_STATES: return SparkJobStatus.COMPLETED elif emr_state in FAILED_STEP_STATES: return SparkJobStatus.FAILED else: # we should never get here raise Exception("Invalid EMR state")
[docs] def cancel(self): _cancel_job(self._emr_client, self._job_ref)
[docs] def get_start_time(self) -> datetime: return _get_job_creation_time(self._emr_client, self._job_ref)
[docs]class EmrRetrievalJob(EmrJobMixin, RetrievalJob): """ Historical feature retrieval job result for a EMR cluster """ def __init__(self, emr_client, job_ref: EmrJobRef, output_file_uri: str): """ This is the job object representing the historical retrieval job, returned by EmrClusterLauncher. Args: output_file_uri (str): Uri to the historical feature retrieval job output file. """ super().__init__(emr_client, job_ref) self._output_file_uri = output_file_uri
[docs] def get_output_file_uri(self, timeout_sec=None, block=True): if not block: return self._output_file_uri state = _wait_for_job_state( self._emr_client, self._job_ref, TERMINAL_STEP_STATES, timeout_sec ) if state in SUCCEEDED_STEP_STATES: return self._output_file_uri else: raise SparkJobFailure("Spark job failed")
[docs]class EmrBatchIngestionJob(EmrJobMixin, BatchIngestionJob): """ Ingestion job result for a EMR cluster """ def __init__(self, emr_client, job_ref: EmrJobRef, project: str, table_name: str): super().__init__(emr_client, job_ref) self._project = project self._table_name = table_name
[docs] def get_feature_table(self) -> str: return self._table_name
[docs] def get_project(self) -> str: return self._project
[docs]class EmrStreamIngestionJob(EmrJobMixin, StreamIngestionJob): """ Ingestion streaming job for a EMR cluster """ def __init__( self, emr_client, job_ref: EmrJobRef, job_hash: str, project: str, table_name: str, ): super().__init__(emr_client, job_ref) self._job_hash = job_hash self._project = project self._table_name = table_name
[docs] def get_hash(self) -> str: return self._job_hash
[docs] def get_feature_table(self) -> str: return self._table_name
[docs]class EmrClusterLauncher(JobLauncher): """ Submits jobs to an existing or new EMR cluster. Requires boto3 as an additional dependency. """ _existing_cluster_id: Optional[str] _new_cluster_template: Optional[Dict[str, Any]] _staging_location: str _emr_log_location: str _region: str def __init__( self, *, region: str, existing_cluster_id: Optional[str], new_cluster_template_path: Optional[str], staging_location: str, emr_log_location: str, ): """ Initialize a dataproc job controller client, used internally for job submission and result retrieval. Can work with either an existing EMR cluster, or create a cluster on-demand for each job. Args: region (str): AWS region name. existing_cluster_id (str): Existing EMR cluster id, if using an existing cluster. new_cluster_template_path (str): Path to yaml new cluster template, if using a new cluster. staging_location: An S3 staging location for artifacts. emr_log_location: S3 location for EMR logs. """ assert existing_cluster_id or new_cluster_template_path self._existing_cluster_id = existing_cluster_id if new_cluster_template_path: self._new_cluster_template = _load_new_cluster_template( new_cluster_template_path ) else: self._new_cluster_template = None self._staging_location = staging_location self._emr_log_location = emr_log_location self._region = region def _emr_client(self): # Use an increased number of retries since DescribeStep calls have a pretty low rate limit. config = BotoConfig(retries={"max_attempts": 10, "mode": "standard"}) return boto3.client("emr", region_name=self._region, config=config) def _submit_emr_job(self, step: Dict[str, Any]) -> EmrJobRef: """ Submit EMR job using a new or existing cluster. Returns a job reference (cluster_id and step_id). """ emr = self._emr_client() if self._existing_cluster_id: step["ActionOnFailure"] = "CONTINUE" step_ids = emr.add_job_flow_steps( JobFlowId=self._existing_cluster_id, Steps=[step], ) return EmrJobRef(self._existing_cluster_id, step_ids["StepIds"][0]) else: assert self._new_cluster_template is not None jobTemplate = self._new_cluster_template step["ActionOnFailure"] = "TERMINATE_CLUSTER" jobTemplate["Steps"] = [step] if self._emr_log_location: jobTemplate["LogUri"] = os.path.join( self._emr_log_location, _random_string(5) ) job = emr.run_job_flow(**jobTemplate) return EmrJobRef(job["JobFlowId"], None)
[docs] def historical_feature_retrieval( self, job_params: RetrievalJobParameters ) -> RetrievalJob: with open(job_params.get_main_file_path()) as f: pyspark_script = pyspark_script_path = urlunparse( get_staging_client("s3").upload_fileobj( BytesIO(pyspark_script.encode("utf8")), local_path="", remote_path_prefix=self._staging_location, remote_path_suffix=".py", ) ) step = _historical_retrieval_step( pyspark_script_path, args=job_params.get_arguments(), output_file_uri=job_params.get_destination_path(), packages=job_params.get_extra_packages(), ) job_ref = self._submit_emr_job(step) return EmrRetrievalJob( self._emr_client(), job_ref, job_params.get_destination_path(), )
[docs] def offline_to_online_ingestion( self, ingestion_job_params: BatchIngestionJobParameters ) -> BatchIngestionJob: """ Submits a batch ingestion job to a Spark cluster. Raises: SparkJobFailure: The spark job submission failed, encountered error during execution, or timeout. Returns: BatchIngestionJob: wrapper around remote job that can be used to check when job completed. """ jar_s3_path = _upload_jar( self._staging_location, ingestion_job_params.get_main_file_path() ) step = _sync_offline_to_online_step( jar_path=jar_s3_path, project=ingestion_job_params.get_project(), feature_table_name=ingestion_job_params.get_feature_table_name(), args=ingestion_job_params.get_arguments(), ) job_ref = self._submit_emr_job(step) return EmrBatchIngestionJob( self._emr_client(), job_ref, ingestion_job_params.get_project(), ingestion_job_params.get_feature_table_name(), )
[docs] def schedule_offline_to_online_ingestion( self, ingestion_job_params: ScheduledBatchIngestionJobParameters ): raise NotImplementedError( "Schedule spark jobs are not supported by emr launcher" )
[docs] def unschedule_offline_to_online_ingestion(self, project: str, feature_table: str): raise NotImplementedError( "Unschedule spark jobs are not supported by emr launcher" )
[docs] def start_stream_to_online_ingestion( self, ingestion_job_params: StreamIngestionJobParameters ) -> StreamIngestionJob: """ Starts a stream ingestion job on a Spark cluster. Returns: StreamIngestionJob: wrapper around remote job that can be used to check on the job. """ jar_s3_path = _upload_jar( self._staging_location, ingestion_job_params.get_main_file_path() ) extra_jar_paths: List[str] = [] for extra_jar in ingestion_job_params.get_extra_jar_paths(): if extra_jar.startswith("s3://"): extra_jar_paths.append(extra_jar) else: extra_jar_paths.append(_upload_jar(self._staging_location, extra_jar)) job_hash = ingestion_job_params.get_job_hash() step = _stream_ingestion_step( jar_path=jar_s3_path, extra_jar_paths=extra_jar_paths, project=ingestion_job_params.get_project(), feature_table_name=ingestion_job_params.get_feature_table_name(), args=ingestion_job_params.get_arguments(), job_hash=job_hash, ) job_ref = self._submit_emr_job(step) return EmrStreamIngestionJob( self._emr_client(), job_ref, job_hash, ingestion_job_params.get_project(), ingestion_job_params.get_feature_table_name(), )
def _job_from_job_info(self, job_info: JobInfo) -> SparkJob: if job_info.job_type == HISTORICAL_RETRIEVAL_JOB_TYPE: assert job_info.output_file_uri is not None return EmrRetrievalJob( emr_client=self._emr_client(), job_ref=job_info.job_ref, output_file_uri=job_info.output_file_uri, ) elif job_info.job_type == OFFLINE_TO_ONLINE_JOB_TYPE: project = job_info.project if job_info.project else "" table_name = job_info.table_name if job_info.table_name else "" assert project is not None assert table_name is not None return EmrBatchIngestionJob( emr_client=self._emr_client(), job_ref=job_info.job_ref, project=project, table_name=table_name, ) elif job_info.job_type == STREAM_TO_ONLINE_JOB_TYPE: project = job_info.project if job_info.project else "" table_name = job_info.table_name if job_info.table_name else "" assert project is not None assert table_name is not None # job_hash must not be None for stream ingestion jobs assert job_info.job_hash is not None return EmrStreamIngestionJob( emr_client=self._emr_client(), job_ref=job_info.job_ref, job_hash=job_info.job_hash, project=project, table_name=table_name, ) else: # We should never get here raise ValueError(f"Unknown job type {job_info.job_type}")
[docs] def list_jobs( self, include_terminated: bool, project: Optional[str] = None, table_name: Optional[str] = None, ) -> List[SparkJob]: """ Find EMR job by a string id. Args: include_terminated: whether to include terminated jobs. table_name: FeatureTable name to filter by Returns: A list of SparkJob instances. """ jobs = _list_jobs( emr_client=self._emr_client(), job_type=None, project=None, table_name=table_name, active_only=not include_terminated, ) result = [] for job_info in jobs: result.append(self._job_from_job_info(job_info)) return result
[docs] def get_job_by_id(self, job_id: str) -> SparkJob: """ Find EMR job by a string id. Note that it will also return terminated jobs. Raises: KeyError if the job not found. """ # FIXME: this doesn't have to be a linear search but that'll do for now jobs = _list_jobs( emr_client=self._emr_client(), job_type=None, project=None, table_name=None, active_only=False, ) for job_info in jobs: if _job_ref_to_str(job_info.job_ref) == job_id: return self._job_from_job_info(job_info) else: raise KeyError(f"Job not found {job_id}")