class UpsertSCDSpark(BaseLoader):
def __init__(
self,
primary_keys: Union[str, list, tuple],
path: Optional[str] = None,
dbtable: Optional[str] = None,
tgt_dbtable: Optional[str] = None,
update_timestamp: Optional[str] = None,
src_eff_timestamp: Optional[str] = None,
current_flag: Optional[str] = "current_flag",
soft_delete_flag: Optional[str] = "soft_delete_flag",
delete_timestamp: Optional[str] = "delete_ts",
eff_ts: Optional[str] = "eff_ts",
exp_ts: Optional[str] = "exp_ts",
columns: Optional[list] = None,
format: str = 'delta',
*args,
**kwargs
):
super(UpsertSCDSpark, self).__init__(*args, **kwargs)
self.primary_keys = primary_keys
self.tgt_dbtable = tgt_dbtable
self.update_timestamp = update_timestamp
self.src_eff_timestamp = src_eff_timestamp
self.current_flag = current_flag
self.soft_delete_flag = soft_delete_flag
self.delete_timestamp = delete_timestamp
self.eff_ts = eff_ts
self.exp_ts = exp_ts
self.dbtable = dbtable
self.columns_list = columns
self.path = path
self.format = format
self._loader = ReadSpark
if self.dbtable and self.path:
raise ValueError('One of `dbtable` or `path` must be specified!')
if self.dbtable and self.path:
raise ValueError('No `dbtable` or `path` specified. One must be specified!')
self.database, self.table = self._parse_dbtable(self.dbtable)
self.logger = logging.getLogger(__name__)
self.logger.setLevel(self.log_level)
def _get_df(self):
params = {
'path': self.path,
'dbtable': self.dbtable,
'format': self.format,
'options': self.options,
'spark_session': self._spark_session,
}
read = ReadSpark(**params)
df = read.load()
return df
def _get_cols(self, tbl_a_alias, tbl_b_alias, keys, columns):
exclude_key_list = [x for x in columns if x not in keys][:-1]
a_list = [tbl_a_alias + '.' + col.strip() for col in exclude_key_list]
b_list = [tbl_b_alias + '.' + col.strip() for col in exclude_key_list]
col_conditions = " or ".join(a + "<>" + b for a, b in zip(a_list, b_list))
return col_conditions
def _get_scd_sql_contents(self, tbl_a_alias, tbl_b_alias, keys, tbl_src_alias="src", tbl_tgt_alias="tgt"):
# get columns
mergeKey_cols = ", ".join(tbl_a_alias + '.' + col + ' as mergeKey_' + col for col in keys)
NULL_mergeKey_cols = ", ".join('NULL as mergeKey_' + col for col in keys)
# get inner join conditions
a_list = [tbl_a_alias + '.' + col.strip() for col in keys]
b_list = [tbl_b_alias + '.' + col.strip() for col in keys]
key_join_conditions = " and ".join(a + "=" + b for a, b in zip(a_list, b_list))
# get outer join conditions
src_list = [tbl_src_alias + '.mergeKey_' + col.strip() for col in keys]
tgt_list = [tbl_tgt_alias + '.' + col.strip() for col in keys]
merge_join_conditions = " and ".join(a + "=" + b for a, b in zip(src_list, tgt_list))
return mergeKey_cols, NULL_mergeKey_cols, key_join_conditions, merge_join_conditions
def _merge_tables(self):
mergeKey_cols, NULL_mergeKey_cols, key_join_conditions, merge_join_conditions = self._get_scd_sql_contents("a",
"b",
self.primary_keys)
df_src = self._get_df()
if self.columns_list:
lower_cols = [col.lower() for col in self.columns_list]
else:
lower_cols = [col.lower() for col in df_src.columns]
lower_keys = [col.lower() for col in self.primary_keys]
src_tbl_alias, tgt_tbl_alias = 'src', 'tgt'
inner_col_conditions = self._get_cols("a", "b", lower_keys, lower_cols)
outer_col_conditions = self._get_cols(src_tbl_alias, tgt_tbl_alias, lower_keys, lower_cols)
insert_columns = ",".join(df_src.columns)
insert_values = ",".join([src_tbl_alias + '.' + col for col in df_src.columns])
sql_merge_query = f"""
MERGE INTO {self.tgt_dbtable} AS tgt
USING (select {mergeKey_cols}, a.*
from {self.dbtable} a
UNION ALL
SELECT {NULL_mergeKey_cols}, a.*
FROM {self.dbtable} a JOIN {self.tgt_dbtable} b ON {key_join_conditions}
WHERE b.{self.current_flag} = 'Y' AND {inner_col_conditions}
) AS src ON {merge_join_conditions}
WHEN MATCHED and tgt.{self.current_flag} = 'Y' and {outer_col_conditions} THEN
UPDATE SET tgt.{self.current_flag} = 'N', tgt.{self.update_timestamp} = current_timestamp(), tgt.{self.exp_ts} = {self.src_eff_timestamp}
WHEN NOT MATCHED THEN
INSERT ({insert_columns}, {self.current_flag}, {self.soft_delete_flag}, {self.delete_timestamp}, {self.eff_ts}, {self.exp_ts})
VALUES ({insert_values}, 'Y', 'N',current_timestamp(), {self.src_eff_timestamp}, current_timestamp())
"""
self._spark_session.sql(sql_merge_query)
def load(self, df: Optional[SparkDataFrame] = None) -> SparkDataFrame:
self.logger.debug('{:}.{:} started...'.format(self.__class__.__name__, inspect.stack()[0][3]))
df_src = self._get_df()
if df_src.head():
self._merge_tables()
df = self._spark_session.table(f"{self.tgt_dbtable}")
else:
warnings.warn(
"Upsert skipped! The input dataframe is empty."
)
self.logger.debug('{:}.{:} completed.'.format(self.__class__.__name__, inspect.stack()[0][3]))
return df
Wednesday, August 31, 2022
Spark - SCD2 Type Update
Subscribe to:
Post Comments (Atom)
No comments:
Post a Comment