MAML : Model Agnostic Meta Learning

Last Updated : 16 Jan, 2026

Model-Agnostic Meta-Learning(MAML) is a meta-learning algorithm designed to train models that can adapt to a new task using very few data points and a very few gradient steps, in an essence the model learns to learn.

MAML learns an initialization of parameters/weights such that the model can adapt to any new task in the distribution in relatively fewer steps than random initialization.

Why MAML Exists

In standard training, a model:

  • learns one task
  • Requires many labelled examples / training data
  • Can't generalize well to tasks outside its domain, requires re-training or fine-tuning.
maml_2
for ALL tasks in the Distribution , MAML optimizes for a good starting point

MAML solves this by learning to learn , essentially acting as a effective few-shot learner, MAML shines when

  • All the tasks are derived from a single Task distribution (T(x)).
  • Each task has very little data.
  • Computation is either limited or we want fast adaptations rather than learning from scratch.

Instead of learning parameters Îļ that are optimal for one task, MAML learns parameters Îļ such that, After 1–5 gradient steps on a new task, The adapted parameters perform well on that task. So Îļ is not the final solution, it is a good starting point.

Algorithm

meta_training

We will begin with understanding the algorithm mathematically.

Requirements / Hyper-Parameters

p(T) : Probability Distribution from where tasks are sampled
\alpha, \beta : Learning rates for first-order and second-order gradient updates.

Step 1 : Initialize model weights randomly

Model weights are sampled from the uniform distribution
\theta \sim \mu(-a,+a) : where (mu) is a random uniform distribution between +a to -a.

Step 2 : Sampling a Batch of Tasks from the Task distribution

T_i \sim p(T) : One task is sampled from many tasks.

Step 3 : Sample 'K' number of datapoints from the sampled task

\mathcal{D}_i = \{(x^{(i)}, y^{(i)})\} : X is a data point and Y is a label and Di is the Task's data distribution.

Step 4 : Calculate loss and evaluate gradients

\nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta}) : where \mathcal{L} is the loss function and \nabla_{\theta} is the gradient w.r.t. loss

Step 5: Compute adapted parameters with gradient descent

\theta_i' = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta}) : standard gradient descent update with alpha learning rate

Step 6: Sampled validation data from the task for evaluation

\mathcal{D}_j = \{(x^{(j)}, y^{(j)})\} : This acts as the validation set , sampled from the same task distribution \Rho(\Tau)

Step 7: Calculate gradients like above and update the model parameters

\theta_j' = \theta - \alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_i}(f_{\theta}) : Standard gradient update using second-order derivative.

Implementation

Now we will look at implementation of MAML, on a task of predicting family of sine wave functions.

Step 1: Import Necessary Libraries

Python
import numpy as np
import tensorflow as tf

Step 2: Define Task Distribution

Our task distribution will be a family of sine functions with varying amplitude(A) and phase (\phi ) , given by y(t) = A \sin(t + \phi )

download
Sine waves with varying amplitude and phase

Step 3: Code for Randomly Sampling from the task distribution

  • Amplitude(amp) is a uniform distribution over range (0,1 - 5.0)
  • phase is also a uniform distribution over 0 to \pi as sinusoidal functions repeat themselves after pi.
  • X is the value whose sine will be taken , i.e. sine(x)
  • Y is the aggregation of above 3 variables , also we casted the variables into float32 because it's default for tensorflow.
Python
def sample_sine_task(K = 10):
  amp = np.random.uniform(0.1,5.0)
  phase = np.random.uniform(0,np.pi)
  X = np.random.uniform(-5,5,size=(K,1))
  y = amp * np.sin(X + phase)
  return X.astype(np.float32), y.astype(np.float32) # tensorflow default values

Step 4: Define a simple model

  • Model consists of 3 Dense layers with 40 neurons in hidden layer , and 1 in the output layer, with a 'relu' activation function.
  • We will be using keras' Sequential API as the flow is sequential in nature.
  • we will simulate a forward pass using zero input to build weights.
Python
def create_model():
    return tf.keras.Sequential([
        tf.keras.layers.Dense(40, activation='relu'),
        tf.keras.layers.Dense(40, activation='relu'),
        tf.keras.layers.Dense(1)
    ])

model = create_model()
model(tf.zeros((1, 1))) #dummy forward pass

Output:

_neural_network
Deep learning Model representation

Step 5: Define Hyper parameters , loss function and optimizer to use

Python
inner_lr = 0.01 # inner-loop learning rate (alpha)
outer_lr = 0.001 # meta learning rate (beta)
inner_steps = 1 # total iterations of inner
meta_batch_size = 4 # how much batches meta update has

loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(outer_lr)

Step 6: Define a function to simulate forward pass

We will define a function that takes in input and a set of weights and applies a forward pass with those weights, without actually changing the model's weights.

