Skip to content

Conditional sampling

YData Synthesizers support conditional sampling. The fit method has an optional parameter named condition_on, which receives a list of features to condition upon. Furthermore, the sample method receives the conditions to be applied through another optional parameter also named condition_on. For now, two types of conditions are supported:

  • Condition upon a categorical (or string) feature. The parameters are the name of the feature and a list of values (i.e., categories) to be considered. Each category also has its percentage of representativeness. For example, if we want to condition upon two categories, we need to define the percentage of rows each of these categories will have on the synthetic dataset. Naturally, the sum of such percentages needs to be 1. The default percentage is also 1 since it is the required value for a single category.
  • Condition upon a numerical feature. The parameters are the name of the feature and the minimum and maximum of the range to be considered. This feature will present a uniform distribution on the synthetic dataset, limited by the specified range.

The example below demonstrates how to train and sample from a synthesizer using conditional sampling:

import os

from ydata.sdk.dataset import get_dataset
from ydata.sdk.synthesizers import RegularSynthesizer

# Do not forget to add your token as env variables.
os.environ["YDATA_TOKEN"] = '<TOKEN>'  # Remove if already defined.


def main():
    """In this example, we demonstrate how to train and
    sample from a synthesizer using conditional sampling."""
    X = get_dataset('census')

    # We initialize a regular synthesizer.
    # As long as the synthesizer does not call `fit`, it exists only locally.
    synth = RegularSynthesizer()

    # We train the synthesizer on our dataset setting
    # the features to condition upon.
    synth.fit(
        X,
        name="census_synthesizer",
        condition_on=["sex", "native-country", "age"]
    )

    # We request a synthetic dataset with specific condition rules.
    sample = synth.sample(
        n_samples=500,
        condition_on={
            "sex": {
                "categories": ["Female"]
            },
            "native-country": {
                "categories": [("United-States", 0.6),
                               ("Mexico", 0.4)]
            },
            "age": {
                "minimum": 55,
                "maximum": 60
            }
        }
    )
    print(sample)


if __name__ == "__main__":
    main()