Great_expectations: writing custom expectations

By Paolo Léonard

If you are working with a lot of data like we do at dataroots then it is highly possible that you encountered your fair share of bad datasets with unexpected or missing values.

At that point you have two choices, either you do not care about those dataset imperfections and you use the data as it is or you do care about it and and you use Great Expectation (or GE), the best batteries-included data validation tool out there.

In this article, I will not go into the details on how to set up and start using GE, for that I refer to the excellent tutorial made by some dataroots colleagues.
This article is for the developers who have already been using GE to the point they need more from it. In this article I will explain how to develop a custom table expectation using the V3 api.

Table expectation

In this tutorial we will develop a custom table expectation. But what is a table expectation?
GE offers different type of expectations: column, column pair, ... and table. A table expectation is an expectation that will validate information about the table itself. For example, a valid table expectation is expect_table_row_count_to_equal, which checks whether the row count of the a table is equal to the given count or not.

Today, we will be busy implementing an expectation that will check the row count of a dataset and will compare it to other datasets row counts. We can use this expectation to ensure that the number of data we receive over multiple timeframes is consistent. For example, let's say that each day you record every transaction that a bot makes in a dataset. You would not expect less transactions from day to day. Therefore, you would validate that the number of rows in each new dataset should not be less than the previous datasets row counts.

To this end we will implement an expectation called expect_table_row_count_to_be_more_than_otherswith these parameters:

  • other_table_filenames_list which is the list of the other datasets that we want to compare,
  • comparison_key the type of comparison (mean or absolute in this tutorial),
  • *lower_percentage_threshold *is the lower bound for the row count in percentage.

Architecture

So what do we need to develop this expectation? First, we need to implement a way to compare the row count based on the comparison_key. Second, we will create a Table metric using GE library to load and count the other tables rows. And lastly, we will build the actual expectation using the two first steps.

Implementation

Comparator

Let's start by the easy part, the comparison method. To build it, I used a python Enum class:

from enum import Enum, auto
from statistics import mean


class SupportedComparisonEnum(Enum):
    """Enum class with the currently supported comparison type."""

    ABSOLUTE = auto()
    MEAN = auto()

    def __call__(self, *args, **kwargs):
        if self.name == "ABSOLUTE":
            return all(args[0] >= i * args[2] / 100 for i in args[1])
        elif self.name == "MEAN":
            return args[0] >= mean(args[1]) * args[2] / 100
        else:
            raise NotImplementedError("Comparison key is not supported.")

I wrote the comparator using enum because it allows for new comparison keys to be easily added. In the future, if you want to compare based on a new key you just have to add a new field with its name and its logic in the __call__method. Then we can call it like this: SupportedComparisonEnum["YOURNEWKEY"](current_row_count, list_of_previous_row_counts, lower_percentage_threshold) and get the result of the comparison function.

Table metric

Then we need a custom metric to actually return the number of rows in each additional dataset. This metric will be called a number of times equals to the number of additional datasets. However, loading external datasets based on their filename is not currently supported by GE. Therefore, I implemented my own method:

from typing import Dict, Tuple, Any

from great_expectations.core.batch_spec import PathBatchSpec
from great_expectations.execution_engine import (
    SparkDFExecutionEngine,
    PandasExecutionEngine
)
from great_expectations.expectations.metrics.metric_provider import metric_value
from great_expectations.expectations.metrics.table_metric_provider import (
    TableMetricProvider,
)


class OtherTableRowCount(TableMetricProvider):
    """MetricProvider class to get row count from different tables than the current one."""

    metric_name = "table.row_count_other"

    @metric_value(engine=PandasExecutionEngine)
    def _pandas(
        cls,
        execution_engine: "PandasExecutionEngine",
        metric_domain_kwargs: Dict,
        metric_value_kwargs: Dict,
        metrics: Dict[Tuple, Any],
        runtime_configuration: Dict,
    ) -> int:
        other_table_filename = metric_domain_kwargs.get("table_filename")
        batch_spec = PathBatchSpec(
            {"path": other_table_filename, "reader_method": "read_csv"}
        )
        batch_data = execution_engine.get_batch_data(batch_spec=batch_spec)
        df = batch_data.dataframe
        return df.shape[0]

    @metric_value(engine=SparkDFExecutionEngine)
    def _spark(
        cls,
        execution_engine: "SparkDFExecutionEngine",
        metric_domain_kwargs: Dict,
        metric_value_kwargs: Dict,
        metrics: Dict[Tuple, Any],
        runtime_configuration: Dict,
    ) -> int:
        other_table_filename = metric_domain_kwargs.get("table_filename")
        batch_spec = PathBatchSpec(
            {"path": other_table_filename, "reader_method": "csv"}
        )
        batch_data = execution_engine.get_batch_data(batch_spec=batch_spec)
        df = batch_data.dataframe
        return df.count()

