Model-Agnostic Meta-Learning(MAML) is a meta-learning algorithm designed to train models that can adapt to a new task using very few data points and a very few gradient steps, in an essence the model learns to learn.
MAML learns an initialization of parameters/weights such that the model can adapt to any new task in the distribution in relatively fewer steps than random initialization.

Why MAML Exists
In standard training, a model:
- learns one task
- Requires many labelled examples / training data
- Can't generalize well to tasks outside its domain, requires re-training or fine-tuning.

MAML solves this by learning to learn , essentially acting as a effective few-shot learner, MAML shines when
- All the tasks are derived from a single Task distribution (T(x)).
- Each task has very little data.
- Computation is either limited or we want fast adaptations rather than learning from scratch.
Instead of learning parameters Îļ that are optimal for one task, MAML learns parameters Îļ such that, After 1â5 gradient steps on a new task, The adapted parameters perform well on that task. So Îļ is not the final solution, it is a good starting point.
Algorithm

We will begin with understanding the algorithm mathematically.
Requirements / Hyper-Parameters
Step 1 : Initialize model weights randomly
Model weights are sampled from the uniform distribution
Step 2 : Sampling a Batch of Tasks from the Task distribution
Step 3 : Sample 'K' number of datapoints from the sampled task
Step 4 : Calculate loss and evaluate gradients
Step 5: Compute adapted parameters with gradient descent
Step 6: Sampled validation data from the task for evaluation
Step 7: Calculate gradients like above and update the model parameters
Implementation
Now we will look at implementation of MAML, on a task of predicting family of sine wave functions.
Step 1: Import Necessary Libraries
import numpy as np
import tensorflow as tf
Step 2: Define Task Distribution
Our task distribution will be a family of sine functions with varying amplitude(A) and phase (

Step 3: Code for Randomly Sampling from the task distribution
- Amplitude(amp) is a uniform distribution over range (0,1 - 5.0)
- phase is also a uniform distribution over 0 to
\pi as sinusoidal functions repeat themselves after pi. - X is the value whose sine will be taken , i.e. sine(x)
- Y is the aggregation of above 3 variables , also we casted the variables into float32 because it's default for tensorflow.
def sample_sine_task(K = 10):
amp = np.random.uniform(0.1,5.0)
phase = np.random.uniform(0,np.pi)
X = np.random.uniform(-5,5,size=(K,1))
y = amp * np.sin(X + phase)
return X.astype(np.float32), y.astype(np.float32) # tensorflow default values
Step 4: Define a simple model
- Model consists of 3 Dense layers with 40 neurons in hidden layer , and 1 in the output layer, with a 'relu' activation function.
- We will be using keras' Sequential API as the flow is sequential in nature.
- we will simulate a forward pass using zero input to build weights.
def create_model():
return tf.keras.Sequential([
tf.keras.layers.Dense(40, activation='relu'),
tf.keras.layers.Dense(40, activation='relu'),
tf.keras.layers.Dense(1)
])
model = create_model()
model(tf.zeros((1, 1))) #dummy forward pass
Output:

Step 5: Define Hyper parameters , loss function and optimizer to use
inner_lr = 0.01 # inner-loop learning rate (alpha)
outer_lr = 0.001 # meta learning rate (beta)
inner_steps = 1 # total iterations of inner
meta_batch_size = 4 # how much batches meta update has
loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(outer_lr)
Step 6: Define a function to simulate forward pass
We will define a function that takes in input and a set of weights and applies a forward pass with those weights, without actually changing the model's weights.
def forward_pass_with_weights(x, weights):
h = x
idx = 0
# layer 1
h = tf.matmul(h, weights[idx]) + weights[idx + 1]
h = tf.nn.relu(h)
idx += 2
# layer 2
h = tf.matmul(h, weights[idx]) + weights[idx + 1]
h = tf.nn.relu(h)
idx += 2
# layer 3 (output)
h = tf.matmul(h, weights[idx]) + weights[idx + 1]
return h
Step 7: Core MAML training Loop (most important)
- Function first initializes meta_loss to 0 and opens up the outer tf.GradientTape() , which is records operations and performs auto-diff.
- then we sample two datasets one our training set and one is our validation set.
- We record the model's weights , and open up the inner tf.GradientTape()
- After that , we perform a forward pass with these weights , compute the loss , and calculate gradients w.r.t. to the loss.
- we find out and apply the gradients ( Note : The gradients are not applied to the model yet !!)
- We simulate a forward pass with these weights as defined in Step 6 and record the output.
- We increment the meta_loss with the loss between the simulated forward pass and True value.
- Then in the outer loop , we calculate gradient of the gradients , hence performing a double differentiation operation ( this is one of the biggest drawbacks of MAML).
- Then we apply those outer gradients to the model.
@tf.function()
def maml_train_step():
meta_loss = 0.0
with tf.GradientTape() as outer_tape:
for _ in range(meta_batch_size):
x_train, y_train = sample_sine_task() # train set
x_val, y_val = sample_sine_task() # validation set
weights = model.trainable_variables # initial weights
with tf.GradientTape() as inner_tape:
y_pred = model(x_train) # predictions with random variables
train_loss = loss_fn(y_train, y_pred) # loss with predictions
grads = inner_tape.gradient(train_loss, weights) # gradients with respect to initial loss
adapted_weights = [
w - inner_lr * g for w, g in zip(weights, grads)
] # weights ater applying those gradients
y_val_pred = forward_pass_with_weights(
x_val, adapted_weights
) # if those gradients were applied to the model what would the forward pass output look like
meta_loss += loss_fn(y_val, y_val_pred) # meta -loss
meta_loss /= meta_batch_size
meta_grads = outer_tape.gradient(
meta_loss, model.trainable_variables
)
optimizer.apply_gradients(
zip(meta_grads, model.trainable_variables)
)
return meta_loss
Step 8: Training for 2000 epochs
for step in range(2000):
loss = maml_train_step()
if step % 200 == 0:
print(f"Step {step}, Meta Loss: {loss.numpy():.4f}")
Output:

Step 9: Evaluating the model on unseen task for one-step adaption:
Now we check the performance of a model that has not seen the task before v/s after MAML optimized model.
# New unseen task
x_train, y_train = sample_sine_task()
x_test = np.linspace(-5, 5, 100).reshape(-1, 1).astype(np.float32)
# Before adaptation
y_before = model(x_test)
# One-step adaptation
with tf.GradientTape() as tape:
loss = loss_fn(y_train, model(x_train))
grads = tape.gradient(loss, model.trainable_variables)
adapted_weights = [
w - inner_lr * g for w, g in zip(model.trainable_variables, grads)
]
y_after = forward_pass_with_weights(x_test, adapted_weights)
print(tf.reduce_sum(y_before),tf.reduce_sum(y_after))
Output:

You can find and download the updated code from here.
Applications of MAML
- Few-shot image Classification : MAML demonstrates that few-shot learning can be framed as a meta-learning problem , MAML has demonstrated exception performance on datasets like Mini-Image-Net and Omniglot , making it a competitive baseline for few-shot tasks.
- Robotics : Robots have to interact with environment and make quick decisions based on limited data and limited time , MAML shines here by optimizing for quick decision-making under less external data.
- Reinforcement Learning : MAML accelerates policy adaption in neural networks by initializing neural networks policies to regions of parameter space where task-specific policies can be learned quickly via gradient descent.
Limitations
- Computational complexity and memory overhead : MAML requires second-order derivatives through the inner loop updates, making it computationally expensive. This results in higher memory consumption and training times that can be 2x or 3x longer than first-order alternatives.
- Training Instability : As MAML is a bi-optimization problem , small changes to hyper-parameters can lead to poor training.
- Question of Objective : Yes, MAML optimizes to optimizes but in some cases this makes things worse than improve , MAML usually settles in areas where gradients are accessible easily , but normal first-order derivatives , can reach there in a few iterations near easily, making MAML unnecessary.