Skip to content

Multitable Synthetic Data with Calculated Features

This section demonstrates how to use the Calculated Features module for MultiTable synthesis, also known as Business Rules, in ydata-sdk.

Don't forget to set up your license key

    import os

    os.environ['YDATA_LICENSE_KEY'] = '{add-your-key}'

Example Code

import json

import pandas as pd

from ydata.connectors.storages.rdbms_connector import MySQLConnector, PostgreSQLConnector
from ydata.metadata.multimetadata import MultiMetadata
from ydata.synthesizers.multitable.model import MultiTableSynthesizer


def get_connector(credentials_path: str, connector_type: str, database: str):
    with open(f"{credentials_path}/{connector_type}_credentials.json", "r") as f:
        connection_string = json.load(f)
    connection_string["database"] = database

    if connector_type == "mysql":
        return MySQLConnector(conn_string=connection_string)
    else:
        return PostgreSQLConnector(conn_string=connection_string)


def load_database(credentials_path: str, database: str, connector_type: str, lazy: bool = False):
    connector = get_connector(credentials_path, connector_type, database)
    dataset = connector.read_database(lazy=lazy)
    return dataset


def calculate_total_volume(fund_symbol, volume):
    df = pd.concat([fund_symbol, volume], axis=1)
    df = df.groupby(by="fund_symbol").sum().reset_index()
    return df


def main():
    credentials_path = ".secrets"
    dataset = load_database(
        credentials_path,
        database="database-name",
        connector_type="postgresql"
    )
    metadata = MultiMetadata(dataset)

    calculated_features = [
        {
            "calculated_features": "etf.total_vol",
            "function": calculate_total_volume,
            "calculated_from": ["etf_price.volume"],
            "reference_keys": {
                "source": ["etf.fund_symbol"],
                # "source": "etf.fund_symbol",
                "target": ["etf_price.fund_symbol"],
                # "target": "etf_price.fund_symbol",
            }
        }
    ]

    synthesizer = MultiTableSynthesizer()
    synthesizer.fit(
        dataset,
        metadata,
        calculated_features=calculated_features
    )

    sample = synthesizer.sample(n_samples=1)
    sample_dfs = {t: df.to_pandas() for t, df in sample.items()}

    total_volume = sample_dfs['etf_price'].groupby(
        by="fund_symbol"
    )['volume'].sum().reset_index()
    print(total_volume.sort_values(by="fund_symbol").head(5))

    etf = sample_dfs['etf'][['fund_symbol', 'total_vol']].sort_values(
        by="fund_symbol").reset_index(drop=True)
    print(etf.sort_values(by="fund_symbol").head(5))

    comparison = etf.merge(total_volume, on='fund_symbol', how='left')
    comparison['is_valid'] = comparison['total_vol'] == comparison['volume']
    print(comparison)

    print('is there any invalid result?', (~comparison['is_valid']).any())


if __name__ == '__main__':
    main()