Multitable Synthetic Data with Calculated Features
This section demonstrates how to use the Calculated Features
module for MultiTable synthesis, also known as Business Rules, in ydata-sdk
.
Don't forget to set up your license key
Example Code
import json
import pandas as pd
from ydata.connectors.storages.rdbms_connector import MySQLConnector, PostgreSQLConnector
from ydata.metadata.multimetadata import MultiMetadata
from ydata.synthesizers.multitable.model import MultiTableSynthesizer
def get_connector(credentials_path: str, connector_type: str, database: str):
with open(f"{credentials_path}/{connector_type}_credentials.json", "r") as f:
connection_string = json.load(f)
connection_string["database"] = database
if connector_type == "mysql":
return MySQLConnector(conn_string=connection_string)
else:
return PostgreSQLConnector(conn_string=connection_string)
def load_database(credentials_path: str, database: str, connector_type: str, lazy: bool = False):
connector = get_connector(credentials_path, connector_type, database)
dataset = connector.read_database(lazy=lazy)
return dataset
def calculate_total_volume(fund_symbol, volume):
df = pd.concat([fund_symbol, volume], axis=1)
df = df.groupby(by="fund_symbol").sum().reset_index()
return df
def main():
credentials_path = ".secrets"
dataset = load_database(
credentials_path,
database="database-name",
connector_type="postgresql"
)
metadata = MultiMetadata(dataset)
calculated_features = [
{
"calculated_features": "etf.total_vol",
"function": calculate_total_volume,
"calculated_from": ["etf_price.volume"],
"reference_keys": {
"source": ["etf.fund_symbol"],
# "source": "etf.fund_symbol",
"target": ["etf_price.fund_symbol"],
# "target": "etf_price.fund_symbol",
}
}
]
synthesizer = MultiTableSynthesizer()
synthesizer.fit(
dataset,
metadata,
calculated_features=calculated_features
)
sample = synthesizer.sample(n_samples=1)
sample_dfs = {t: df.to_pandas() for t, df in sample.items()}
total_volume = sample_dfs['etf_price'].groupby(
by="fund_symbol"
)['volume'].sum().reset_index()
print(total_volume.sort_values(by="fund_symbol").head(5))
etf = sample_dfs['etf'][['fund_symbol', 'total_vol']].sort_values(
by="fund_symbol").reset_index(drop=True)
print(etf.sort_values(by="fund_symbol").head(5))
comparison = etf.merge(total_volume, on='fund_symbol', how='left')
comparison['is_valid'] = comparison['total_vol'] == comparison['volume']
print(comparison)
print('is there any invalid result?', (~comparison['is_valid']).any())
if __name__ == '__main__':
main()