Source code for feast_spark.pyspark.launcher

import json
import os
import tempfile
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from urllib.parse import urlparse, urlunparse

from feast.config import Config
from feast.data_format import ParquetFormat
from feast.data_source import BigQuerySource, DataSource, FileSource, KafkaSource
from feast.feature_table import FeatureTable
from feast.staging.storage_client import get_staging_client
from feast.value_type import ValueType
from feast_spark.constants import ConfigOptions as opt
from feast_spark.pyspark.abc import (
    BatchIngestionJob,
    BatchIngestionJobParameters,
    JobLauncher,
    RetrievalJob,
    RetrievalJobParameters,
    ScheduledBatchIngestionJobParameters,
    SparkJob,
    StreamIngestionJob,
    StreamIngestionJobParameters,
)

if TYPE_CHECKING:
    from feast_spark.client import Client


def _standalone_launcher(config: Config) -> JobLauncher:
    from feast_spark.pyspark.launchers import standalone

    return standalone.StandaloneClusterLauncher(
        config.get(opt.SPARK_STANDALONE_MASTER), config.get(opt.SPARK_HOME),
    )


def _dataproc_launcher(config: Config) -> JobLauncher:
    from feast_spark.pyspark.launchers import gcloud

    return gcloud.DataprocClusterLauncher(
        cluster_name=config.get(opt.DATAPROC_CLUSTER_NAME),
        staging_location=config.get(opt.SPARK_STAGING_LOCATION),
        region=config.get(opt.DATAPROC_REGION),
        project_id=config.get(opt.DATAPROC_PROJECT),
        executor_instances=config.get(opt.DATAPROC_EXECUTOR_INSTANCES),
        executor_cores=config.get(opt.DATAPROC_EXECUTOR_CORES),
        executor_memory=config.get(opt.DATAPROC_EXECUTOR_MEMORY),
    )


def _emr_launcher(config: Config) -> JobLauncher:
    from feast_spark.pyspark.launchers import aws

    def _get_optional(option):
        if config.exists(option):
            return config.get(option)

    return aws.EmrClusterLauncher(
        region=config.get(opt.EMR_REGION),
        existing_cluster_id=_get_optional(opt.EMR_CLUSTER_ID),
        new_cluster_template_path=_get_optional(opt.EMR_CLUSTER_TEMPLATE_PATH),
        staging_location=config.get(opt.SPARK_STAGING_LOCATION),
        emr_log_location=config.get(opt.EMR_LOG_LOCATION),
    )


def _k8s_launcher(config: Config) -> JobLauncher:
    from feast_spark.pyspark.launchers import k8s

    staging_location = config.get(opt.SPARK_STAGING_LOCATION)
    staging_uri = urlparse(staging_location)

    return k8s.KubernetesJobLauncher(
        namespace=config.get(opt.SPARK_K8S_NAMESPACE),
        generic_resource_template_path=config.get(opt.SPARK_K8S_JOB_TEMPLATE_PATH),
        batch_ingestion_resource_template_path=config.get(
            opt.SPARK_K8S_BATCH_INGESTION_TEMPLATE_PATH, None
        ),
        stream_ingestion_resource_template_path=config.get(
            opt.SPARK_K8S_STREAM_INGESTION_TEMPLATE_PATH, None
        ),
        historical_retrieval_resource_template_path=config.get(
            opt.SPARK_K8S_HISTORICAL_RETRIEVAL_TEMPLATE_PATH, None
        ),
        staging_location=staging_location,
        incluster=config.getboolean(opt.SPARK_K8S_USE_INCLUSTER_CONFIG),
        staging_client=get_staging_client(staging_uri.scheme, config),
    )


_launchers = {
    "standalone": _standalone_launcher,
    "dataproc": _dataproc_launcher,
    "emr": _emr_launcher,
    "k8s": _k8s_launcher,
}