So what do we have here? We have a table metric called OtherTableRowCount based on the GE base class TableMetricProvider. To create a GE metric, two things (at least) need to be declared/implemented:

  • the metric name which will be used to call the metric. In this example it is "table.row_count_other" (by the way, GE has a naming convention for the metrics which you can find here).
  • an actual method which will compute the metric. GE supports three backend dataset processors, namely pandas, spark and SQL. Here we will only implement the logic for the pandas and spark data engine.

The first requirement is easy to fulfil, see line 17 from the code above.

The second requirement takes a bit more effort to write because you need to evaluate the metric which can be sometimes dependent on outside parameters. Let's go over the pandas metric implementation line by line (it is actually the same logic for the spark engine, just with different semantics):

other_table_filename = metric_domain_kwargs.get("table_filename") this line sets the filename for the table we need to load and count the rows. The metric_domain_kwargs is a dictionary with the parameters that the metric needs and is set by the expectation which calls the metric.

batch_spec = PathBatchSpec({"path": other_table_filename, "reader_method": "read_csv"}) is a PathBatchSpec object which will be used by an execution engine to load the dataset into memory using its system path.

batch_data = execution_engine.get_batch_data(batch_spec=batch_spec) is the batch data that contains the dataset as a pandas dataframe. We need to use this little work around to load external datasets since no built-in function exists at the moment. Finally, the last two lines are just retrieving the pandas dataframe and returning the row count.

All right, we got our table metric, let us implement the actual expectation.

Expectation

First, the expectation code:

from copy import deepcopy
from typing import Dict, Tuple, Any, Optional, Callable, List

from great_expectations.core import ExpectationConfiguration
from great_expectations.execution_engine import (
    ExecutionEngine
)
from great_expectations.expectations.expectation import TableExpectation
from great_expectations.exceptions.exceptions import InvalidExpectationKwargsError


class ExpectTableRowCountToBeMoreThanOthers(TableExpectation):
    """TableExpectation class to compare the row count of the current dataset to other dataset(s)."""

    metric_dependencies = ("table.row_count", "table.row_count_other")
    success_keys = (
        "other_table_filenames_list",
        "comparison_key",
        "lower_percentage_threshold",
    )
    default_kwarg_values = {
        "other_table_filenames_list": None,
        "comparison_key": "MEAN",
        "lower_percentage_threshold": 100,
    }

    @staticmethod
    def _validate_success_key(
        param: str,
        required: bool,
        configuration: Optional[ExpectationConfiguration],
        validation_rules: Dict[Callable, str],
    ) -> None:
        """Simple method to aggregate and apply validation rules to the `param`."""
        if param not in configuration.kwargs:
            if required:
                raise InvalidExpectationKwargsError(
                    f"Param {param} is required but was not found in configuration."
                )
            return

        param_value = configuration.kwargs[param]

        for rule, error_message in validation_rules.items():
            if not rule(param_value):
                raise InvalidExpectationKwargsError(error_message)

    def validate_configuration(
        self, configuration: Optional[ExpectationConfiguration]
    ) -> bool:
        super().validate_configuration(configuration=configuration)
        if configuration is None:
            configuration = self.configuration

        self._validate_success_key(
            param="other_table_filenames_list",
            required=True,
            configuration=configuration,
            validation_rules={
                lambda x: isinstance(x, str)
                or isinstance(
                    x, List
                ): "other_table_filenames_list should either be a list or a string.",
                lambda x: x: "other_table_filenames_list should not be empty",
            },
        )

        self._validate_success_key(
            param="comparison_key",
            required=False,
            configuration=configuration,
            validation_rules={
                lambda x: isinstance(x, str): "comparison_key should be a string.",
                lambda x: x.upper()
                in SupportedComparisonEnum.__members__: "Given comparison_key is not supported.",
            },
        )

        self._validate_success_key(
            param="lower_percentage_threshold",
            required=False,
            configuration=configuration,
            validation_rules={
                lambda x: isinstance(
                    x, int
                ): "lower_percentage_threshold should be an integer.",
                lambda x: x
                > 0: "lower_percentage_threshold should be strictly greater than 0.",
            },
        )

        return True

    def get_validation_dependencies(
        self,
        configuration: Optional[ExpectationConfiguration] = None,
        execution_engine: Optional[ExecutionEngine] = None,
        runtime_configuration: Optional[dict] = None,
    ) -> dict:
        dependencies = super().get_validation_dependencies(
            configuration, execution_engine, runtime_configuration
        )

        other_table_filenames_list = configuration.kwargs.get(
            "other_table_filenames_list"
        )

        if isinstance(other_table_filenames_list, str):
            other_table_filenames_list = [other_table_filenames_list]

        for other_table_filename in other_table_filenames_list:
            table_row_count_metric_config_other = deepcopy(
                dependencies["metrics"]["table.row_count_other"]
            )
            table_row_count_metric_config_other.metric_domain_kwargs[
                "table_filename"
            ] = other_table_filename

            dependencies["metrics"][
                f"table.row_count_other.{other_table_filename}"
            ] = table_row_count_metric_config_other

        dependencies["metrics"]["table.row_count.self"] = dependencies["metrics"].pop(
            "table.row_count"
        )
        dependencies["metrics"].pop("table.row_count_other")

        return dependencies

    def _validate(
        self,
        configuration: ExpectationConfiguration,
        metrics: Dict,
        runtime_configuration: dict = None,
        execution_engine: ExecutionEngine = None,
    ) -> Dict:
        comparison_key = self.get_success_kwargs(configuration)["comparison_key"]
        other_table_filename_list = self.get_success_kwargs(configuration)[
            "other_table_filenames_list"
        ]
        lower_percentage_threshold = self.get_success_kwargs(configuration)[
            "lower_percentage_threshold"
        ]
        current_row_count = metrics["table.row_count.self"]
        previous_row_count_list = []

        for other_table_filename in other_table_filename_list:
            previous_row_count_list.append(
                metrics[f"table.row_count_other.{other_table_filename}"]
            )

        comparison_key_fn = SupportedComparisonEnum[comparison_key.upper()]
        success_flag = comparison_key_fn(
            current_row_count, previous_row_count_list, lower_percentage_threshold
        )

        return {
            "success": success_flag,
            "result": {
                "self": current_row_count,
                "other": previous_row_count_list,
                "comparison_key": comparison_key_fn.name,
            },
        }