Python
def forward_pass_with_weights(x, weights):
    h = x
    idx = 0

    # layer 1
    h = tf.matmul(h, weights[idx]) + weights[idx + 1]
    h = tf.nn.relu(h)
    idx += 2

    # layer 2
    h = tf.matmul(h, weights[idx]) + weights[idx + 1]
    h = tf.nn.relu(h)
    idx += 2

    # layer 3 (output)
    h = tf.matmul(h, weights[idx]) + weights[idx + 1]

    return h

Step 7: Core MAML training Loop (most important)

  • Function first initializes meta_loss to 0 and opens up the outer tf.GradientTape() , which is records operations and performs auto-diff.
  • then we sample two datasets one our training set and one is our validation set.
  • We record the model's weights , and open up the inner tf.GradientTape()
  • After that , we perform a forward pass with these weights , compute the loss , and calculate gradients w.r.t. to the loss.
  • we find out and apply the gradients ( Note : The gradients are not applied to the model yet !!)
  • We simulate a forward pass with these weights as defined in Step 6 and record the output.
  • We increment the meta_loss with the loss between the simulated forward pass and True value.
  • Then in the outer loop , we calculate gradient of the gradients , hence performing a double differentiation operation ( this is one of the biggest drawbacks of MAML).
  • Then we apply those outer gradients to the model.
Python
@tf.function()
def maml_train_step():
    meta_loss = 0.0
    with tf.GradientTape() as outer_tape:
        for _ in range(meta_batch_size):

            x_train, y_train = sample_sine_task()  # train set
            x_val, y_val = sample_sine_task()      # validation set
            weights = model.trainable_variables    # initial weights

            with tf.GradientTape() as inner_tape:
                y_pred = model(x_train)            # predictions with random variables
                train_loss = loss_fn(y_train, y_pred)  # loss with predictions

            grads = inner_tape.gradient(train_loss, weights)  # gradients with respect to initial loss

            adapted_weights = [
                w - inner_lr * g for w, g in zip(weights, grads)
            ]  # weights ater applying those gradients

            y_val_pred = forward_pass_with_weights(
                x_val, adapted_weights
            )  # if those gradients were applied to the model what would the forward pass output look like

            meta_loss += loss_fn(y_val, y_val_pred)  # meta -loss

        meta_loss /= meta_batch_size

    meta_grads = outer_tape.gradient(
        meta_loss, model.trainable_variables
    )

    optimizer.apply_gradients(
        zip(meta_grads, model.trainable_variables)
    )

    return meta_loss

Step 8: Training for 2000 epochs

Python
for step in range(2000):
    loss = maml_train_step()
    if step % 200 == 0:
        print(f"Step {step}, Meta Loss: {loss.numpy():.4f}")

Output:

Screenshot-2026-01-03-111026

Step 9: Evaluating the model on unseen task for one-step adaption:

Now we check the performance of a model that has not seen the task before v/s after MAML optimized model.

Python
# New unseen task
x_train, y_train = sample_sine_task()
x_test = np.linspace(-5, 5, 100).reshape(-1, 1).astype(np.float32)

# Before adaptation
y_before = model(x_test)

# One-step adaptation
with tf.GradientTape() as tape:
    loss = loss_fn(y_train, model(x_train))
grads = tape.gradient(loss, model.trainable_variables)

adapted_weights = [
    w - inner_lr * g for w, g in zip(model.trainable_variables, grads)
]

y_after = forward_pass_with_weights(x_test, adapted_weights)
print(tf.reduce_sum(y_before),tf.reduce_sum(y_after))

Output:

Screenshot-2026-01-03-111357

You can find and download the updated code from here.

Applications of MAML

  1. Few-shot image Classification : MAML demonstrates that few-shot learning can be framed as a meta-learning problem , MAML has demonstrated exception performance on datasets like Mini-Image-Net and Omniglot , making it a competitive baseline for few-shot tasks.
  2. Robotics : Robots have to interact with environment and make quick decisions based on limited data and limited time , MAML shines here by optimizing for quick decision-making under less external data.
  3. Reinforcement Learning : MAML accelerates policy adaption in neural networks by initializing neural networks policies to regions of parameter space where task-specific policies can be learned quickly via gradient descent.

Limitations

  1. Computational complexity and memory overhead : MAML requires second-order derivatives through the inner loop updates, making it computationally expensive. This results in higher memory consumption and training times that can be 2x or 3x longer than first-order alternatives.
  2. Training Instability : As MAML is a bi-optimization problem , small changes to hyper-parameters can lead to poor training.
  3. Question of Objective : Yes, MAML optimizes to optimizes but in some cases this makes things worse than improve , MAML usually settles in areas where gradients are accessible easily , but normal first-order derivatives , can reach there in a few iterations near easily, making MAML unnecessary.
Comment

Explore