Demystifying fit_transform and transform in Scikit-learn: Which Method to Use for Data Preprocessing and When?

Satish Mishra
3 min readApr 24, 2023

--

Machine learning models often require data preprocessing before being trained. Preprocessing involves transforming the raw data into a format that is more suitable for machine learning algorithms. In scikit-learn, there are two methods used for data preprocessing: fit_transform and transform. As these two concepts are commonly confused by many people, I decided to write a blog post to explain the differences between them and help clarify any confusion.

fit_transform

fit_transform is a method that combines the fit() and transform() methods into a single step. It is commonly used to preprocess the training data and learn any necessary parameters, such as mean and standard deviation for scaling, on the training set. The learned parameters are then applied to both the training and testing data using the transform() method. This ensures that the testing data is preprocessed in the same way as the training data.

Here’s an example of using fit_transform to preprocess a training set of data:

from sklearn.preprocessing import StandardScaler

# Create a StandardScaler object
scaler = StandardScaler()

# Apply fit_transform to the training data
X_train_scaled = scaler.fit_transform(X_train)

In this example, we create a StandardScaler object called scaler. We then apply the fit_transform method to the training set X_train, which scales the data and learns any necessary parameters (e.g., mean and standard deviation). The result of fit_transform is stored in a new variable called X_train_scaled.

transform

The transform the method is used to apply the learned parameters to new data. This is typically done on testing data after preprocessing has been applied to the training data using fit_transform. By applying the same preprocessing steps to both the training and testing data, we ensure that the testing data is processed in the same way as the training data.

Here’s an example of using transform to preprocess a testing set of data:

# Apply transform to the testing data
X_test_scaled = scaler.transform(X_test)

In this example, we apply the transform method to the testing set X_test, which applies the same scaling and parameter learning that was done on the training set. The result of transform is stored in a new variable called X_test_scaled.

Note: We used the same scaler object to transform this test data on which we applied train data.

When to use fit_transform and transform

fit_transform is typically used on the training set to learn any necessary parameters and preprocess the data. The same fit_transform object should then be used to preprocess the testing set using the transform method. This ensures that the testing data is processed in the same way as the training data and prevents data leakage.
This is also useful when your test data has few different features set than train data

Putting it together

Here’s an example of how to use fit_transform and transform on a complete dataset:

# Import the dataset
from sklearn.datasets import load_iris
iris = load_iris()

# Split the dataset into training and testing sets
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)

# Create a StandardScaler object and apply it to the training data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

# Apply the same scaler object to the testing data
X_test_scaled = scaler.transform(X_test)

In this example, we import the Iris dataset and split it into training and testing sets. We then create an StandardScaler object and use fit_transform it to preprocess the training data. Finally, we use the same scaler object to preprocess the testing data using the transform method.

Conclusion

In summary, transform is used to transform your test or prediction data by utilizing the learning from the trained model transformed using fit_transform .

--

--