Let's build this gradually starting with the parameters and the table metrics we need for this expectation. The parameters we want are already specified at the beginning of the article. Those are defined as a tuple assigned to success_keys. You can assign default value to these parameters with the default_kwargs_values dictionary. If you want to set a parameter as mandatory, you need to give it the default value None.

Then for the metrics we need, we will be, of course, using the one we just implemented table.row_count_other plus the table.row_count metric.

Now, we need three additional methods to have our expectation working:

  • the validate_configuration method validates the input parameters. To avoid replication, I used a small helper function _validate_success_key that takes as arguments the name of the parameters it needs to check, the configuration of the expectation and a dictionary of rules that the parameter needs to comply as its key and error message as the corresponding value (e.g. {lambda x: x > 0: "parameter should be positive"} checks that the parameter is positive and if not throws an error with the message "parameter should be positive",
  • get the metric dependencies with method get_validation_dependencies, meaning giving the correct arguments to the metrics we will use. Usually, you only need to overwrite this method if you are going to do some shenanigans like we are about to do.
    Metrics are measured based on the dataset GE is handling at that time. So there can be multiple metrics per dataset but only one dataset per metric. However, we need multiple table.row_count_other metrics for each additional datasets. Therefore, for each dataset in the parameter other_table_filenames_list we add a new dependency metric that will give us the row count for that particular dataset. In the end, we have multiple datasets that will use the same metric,
  • and finally validate the dataset with _validate using the results of the metrics. We gather each row count, compare them to the current row count based on the comparison key and return the result. The result is dictionary that must have the key "success". We can additional information like the row counts and the comparison key to have a better understanding of the result.

And... It's done! You can find the full code here. Put the new expectation in the plugins folder of the great_expectations directory you get after initialising GE. We can test it using the taxi data from the GE official tutorial. Follow the tutorial and put multiple taxi datasets into your data folder, create your first expectation suite with the command line tool and run the following cell in the newly created jupyter notebook:

import expect_table_row_count_to_be_more_than_others
validator.expect_table_row_count_to_be_more_than_others(other_table_filenames_list=[“/data/yellow_tripdata_2019–02.csv”, “/data/yellow_tripdata_2019–03.csv”],
 comparison_key=”mean”)

If everything went correctly, you should see a dictionary with the row counts and the success flag.

Thank you for following along this tutorial, I hope it helped you create some sick custom expectations and have a better understanding of the amazing tool that Great Expectations is. If you have any question, do not hesitate to post them in the comments section!

I'll see you in the next article!