[docs]def resolve_launcher(config: Config) -> JobLauncher: return _launchers[config.get(opt.SPARK_LAUNCHER)](config)
def _source_to_argument(source: DataSource, config: Config): common_properties = { "field_mapping": dict(source.field_mapping), "event_timestamp_column": source.event_timestamp_column, "created_timestamp_column": source.created_timestamp_column, "date_partition_column": source.date_partition_column, } properties = {**common_properties} if isinstance(source, FileSource): properties["path"] = source.file_options.file_url properties["format"] = dict( json_class=source.file_options.file_format.__class__.__name__ ) return {"file": properties} if isinstance(source, BigQuerySource): project, dataset_and_table = source.bigquery_options.table_ref.split(":") dataset, table = dataset_and_table.split(".") properties["project"] = project properties["dataset"] = dataset properties["table"] = table if config.exists(opt.SPARK_BQ_MATERIALIZATION_PROJECT) and config.exists( opt.SPARK_BQ_MATERIALIZATION_DATASET ): properties["materialization"] = dict( project=config.get(opt.SPARK_BQ_MATERIALIZATION_PROJECT), dataset=config.get(opt.SPARK_BQ_MATERIALIZATION_DATASET), ) return {"bq": properties} if isinstance(source, KafkaSource): properties["bootstrap_servers"] = source.kafka_options.bootstrap_servers properties["topic"] = source.kafka_options.topic properties["format"] = { **source.kafka_options.message_format.__dict__, "json_class": source.kafka_options.message_format.__class__.__name__, } return {"kafka": properties} raise NotImplementedError(f"Unsupported Datasource: {type(source)}") def _feature_table_to_argument( client: "Client", project: str, feature_table: FeatureTable, use_gc_threshold=True, ): max_age = feature_table.max_age.ToSeconds() if feature_table.max_age else None if use_gc_threshold: try: gc_threshold = int(feature_table.labels["gcThresholdSec"]) except (KeyError, ValueError, TypeError): pass else: max_age = max(max_age or 0, gc_threshold) return { "features": [ {"name": f.name, "type": ValueType(f.dtype).name} for f in feature_table.features ], "project": project, "name": feature_table.name, "entities": [ { "name": n, "type": client.feature_store.get_entity(n, project=project).value_type, } for n in feature_table.entities ], "max_age": max_age, "labels": dict(feature_table.labels), }
[docs]def start_historical_feature_retrieval_spark_session( client: "Client", project: str, entity_source: Union[FileSource, BigQuerySource], feature_tables: List[FeatureTable], ): from pyspark.sql import SparkSession from feast_spark.pyspark.historical_feature_retrieval_job import ( retrieve_historical_features, ) spark_session = SparkSession.builder.getOrCreate() return retrieve_historical_features( spark=spark_session, entity_source_conf=_source_to_argument(entity_source, client.config), feature_tables_sources_conf=[ _source_to_argument(feature_table.batch_source, client.config) for feature_table in feature_tables ], feature_tables_conf=[ _feature_table_to_argument( client, project, feature_table, use_gc_threshold=False ) for feature_table in feature_tables ], )
[docs]def start_historical_feature_retrieval_job( client: "Client", project: str, entity_source: Union[FileSource, BigQuerySource], feature_tables: List[FeatureTable], output_format: str, output_path: str, ) -> RetrievalJob: launcher = resolve_launcher(client.config) feature_sources = [ _source_to_argument(feature_table.batch_source, client.config,) for feature_table in feature_tables ] return launcher.historical_feature_retrieval( RetrievalJobParameters( project=project, entity_source=_source_to_argument(entity_source, client.config), feature_tables_sources=feature_sources, feature_tables=[ _feature_table_to_argument( client, project, feature_table, use_gc_threshold=False ) for feature_table in feature_tables ], destination={"format": output_format, "path": output_path}, checkpoint_path=client.config.get(opt.CHECKPOINT_PATH), ) )
[docs]def table_reference_from_string(table_ref: str): """ Parses reference string with format "{project}:{dataset}.{table}" into bigquery.TableReference """ from google.cloud import bigquery project, dataset_and_table = table_ref.split(":") dataset, table_id = dataset_and_table.split(".") return bigquery.TableReference( bigquery.DatasetReference(project, dataset), table_id )
[docs]def start_offline_to_online_ingestion( client: "Client", project: str, feature_table: FeatureTable, start: datetime, end: datetime, ) -> BatchIngestionJob: launcher = resolve_launcher(client.config) return launcher.offline_to_online_ingestion( BatchIngestionJobParameters( jar=client.config.get(opt.SPARK_INGESTION_JAR), source=_source_to_argument(feature_table.batch_source, client.config), feature_table=_feature_table_to_argument(client, project, feature_table), start=start, end=end, redis_host=client.config.get(opt.REDIS_HOST), redis_port=bool(client.config.get(opt.REDIS_HOST)) and client.config.getint(opt.REDIS_PORT), redis_password=client.config.get(opt.REDIS_PASSWORD), redis_ssl=client.config.getboolean(opt.REDIS_SSL), bigtable_project=client.config.get(opt.BIGTABLE_PROJECT), bigtable_instance=client.config.get(opt.BIGTABLE_INSTANCE), cassandra_host=client.config.get(opt.CASSANDRA_HOST), cassandra_port=bool(client.config.get(opt.CASSANDRA_HOST)) and client.config.getint(opt.CASSANDRA_PORT), statsd_host=( client.config.getboolean(opt.STATSD_ENABLED) and client.config.get(opt.STATSD_HOST) ), statsd_port=( client.config.getboolean(opt.STATSD_ENABLED) and client.config.getint(opt.STATSD_PORT) ), deadletter_path=client.config.get(opt.DEADLETTER_PATH), stencil_url=client.config.get(opt.STENCIL_URL), stencil_token=client.config.get(opt.STENCIL_TOKEN), ) )
[docs]def schedule_offline_to_online_ingestion( client: "Client", project: str, feature_table: FeatureTable, ingestion_timespan: int, cron_schedule: str, ): launcher = resolve_launcher(client.config) launcher.schedule_offline_to_online_ingestion( ScheduledBatchIngestionJobParameters( jar=client.config.get(opt.SPARK_INGESTION_JAR), source=_source_to_argument(feature_table.batch_source, client.config), feature_table=_feature_table_to_argument(client, project, feature_table), ingestion_timespan=ingestion_timespan, cron_schedule=cron_schedule, redis_host=client.config.get(opt.REDIS_HOST), redis_port=bool(client.config.get(opt.REDIS_HOST)) and client.config.getint(opt.REDIS_PORT), redis_password=client.config.get(opt.REDIS_PASSWORD), redis_ssl=client.config.getboolean(opt.REDIS_SSL), bigtable_project=client.config.get(opt.BIGTABLE_PROJECT), bigtable_instance=client.config.get(opt.BIGTABLE_INSTANCE), cassandra_host=client.config.get(opt.CASSANDRA_HOST), cassandra_port=bool(client.config.get(opt.CASSANDRA_HOST)) and client.config.getint(opt.CASSANDRA_PORT), statsd_host=( client.config.getboolean(opt.STATSD_ENABLED) and client.config.get(opt.STATSD_HOST) ), statsd_port=( client.config.getboolean(opt.STATSD_ENABLED) and client.config.getint(opt.STATSD_PORT) ), deadletter_path=client.config.get(opt.DEADLETTER_PATH), stencil_url=client.config.get(opt.STENCIL_URL), stencil_token=client.config.get(opt.STENCIL_TOKEN), ) )
[docs]def unschedule_offline_to_online_ingestion( client: "Client", project: str, feature_table: FeatureTable, ): launcher = resolve_launcher(client.config) launcher.unschedule_offline_to_online_ingestion(project, feature_table.name)
[docs]def get_stream_to_online_ingestion_params( client: "Client", project: str, feature_table: FeatureTable, extra_jars: List[str] ) -> StreamIngestionJobParameters: return StreamIngestionJobParameters( jar=client.config.get(opt.SPARK_INGESTION_JAR), extra_jars=extra_jars, source=_source_to_argument(feature_table.stream_source, client.config), feature_table=_feature_table_to_argument(client, project, feature_table), redis_host=client.config.get(opt.REDIS_HOST), redis_port=bool(client.config.get(opt.REDIS_HOST)) and client.config.getint(opt.REDIS_PORT), redis_password=client.config.get(opt.REDIS_PASSWORD), redis_ssl=client.config.getboolean(opt.REDIS_SSL), bigtable_project=client.config.get(opt.BIGTABLE_PROJECT), bigtable_instance=client.config.get(opt.BIGTABLE_INSTANCE), statsd_host=client.config.getboolean(opt.STATSD_ENABLED) and client.config.get(opt.STATSD_HOST), statsd_port=client.config.getboolean(opt.STATSD_ENABLED) and client.config.getint(opt.STATSD_PORT), deadletter_path=client.config.get(opt.DEADLETTER_PATH), checkpoint_path=client.config.get(opt.CHECKPOINT_PATH), stencil_url=client.config.get(opt.STENCIL_URL), stencil_token=client.config.get(opt.STENCIL_TOKEN), drop_invalid_rows=client.config.get(opt.INGESTION_DROP_INVALID_ROWS), triggering_interval=client.config.getint( opt.SPARK_STREAMING_TRIGGERING_INTERVAL, default=None ), )
[docs]def start_stream_to_online_ingestion( client: "Client", project: str, feature_table: FeatureTable, extra_jars: List[str] ) -> StreamIngestionJob: launcher = resolve_launcher(client.config) return launcher.start_stream_to_online_ingestion( get_stream_to_online_ingestion_params( client, project, feature_table, extra_jars ) )
[docs]def list_jobs( include_terminated: bool, client: "Client", project: Optional[str] = None, table_name: Optional[str] = None, ) -> List[SparkJob]: launcher = resolve_launcher(client.config) return launcher.list_jobs( include_terminated=include_terminated, table_name=table_name, project=project )
[docs]def get_job_by_id(job_id: str, client: "Client") -> SparkJob: launcher = resolve_launcher(client.config) return launcher.get_job_by_id(job_id)
[docs]def get_health_metrics( client: "Client", project: str, table_names: List[str], ) -> Dict[str, List[str]]: all_redis_keys = [f"{project}:{table}" for table in table_names] metrics = client.metrics_redis.mget(all_redis_keys) passed_feature_tables = [] failed_feature_tables = [] for metric, name in zip(metrics, table_names): feature_table = client.feature_store.get_feature_table( project=project, name=name ) max_age = feature_table.max_age # Only perform ingestion health checks for Feature tables with max_age if not max_age: passed_feature_tables.append(name) continue # If there are missing metrics in Redis; None is returned if there is no such key if not metric: passed_feature_tables.append(name) continue # Ensure ingestion times are in epoch timings last_ingestion_time = json.loads(metric)["last_processed_event_timestamp"][ "value" ] valid_ingestion_time = datetime.timestamp( datetime.now() - timedelta(seconds=max_age.ToSeconds()) ) # Check if latest ingestion timestamp > cur_time - max_age if valid_ingestion_time > last_ingestion_time: failed_feature_tables.append(name) else: passed_feature_tables.append(name) return {"passed": passed_feature_tables, "failed": failed_feature_tables}
[docs]def stage_dataframe(df, event_timestamp_column: str, config: Config) -> FileSource: """ Helper function to upload a pandas dataframe in parquet format to a temporary location (under SPARK_STAGING_LOCATION) and return it wrapped in a FileSource. Args: event_timestamp_column(str): the name of the timestamp column in the dataframe. config(Config): feast config. """ staging_location = config.get(opt.SPARK_STAGING_LOCATION) staging_uri = urlparse(staging_location) with tempfile.NamedTemporaryFile() as f: df.to_parquet(f) file_url = urlunparse( get_staging_client(staging_uri.scheme, config).upload_fileobj( f, f.name, remote_path_prefix=os.path.join(staging_location, "dataframes"), remote_path_suffix=".parquet", ) ) return FileSource( event_timestamp_column=event_timestamp_column, file_format=ParquetFormat(), file_url=file_url, )