import abc
import hashlib
import json
import os
from base64 import b64encode
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
[docs]class SparkJobFailure(Exception):
"""
Job submission failed, encountered error during execution, or timeout
"""
pass
BQ_SPARK_PACKAGE = "com.google.cloud.spark:spark-bigquery-with-dependencies_2.12:0.18.0"
[docs]class SparkJobStatus(Enum):
STARTING = 0
IN_PROGRESS = 1
FAILED = 2
COMPLETED = 3
[docs]class SparkJobType(Enum):
HISTORICAL_RETRIEVAL = 0
BATCH_INGESTION = 1
STREAM_INGESTION = 2
SCHEDULED_BATCH_INGESTION = 3
[docs] def to_pascal_case(self):
return self.name.title().replace("_", "")
[docs]class SparkJob(abc.ABC):
"""
Base class for all spark jobs
"""
[docs] @abc.abstractmethod
def get_id(self) -> str:
"""
Getter for the job id. The job id must be unique for each spark job submission.
Returns:
str: Job id.
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def get_status(self) -> SparkJobStatus:
"""
Job Status retrieval
Returns:
SparkJobStatus: Job status
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def cancel(self):
"""
Manually terminate job
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def get_start_time(self) -> datetime:
"""
Get job start time.
"""
[docs] def get_log_uri(self) -> Optional[str]:
"""
Get path to Spark job log, if applicable.
"""
return None
[docs] def get_error_message(self) -> Optional[str]:
"""
Get Spark job error message, if applicable.
"""
return None
[docs]class SparkJobParameters(abc.ABC):
[docs] @abc.abstractmethod
def get_name(self) -> str:
"""
Getter for job name
Returns:
str: Job name.
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def get_job_type(self) -> SparkJobType:
"""
Getter for job type.
Returns:
SparkJobType: Job type enum.
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def get_main_file_path(self) -> str:
"""
Getter for jar | python path
Returns:
str: Full path to file.
"""
raise NotImplementedError
[docs] def get_class_name(self) -> Optional[str]:
"""
Getter for main class name if it's applicable
Returns:
Optional[str]: java class path, e.g. feast.ingestion.IngestionJob.
"""
return None
[docs] @abc.abstractmethod
def get_arguments(self) -> List[str]:
"""
Getter for job arguments
E.g., ["--source", '{"kafka":...}', ...]
Returns:
List[str]: List of arguments.
"""
raise NotImplementedError
[docs]class RetrievalJobParameters(SparkJobParameters):
def __init__(
self,
project: str,
feature_tables: List[Dict],
feature_tables_sources: List[Dict],
entity_source: Dict,
destination: Dict,
extra_packages: Optional[List[str]] = None,
checkpoint_path: Optional[str] = None,
):
"""
Args:
project (str): Client project
entity_source (Dict): Entity data source configuration.
feature_tables_sources (List[Dict]): List of feature tables data sources configurations.
feature_tables (List[Dict]): List of feature table specification.
The order of the feature table must correspond to that of feature_tables_sources.
destination (Dict): Retrieval job output destination.
extra_packages (Optional[List[str]): Extra maven packages to be included on Spark driver
and executors classpath.
Examples:
>>> # Entity source from file
>>> entity_source = {
"file": {
"format": "parquet",
"path": "gs://some-gcs-bucket/customer",
"event_timestamp_column": "event_timestamp",
"options": {
"mergeSchema": "true"
} # Optional. Options to be passed to Spark while reading the dataframe from source.
"field_mapping": {
"id": "customer_id"
} # Optional. Map the columns, where the key is the original column name and the value is the new column name.
}
}
>>> # Entity source from BigQuery
>>> entity_source = {
"bq": {
"project": "gcp_project_id",
"dataset": "bq_dataset",
"table": "customer",
"event_timestamp_column": "event_timestamp",
}
}
>>> feature_tables_sources = [
{
"bq": {
"project": "gcp_project_id",
"dataset": "bq_dataset",
"table": "customer_transactions",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp" # This field is mandatory for feature tables.
}
},
{
"file": {
"format": "parquet",
"path": "gs://some-gcs-bucket/customer_profile",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp",
"options": {
"mergeSchema": "true"
}
}
},
]
>>> feature_tables = [
{
"name": "customer_transactions",
"entities": [
{
"name": "customer
"type": "int32"
}
],
"features": [
{
"name": "total_transactions"
"type": "double"
},
{
"name": "total_discounts"
"type": "double"
}
],
"max_age": 86400 # In seconds.
},
{
"name": "customer_profile",
"entities": [
{
"name": "customer
"type": "int32"
}
],
"features": [
{
"name": "is_vip"
"type": "bool"
}
],
}
]
>>> destination = {
"format": "parquet",
"path": "gs://some-gcs-bucket/retrieval_output"
}
"""
self._project = project
self._feature_tables = feature_tables
self._feature_tables_sources = feature_tables_sources
self._entity_source = entity_source
self._destination = destination
self._extra_packages = extra_packages if extra_packages else []
self._checkpoint_path = checkpoint_path
[docs] def get_project(self) -> str:
return self._project
[docs] def get_name(self) -> str:
all_feature_tables_names = [ft["name"] for ft in self._feature_tables]
return f"{self.get_job_type().to_pascal_case()}-{'-'.join(all_feature_tables_names)}"
[docs] def get_job_type(self) -> SparkJobType:
return SparkJobType.HISTORICAL_RETRIEVAL
[docs] def get_main_file_path(self) -> str:
return os.path.join(
os.path.dirname(__file__), "historical_feature_retrieval_job.py"
)
[docs] def get_arguments(self) -> List[str]:
def json_b64_encode(obj) -> str:
return b64encode(json.dumps(obj).encode("utf8")).decode("ascii")
args = [
"--feature-tables",
json_b64_encode(self._feature_tables),
"--feature-tables-sources",
json_b64_encode(self._feature_tables_sources),
"--entity-source",
json_b64_encode(self._entity_source),
"--destination",
json_b64_encode(self._destination),
]
if self._checkpoint_path:
args.extend(["--checkpoint", self._checkpoint_path])
return args
[docs] def get_destination_path(self) -> str:
return self._destination["path"]
[docs]class RetrievalJob(SparkJob):
"""
Container for the historical feature retrieval job result
"""
[docs] @abc.abstractmethod
def get_output_file_uri(self, timeout_sec=None, block=True):
"""
Get output file uri to the result file. This method will block until the
job succeeded, or if the job didn't execute successfully within timeout.
Args:
timeout_sec (int):
Max no of seconds to wait until job is done. If "timeout_sec"
is exceeded or if the job fails, an exception will be raised.
block (bool):
If false, don't block until the job is done. If block=True, timeout parameter is
ignored.
Raises:
SparkJobFailure:
The spark job submission failed, encountered error during execution,
or timeout.
Returns:
str: file uri to the result file.
"""
raise NotImplementedError
[docs]class IngestionJobParameters(SparkJobParameters):
def __init__(
self,
feature_table: Dict,
source: Dict,
jar: str,
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_password: Optional[str] = None,
redis_ssl: Optional[bool] = None,
bigtable_project: Optional[str] = None,
bigtable_instance: Optional[str] = None,
cassandra_host: Optional[str] = None,
cassandra_port: Optional[int] = None,
statsd_host: Optional[str] = None,
statsd_port: Optional[int] = None,
deadletter_path: Optional[str] = None,
stencil_url: Optional[str] = None,
stencil_token: Optional[str] = None,
drop_invalid_rows: bool = False,
):
self._feature_table = feature_table
self._source = source
self._jar = jar
self._redis_host = redis_host
self._redis_port = redis_port
self._redis_password = redis_password
self._redis_ssl = redis_ssl
self._bigtable_project = bigtable_project
self._bigtable_instance = bigtable_instance
self._cassandra_host = cassandra_host
self._cassandra_port = cassandra_port
self._statsd_host = statsd_host
self._statsd_port = statsd_port
self._deadletter_path = deadletter_path
self._stencil_url = stencil_url
self._stencil_token = stencil_token
self._drop_invalid_rows = drop_invalid_rows
def _get_redis_config(self):
return dict(
host=self._redis_host,
port=self._redis_port,
password=self._redis_password,
ssl=self._redis_ssl,
)
def _get_bigtable_config(self):
return dict(
project_id=self._bigtable_project, instance_id=self._bigtable_instance
)
def _get_cassandra_config(self):
return dict(host=self._cassandra_host, port=self._cassandra_port)
def _get_statsd_config(self):
return (
dict(host=self._statsd_host, port=self._statsd_port)
if self._statsd_host
else None
)
[docs] def get_project(self) -> str:
return self._feature_table["project"]
[docs] def get_feature_table_name(self) -> str:
return self._feature_table["name"]
[docs] def get_main_file_path(self) -> str:
return self._jar
[docs] def get_class_name(self) -> Optional[str]:
return "feast.ingestion.IngestionJob"
[docs] def get_arguments(self) -> List[str]:
args = [
"--feature-table",
json.dumps(self._feature_table),
"--source",
json.dumps(self._source),
]
if self._redis_host and self._redis_port:
args.extend(["--redis", json.dumps(self._get_redis_config())])
if self._bigtable_project and self._bigtable_instance:
args.extend(["--bigtable", json.dumps(self._get_bigtable_config())])
if self._cassandra_host and self._cassandra_port:
args.extend(["--cassandra", json.dumps(self._get_cassandra_config())])
if self._get_statsd_config():
args.extend(["--statsd", json.dumps(self._get_statsd_config())])
if self._deadletter_path:
args.extend(
[
"--deadletter-path",
os.path.join(self._deadletter_path, self.get_feature_table_name()),
]
)
if self._stencil_url:
args.extend(["--stencil-url", self._stencil_url])
if self._stencil_token:
args.extend(["--stencil-token", self._stencil_token])
if self._drop_invalid_rows:
args.extend(["--drop-invalid"])
return args
[docs]class BatchIngestionJobParameters(IngestionJobParameters):
def __init__(
self,
feature_table: Dict,
source: Dict,
start: datetime,
end: datetime,
jar: str,
redis_host: Optional[str],
redis_port: Optional[int],
redis_password: Optional[str],
redis_ssl: Optional[bool],
bigtable_project: Optional[str],
bigtable_instance: Optional[str],
cassandra_host: Optional[str] = None,
cassandra_port: Optional[int] = None,
statsd_host: Optional[str] = None,
statsd_port: Optional[int] = None,
deadletter_path: Optional[str] = None,
stencil_url: Optional[str] = None,
stencil_token: Optional[str] = None,
):
super().__init__(
feature_table,
source,
jar,
redis_host,
redis_port,
redis_password,
redis_ssl,
bigtable_project,
bigtable_instance,
cassandra_host,
cassandra_port,
statsd_host,
statsd_port,
deadletter_path,
stencil_url,
stencil_token,
)
self._start = start
self._end = end
[docs] def get_name(self) -> str:
return (
f"{self.get_job_type().to_pascal_case()}-{self.get_feature_table_name()}-"
f"{self._start.strftime('%Y-%m-%d')}-{self._end.strftime('%Y-%m-%d')}"
)
[docs] def get_job_type(self) -> SparkJobType:
return SparkJobType.BATCH_INGESTION
[docs] def get_arguments(self) -> List[str]:
return super().get_arguments() + [
"--mode",
"offline",
"--start",
self._start.strftime("%Y-%m-%dT%H:%M:%S"),
"--end",
self._end.strftime("%Y-%m-%dT%H:%M:%S"),
]
[docs]class ScheduledBatchIngestionJobParameters(IngestionJobParameters):
def __init__(
self,
feature_table: Dict,
source: Dict,
ingestion_timespan: int,
cron_schedule: str,
jar: str,
redis_host: Optional[str],
redis_port: Optional[int],
redis_password: Optional[str],
redis_ssl: Optional[bool],
bigtable_project: Optional[str],
bigtable_instance: Optional[str],
cassandra_host: Optional[str] = None,
cassandra_port: Optional[int] = None,
statsd_host: Optional[str] = None,
statsd_port: Optional[int] = None,
deadletter_path: Optional[str] = None,
stencil_url: Optional[str] = None,
stencil_token: Optional[str] = None,
):
super().__init__(
feature_table,
source,
jar,
redis_host,
redis_port,
redis_password,
redis_ssl,
bigtable_project,
bigtable_instance,
cassandra_host,
cassandra_port,
statsd_host,
statsd_port,
deadletter_path,
stencil_url,
stencil_token,
)
self._ingestion_timespan = ingestion_timespan
self._cron_schedule = cron_schedule
[docs] def get_name(self) -> str:
return f"{self.get_job_type().to_pascal_case()}-{self.get_feature_table_name()}"
[docs] def get_job_type(self) -> SparkJobType:
return SparkJobType.SCHEDULED_BATCH_INGESTION
[docs] def get_job_schedule(self) -> str:
return self._cron_schedule
[docs] def get_arguments(self) -> List[str]:
return super().get_arguments() + [
"--mode",
"offline",
"--ingestion-timespan",
str(self._ingestion_timespan),
]
[docs]class StreamIngestionJobParameters(IngestionJobParameters):
def __init__(
self,
feature_table: Dict,
source: Dict,
jar: str,
extra_jars: List[str] = None,
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_password: Optional[str] = None,
redis_ssl: Optional[bool] = None,
bigtable_project: Optional[str] = None,
bigtable_instance: Optional[str] = None,
cassandra_host: Optional[str] = None,
cassandra_port: Optional[int] = None,
statsd_host: Optional[str] = None,
statsd_port: Optional[int] = None,
deadletter_path: Optional[str] = None,
checkpoint_path: Optional[str] = None,
stencil_url: Optional[str] = None,
stencil_token: Optional[str] = None,
drop_invalid_rows: bool = False,
triggering_interval: Optional[int] = None,
):
super().__init__(
feature_table,
source,
jar,
redis_host,
redis_port,
redis_password,
redis_ssl,
bigtable_project,
bigtable_instance,
cassandra_host,
cassandra_port,
statsd_host,
statsd_port,
deadletter_path,
stencil_url,
stencil_token,
drop_invalid_rows,
)
self._extra_jars = extra_jars
self._checkpoint_path = checkpoint_path
self._triggering_interval = triggering_interval
[docs] def get_name(self) -> str:
return f"{self.get_job_type().to_pascal_case()}-{self.get_feature_table_name()}"
[docs] def get_job_type(self) -> SparkJobType:
return SparkJobType.STREAM_INGESTION
[docs] def get_arguments(self) -> List[str]:
args = super().get_arguments()
args.extend(["--mode", "online"])
if self._checkpoint_path:
args.extend(["--checkpoint-path", self._checkpoint_path])
if self._triggering_interval:
args.extend(["--triggering-interval", str(self._triggering_interval)])
return args
[docs] def get_job_hash(self) -> str:
sorted_feature_table = self._feature_table.copy()
sorted_feature_table["entities"] = sorted(
self._feature_table["entities"], key=lambda x: x["name"]
)
sorted_feature_table["features"] = sorted(
self._feature_table["features"], key=lambda x: x["name"]
)
job_json = json.dumps(
{"source": self._source, "feature_table": sorted_feature_table},
sort_keys=True,
)
return hashlib.md5(job_json.encode()).hexdigest()
[docs]class BatchIngestionJob(SparkJob):
"""
Container for the ingestion job result
"""
[docs] @abc.abstractmethod
def get_feature_table(self) -> str:
"""
Get the feature table name associated with this job. Return empty string if unable to
determine the feature table, such as when the job is created by the earlier
version of Feast.
Returns:
str: Feature table name
"""
raise NotImplementedError
[docs]class StreamIngestionJob(SparkJob):
"""
Container for the streaming ingestion job result
"""
[docs] def get_hash(self) -> str:
"""Gets the consistent hash of this stream ingestion job.
The hash needs to be persisted at the data processing layer, so that we can get the same
hash when retrieving the job from Spark.
Returns:
str: The hash for this streaming ingestion job
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def get_feature_table(self) -> str:
"""
Get the feature table name associated with this job. Return `None` if unable to
determine the feature table, such as when the job is created by the earlier
version of Feast.
Returns:
str: Feature table name
"""
raise NotImplementedError
[docs]class JobLauncher(abc.ABC):
"""
Submits spark jobs to a spark cluster. Currently supports only historical feature retrieval jobs.
"""
[docs] @abc.abstractmethod
def historical_feature_retrieval(
self, retrieval_job_params: RetrievalJobParameters
) -> RetrievalJob:
"""
Submits a historical feature retrieval job to a Spark cluster.
Raises:
SparkJobFailure: The spark job submission failed, encountered error
during execution, or timeout.
Returns:
RetrievalJob: wrapper around remote job that returns file uri to the result file.
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def offline_to_online_ingestion(
self, ingestion_job_params: BatchIngestionJobParameters
) -> BatchIngestionJob:
"""
Submits a batch ingestion job to a Spark cluster.
Raises:
SparkJobFailure: The spark job submission failed, encountered error
during execution, or timeout.
Returns:
BatchIngestionJob: wrapper around remote job that can be used to check when job completed.
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def schedule_offline_to_online_ingestion(
self, ingestion_job_params: ScheduledBatchIngestionJobParameters
):
"""
Submits a scheduled batch ingestion job to a Spark cluster.
Raises:
SparkJobFailure: The spark job submission failed, encountered error
during execution, or timeout.
Returns:
ScheduledBatchIngestionJob: wrapper around remote job that can be used to check when job completed.
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def unschedule_offline_to_online_ingestion(self, project: str, feature_table: str):
"""
Unschedule a scheduled batch ingestion job.
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def start_stream_to_online_ingestion(
self, ingestion_job_params: StreamIngestionJobParameters
) -> StreamIngestionJob:
"""
Starts a stream ingestion job to a Spark cluster.
Raises:
SparkJobFailure: The spark job submission failed, encountered error
during execution, or timeout.
Returns:
StreamIngestionJob: wrapper around remote job.
"""
raise NotImplementedError
[docs] @abc.abstractmethod
def get_job_by_id(self, job_id: str) -> SparkJob:
raise NotImplementedError
[docs] @abc.abstractmethod
def list_jobs(
self,
include_terminated: bool,
project: Optional[str],
table_name: Optional[str],
) -> List[SparkJob]:
raise NotImplementedError