Custom Preprocessing Transformers in Snowflake ML
Introduction: Input Data Preprocessing with Snowflake ML
This blog post provides a code sample demonstrating how to implement a custom data preprocessing transformer for Snowflake ML, similar to creating a custom step in a Scikit-learn pipeline.
But wait, can't we just ask the AI to write the code for us? While AI code generation tools are evolving, Snowflake ML's niche nature makes it more reliable to craft custom code for now. An AI-generated Transformer code is likely to be of very poor design if it works at all.
Snowflake is a powerful cloud-based data platform with robust machine learning capabilities. Beyond its highly performant relational database, Snowflake offers tools like the Model Registry and Snowpark (Spark for Snowflake) to facilitate ML workflows. However, Snowflake ML's focus on tight integration and ease of use sometimes comes at the expense of broader ecosystem compatibility. This can present challenges when implementing specialised functionality, such as custom data preprocessing transformers. Let us look at an example: data preprocessing.
A typical Snowflake ML preprocessing pipeline closely resembles a Scikit-learn pipeline:
pipeline_steps = [
(
"OE",
OrdinalEncoder(
input_cols=categorical_columns_oe_in,
output_cols=categorical_columns_oe_out,
categories=categories_oe,
)
),
(
"MMS",
MinMaxScaler(
clip=True,
input_cols=numerical_columns,
output_cols=numerical_columns,
)
),
(
"OHE",
OneHotEncoder(
input_cols=categorical_columns_ohe_in,
output_cols=categorical_columns_ohe_out
)
)
]
This familiarity makes it easy for those experienced with Scikit-learn and Spark to work with Snowflake ML. However, a deeper dive reveals more differences between the two. Implementing a custom transformer requires a deeper understanding of Snowflake ML specifics.
Creating a Custom Transformer
A good starting point for developing a custom transformer is examining the existing Snowflake ML implementations on GitHub, particularly the BaseTransformer
class and examples like the MinMaxScaler
. However, the code sample given below is more concise and may be better as a template.
Snowflake ML transformers can handle both Snowpark DataFrames and Pandas DataFrames. However, the example below supports only Snowpark DataFrames for performance reasons: unlike Pandas, Snowpark leverages the power of Snowflake's relational database for operations like joins, filtering, and sorting.
The following StringScaler
transformer, inheriting from BaseTransformer
, demonstrates a simple string length scaling. This transformer calculates a normalised value in the range of 0.0 to 1.0, based on the length of each string, requiring a two-step fit and transform process. The fit stage determines the maximum string length, which the transform stage uses for normalisation.
from snowflake.snowpark import types as T
from snowflake.snowpark import functions as F
from snowflake.ml.modeling.framework import base
from snowflake.ml._internal.exceptions import error_codes, exceptions
from snowflake.snowpark.dataframe import Column, DataFrame
from typing import Any, List, Optional, Union
class StringScaler(base.BaseTransformer):
def __init__(
self,
input_cols: Optional[Union[str, List[str]]] = None,
output_cols: Optional[Union[str, List[str]]] = None,
drop_input_cols: Optional[bool] = False,
remove_underscore: bool = False
# additional Transformer inputs come here
) -> None:
super().__init__(drop_input_cols=drop_input_cols)
self.input_cols = input_cols
self.output_cols = output_cols
self.remove_underscore = remove_underscore
# values to be fitted:
self.max_string_lengths_fitted = {}
# here we only support SnowPark DataFrames. if your code also supports Pandas DataFrames, use super()._check_dataset_type()
@staticmethod
def _check_dataset_type(dataset: Any) -> None:
if not (isinstance(dataset, DataFrame)):
raise exceptions.SnowflakeMLException(
error_code=error_codes.INVALID_ARGUMENT,
original_exception=TypeError(
f"Unexpected dataset type: {type(dataset)}."
f"Supported dataset types: {type(DataFrame)}."
),
)
def _reset(self) -> None:
super()._reset()
# nullify fitted dictionaries
self.max_string_lengths_fitted = {}
# _fit function is called by BaseTransformer.fit() - make sure your method is called _fit(), not fit().
def _fit(self, dataset: DataFrame) -> "StringScaler":
super()._check_input_cols()
super()._check_dataset_type(dataset)
# if your input datatype is numeric, you may want to check input_cols column data types here
self._reset()
for input_column_name in self.input_cols:
# populate the fitted dictionary with fitted values
input_column: Column = dataset[input_column_name]
calculated_column: Column = F.max(F.length(input_column))
self.max_string_lengths_fitted[input_column_name] = dataset.select(calculated_column).collect()[0][0]
self._is_fitted = True
return self
def transform(self, dataset: DataFrame) -> DataFrame:
self._enforce_fit()
super()._check_input_cols()
super()._check_output_cols()
self._check_dataset_type(dataset)
passthrough_columns = [c for c in dataset.columns if c not in self.output_cols]
output_columns = []
for input_column_name in self.input_cols:
# output column - your transformations come here:
if self.remove_underscore:
input_column: Column = F.regexp_replace(dataset[input_column_name], "_", "")
else:
input_column: Column = dataset[input_column_name]
input_column_length: Column = F.length(input_column)
output_column: Column = input_column_length / self.max_string_lengths_fitted[input_column_name] # please note: calculation, uses the fitted value
output_columns.append(output_column)
transformed_dataset: DataFrame = dataset.with_columns(self.output_cols, output_columns)
transformed_dataset = transformed_dataset[self.output_cols + passthrough_columns] # output cols at the front
return self._drop_input_columns(transformed_dataset) if self._drop_input_cols is True else transformed_dataset
This simplified example demonstrates the fundamental structure of a custom Snowflake ML transformer. You can expand upon this foundation to create more complex transformations tailored to your specific data preprocessing needs. Remember to consider factors like handling nulls, different data types and efficient use of Snowpark for optimal performance within the Snowflake environment.