Count all K Sum Paths in Binary Tree

Last Updated : 4 Oct, 2025

Given the root of a binary tree and an integer k, Count the number of paths in the tree such that the sum of the nodes in each path equals k.
A path can start from any node and end at any node and must be downward only.
Note: The nodes can have negative values.

Examples:

Input: k = 7

420046697

Output: 3
Explanation:

4200466971

Input: k = 7

420046771


Output: 3
Explanation:

420046697
Try It Yourself
redirect icon

[Naive Approach] By Exploring All Possible Paths - O(n2) Time and O(h) Space

The simplest approach to solve this problem is that, for each node in the tree, we consider it as the starting point of a path and explore all possible paths that go downward from this node. We calculate the sum of each path and check if it equals k.

C++
#include <iostream>
#include <vector>
using namespace std;

// Node Structure
class Node {
public:
    int data;
    Node* left;
    Node* right;

    Node(int k) {
        data = k;
        left = nullptr;
        right = nullptr;
    }
};

// Function to count paths with
// sum k starting from the given node
int countPathsFromNode(Node* node, int k, int currentSum) {
    if (node == nullptr)
        return 0;
  
  	int pathCount = 0;
    currentSum += node->data;

    if (currentSum == k)
        pathCount++;

  	// Recur for the left and right subtree
    pathCount += countPathsFromNode(node->left, k, currentSum);
    pathCount += countPathsFromNode(node->right, k, currentSum);
  
  	return pathCount;
}

// Function to count all paths 
// that sum to k in the binary tree
int countAllPaths(Node* root, int k) {
    if (root == nullptr)
        return 0;

    // Count all paths starting from the current node
    int res = countPathsFromNode(root, k, 0);

    // Recursive for the left and right subtree
    res += countAllPaths(root->left, k);
    res += countAllPaths(root->right, k);
  
  	return res;
}

int main() {
  
    // Create a sample tree:
    //        8
    //      /  \
    //     4    5
    //    / \    \
    //   3   2    2
	//  / \   \
    // 3  -2   1

    Node* root = new Node(8);
    root->left = new Node(4);
    root->right = new Node(5);
    root->left->left = new Node(3);
    root->left->right = new Node(2);
    root->right->right = new Node(2);
  	root->left->left->left = new Node(3);
  	root->left->left->right = new Node(-2);
    root->left->right->right = new Node(1);

    int k = 7; 

    cout << countAllPaths(root, k) << endl;
    return 0;
}
Java
// Node Structure
class Node {
    int data;
    Node left, right;

    Node(int k) {
        data = k;
        left = null;
        right = null;
    }
}

class GfG {

    // Function to count paths with
    // sum k starting from the given node
    static int countPathsFromNode(Node node, int k, int currentSum) {
        if (node == null)
            return 0;

        int pathCount = 0;
        currentSum += node.data;
        
        if (currentSum == k)
            pathCount++;

        // Recur for the left and right subtree
        pathCount += countPathsFromNode(node.left, k, currentSum);
        pathCount += countPathsFromNode(node.right, k, currentSum);

        return pathCount;
    }

    // Function to count all paths 
    // that sum to k in the binary tree
    static int countAllPaths(Node root, int k) {
        if (root == null)
            return 0;

        // Count all paths starting from the current node
        int res = countPathsFromNode(root, k, 0);

        // Recur for the left and right subtree
        res += countAllPaths(root.left, k);
        res += countAllPaths(root.right, k);

        return res;
    }

    public static void main(String[] args) {
        // Create a sample tree:
        //        8
        //      /  \
        //     4    5
        //    / \    \
        //   3   2    2
        //  / \   \
        // 3  -2   1

        Node root = new Node(8);
        root.left = new Node(4);
        root.right = new Node(5);
        root.left.left = new Node(3);
        root.left.right = new Node(2);
        root.right.right = new Node(2);
        root.left.left.left = new Node(3);
        root.left.left.right = new Node(-2);
        root.left.right.right = new Node(1);

        int k = 7;

        System.out.println(countAllPaths(root, k));
    }
}
Python
# Node Structure
class Node:
    def __init__(self, k):
        self.data = k
        self.left = None
        self.right = None

# Function to count paths with 
# sum k starting from the given node
def countPathsFromNode(node, k, currentSum):
    if node is None:
        return 0

    pathCount = 0
    currentSum += node.data

    if currentSum == k:
        pathCount += 1

    # Recur for the left and right subtree
    pathCount += countPathsFromNode(node.left, k, currentSum)
    pathCount += countPathsFromNode(node.right, k, currentSum)

    return pathCount

# Function to count all paths 
# that sum to k in the binary tree
def countAllPaths(root, k):
    if root is None:
        return 0

    # Count all paths starting from the current node
    res = countPathsFromNode(root, k, 0)

    # Recur for the left and right subtree
    res += countAllPaths(root.left, k)
    res += countAllPaths(root.right, k)

    return res
  	
if __name__ == "__main__":
    # Create a sample tree:
    #        8
    #      /  \
    #     4    5
    #    / \    \
    #   3   2    2
    #  / \   \
    # 3  -2   1

    root = Node(8)
    root.left = Node(4)
    root.right = Node(5)
    root.left.left = Node(3)
    root.left.right = Node(2)
    root.right.right = Node(2)
    root.left.left.left = Node(3)
    root.left.left.right = Node(-2)
    root.left.right.right = Node(1)

    k = 7
    print(countAllPaths(root, k))
C#
using System;

// Node Structure
class Node {
    public int data;
    public Node left, right;

    public Node(int k) {
        data = k;
        left = null;
        right = null;
    }
}

class GFG {

    // Function to count paths with 
    // sum k starting from the given node
    static int countPathsFromNode(Node node, int k, int currentSum) {
        if (node == null)
            return 0;

        int pathCount = 0;
        currentSum += node.data;

        if (currentSum == k)
            pathCount++;

        // Recur for the left and right subtree
        pathCount += countPathsFromNode(node.left, k, currentSum);
        pathCount += countPathsFromNode(node.right, k, currentSum);

        return pathCount;
    }

    // Function to count all paths
    // that sum to k in the binary tree
    static int countAllPaths(Node root, int k) {
        if (root == null)
            return 0;

        // Count all paths starting from the current node
        int res = countPathsFromNode(root, k, 0);

        // Recur for the left and right subtree
        res += countAllPaths(root.left, k);
        res += countAllPaths(root.right, k);

        return res;
    }

    static void Main(string[] args) {
        // Create a sample tree:
        //        8
        //      /  \
        //     4    5
        //    / \    \
        //   3   2    2
        //  / \   \
        // 3  -2   1

        Node root = new Node(8);
        root.left = new Node(4);
        root.right = new Node(5);
        root.left.left = new Node(3);
        root.left.right = new Node(2);
        root.right.right = new Node(2);
        root.left.left.left = new Node(3);
        root.left.left.right = new Node(-2);
        root.left.right.right = new Node(1);

        int k = 7;

        Console.WriteLine(countAllPaths(root, k));
    }
}
JavaScript
// Node Structure
class Node {
    constructor(k) {
        this.data = k;
        this.left = null;
        this.right = null;
    }
}

// Function to count paths with 
// sum k starting from the given node
function countPathsFromNode(node, k, currentSum) {
    if (node === null)
        return 0;

    let pathCount = 0;
    currentSum += node.data;

    if (currentSum === k)
        pathCount++;

    // Recur for the left and right subtree
    pathCount += countPathsFromNode(node.left, k, currentSum);
    pathCount += countPathsFromNode(node.right, k, currentSum);

    return pathCount;
}

// Function to count all paths 
// that sum to k in the binary tree
function countAllPaths(root, k) {
    if (root === null)
        return 0;

    // Count all paths starting from the current node
    let res = countPathsFromNode(root, k, 0);

    // Recur for the left and right subtree
    res += countAllPaths(root.left, k);
    res += countAllPaths(root.right, k);

    return res;
}

// Driver Code
// Create a sample tree:
//        8
//      /  \
//     4    5
//    / \    \
//   3   2    2
//  / \   \
// 3  -2   1

const root = new Node(8);
root.left = new Node(4);
root.right = new Node(5);
root.left.left = new Node(3);
root.left.right = new Node(2);
root.right.right = new Node(2);
root.left.left.left = new Node(3);
root.left.left.right = new Node(-2);
root.left.right.right = new Node(1);

const k = 7;

console.log(countAllPaths(root, k));

Output
3

[Expected Approach] Using Prefix Sum Technique - O(n) Time and O(n) Space

Prerequisite: The approach is similar to finding subarray with given sum.

We can use prefix sums with a hashmap to efficiently track the sum of paths in the binary tree. The prefix sum up to a node is the sum of all node values from the root to that node.

We traverse the tree using recursion and by storing the prefix sums of current path from root in a hashmap, we can quickly find if there are any sub-paths that sum to the target value k by checking the difference between the current prefix sum and k.

If the difference (current prefix sum - k) exists in the hashmap, it means there exists one or more paths, ending at the current node, that sums to k so we increment our count accordingly.


C++
#include <iostream>
#include <vector>
#include <unordered_map>
using namespace std;

// Node Structure
class Node {
  public :
    int data;
    Node* left;
    Node* right;

    Node(int val) {
        data = val;
        left = nullptr;
        right = nullptr;
    }
};

int countPathsUtil(Node* node, int k, int currSum, 
                   		unordered_map<int, int>& prefSums) {
  
    if (node == nullptr)
        return 0;
  
  	int pathCount = 0;
    currSum += node->data;
  	 
    if (currSum == k)
        pathCount++;
  	
	// The count of (curr_sum − k) gives the number 
    // of paths with sum k up to the current node
    pathCount += prefSums[currSum - k];
  
  	// Add the current sum into the hashmap
    prefSums[currSum]++;

    pathCount += countPathsUtil(node->left, k, currSum, prefSums);
    pathCount += countPathsUtil(node->right, k, currSum, prefSums);

    // Remove the current sum from the hashmap
    prefSums[currSum]--;
  
  	return pathCount;
}

int countAllPaths(Node* root, int k) {
    unordered_map<int, int> prefSums;

    return countPathsUtil(root, k, 0, prefSums);
}

int main() {
  
    // Create a sample tree:
    //        8
    //      /  \
    //     4    5
    //    / \    \
    //   3   2    2
	//  / \   \
    // 3  -2   1

    Node* root = new Node(8);
    root->left = new Node(4);
    root->right = new Node(5);
    root->left->left = new Node(3);
    root->left->right = new Node(2);
    root->right->right = new Node(2);
  	root->left->left->left = new Node(3);
  	root->left->left->right = new Node(-2);
    root->left->right->right = new Node(1);

    int k = 7; 

    cout << countAllPaths(root, k) << endl;
    return 0;
}
Java
import java.util.HashMap;

// Node Structure
class Node {
    int data;
    Node left, right;

    Node(int val) {
        data = val;
        left = null;
        right = null;
    }
}

class GFG {
 
    static int countPathsUtil(Node node, int k, int currSum, 
                              	HashMap<Integer, Integer> prefSums) {
        if (node == null)
            return 0;

        int pathCount = 0;
        currSum += node.data;

        if (currSum == k)
            pathCount++;

        // The count of (curr_sum − k) gives the number 
        // of paths with sum k up to the current node
        pathCount += prefSums.getOrDefault(currSum - k, 0);

        // Add the current sum into the hashmap
        prefSums.put(currSum, prefSums.getOrDefault(currSum, 0) + 1);

        pathCount += countPathsUtil(node.left, k, currSum, prefSums);
        pathCount += countPathsUtil(node.right, k, currSum, prefSums);

        // Remove the current sum from the hashmap
        prefSums.put(currSum, prefSums.get(currSum) - 1);

        return pathCount;
    }

    static int countAllPaths(Node root, int k) {
        HashMap<Integer, Integer> prefSums = new HashMap<>();
      
        return countPathsUtil(root, k, 0, prefSums);
    }

    public static void main(String[] args) {
        // Create a sample tree:
        //        8
        //      /  \
        //     4    5
        //    / \    \
        //   3   2    2
        //  / \   \
        // 3  -2   1

        Node root = new Node(8);
        root.left = new Node(4);
        root.right = new Node(5);
        root.left.left = new Node(3);
        root.left.right = new Node(2);
        root.right.right = new Node(2);
        root.left.left.left = new Node(3);
        root.left.left.right = new Node(-2);
        root.left.right.right = new Node(1);

        int k = 7;
        System.out.println(countAllPaths(root, k));
    }
}
Python
# Node Structure
class Node:
    def __init__(self, val):
        self.data = val
        self.left = None
        self.right = None


def countPathsUtil(node, k, currSum, prefSums):
    if node is None:
        return 0

    pathCount = 0
    currSum += node.data

    if currSum == k:
        pathCount += 1

    # The count of (curr_sum − k) gives the number 
    # of paths with sum k up to the current node
    pathCount += prefSums.get(currSum - k, 0)

    # Add the current sum into the hashmap
    prefSums[currSum] = prefSums.get(currSum, 0) + 1

    pathCount += countPathsUtil(node.left, k, currSum, prefSums)
    pathCount += countPathsUtil(node.right, k, currSum, prefSums)

    # Remove the current sum from the hashmap
    prefSums[currSum] -= 1

    return pathCount

def countAllPaths(root, k):
    prefSums = {}
    return countPathsUtil(root, k, 0, prefSums)

if __name__ == "__main__":
    # Create a sample tree:
    #        8
    #      /  \
    #     4    5
    #    / \    \
    #   3   2    2
    #  / \   \
    # 3  -2   1

    root = Node(8)
    root.left = Node(4)
    root.right = Node(5)
    root.left.left = Node(3)
    root.left.right = Node(2)
    root.right.right = Node(2)
    root.left.left.left = Node(3)
    root.left.left.right = Node(-2)
    root.left.right.right = Node(1)

    k = 7
    print(countAllPaths(root, k))
C#
using System;
using System.Collections.Generic;

// Node Structure
class Node {
    public int data;
    public Node left, right;

    public Node(int val) {
        data = val;
        left = null;
        right = null;
    }
}

class GFG {
  
    static int countPathsUtil(Node node, int k, int currSum, 
                              	Dictionary<int, int> prefSums) {
        if (node == null)
            return 0;

        int pathCount = 0;
        currSum += node.data;

        if (currSum == k)
            pathCount++;

        // The count of (curr_sum − k) gives the number 
        // of paths with sum k up to the current node
        if (prefSums.ContainsKey(currSum - k))
            pathCount += prefSums[currSum - k];

        // Add the current sum into the hashmap
        if (!prefSums.ContainsKey(currSum))
            prefSums[currSum] = 0;
        prefSums[currSum]++;

        pathCount += countPathsUtil(node.left, k, currSum, prefSums);
        pathCount += countPathsUtil(node.right, k, currSum, prefSums);

        // Remove the current sum from the hashmap
        prefSums[currSum]--;

        return pathCount;
    }

    static int countAllPaths(Node root, int k) {
        var prefSums = new Dictionary<int, int>();
        return countPathsUtil(root, k, 0, prefSums);
    }

    static void Main() {
        // Create a sample tree:
        //        8
        //      /  \
        //     4    5
        //    / \    \
        //   3   2    2
        //  / \   \
        // 3  -2   1

        Node root = new Node(8);
        root.left = new Node(4);
        root.right = new Node(5);
        root.left.left = new Node(3);
        root.left.right = new Node(2);
        root.right.right = new Node(2);
        root.left.left.left = new Node(3);
        root.left.left.right = new Node(-2);
        root.left.right.right = new Node(1);

        int k = 7;
        Console.WriteLine(countAllPaths(root, k));
    }
}
JavaScript
// Node Structure
class Node {
    constructor(val) {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}

function countPathsUtil(node, k, currSum, prefSums) {
    if (node === null) return 0;

    let pathCount = 0;
    currSum += node.data;

    if (currSum === k) pathCount++;

    // The count of (curr_sum − k) gives the number 
    // of paths with sum k up to the current node
    pathCount += prefSums[currSum - k] || 0;

    // Add the current sum into the hashmap
    prefSums[currSum] = (prefSums[currSum] || 0) + 1;

    pathCount += countPathsUtil(node.left, k, currSum, prefSums);
    pathCount += countPathsUtil(node.right, k, currSum, prefSums);

    // Remove the current sum from the hashmap
    prefSums[currSum]--;

    return pathCount;
}

function countAllPaths(root, k) {
    const prefSums = {};
    return countPathsUtil(root, k, 0, prefSums);
}

// Driver Code
// Create a sample tree:
//        8
//      /  \
//     4    5
//    / \    \
//   3   2    2
//  / \   \
// 3  -2   1

const root = new Node(8);
root.left = new Node(4);
root.right = new Node(5);
root.left.left = new Node(3);
root.left.right = new Node(2);
root.right.right = new Node(2);
root.left.left.left = new Node(3);
root.left.left.right = new Node(-2);
root.left.right.right = new Node(1);

const k = 7;
console.log(countAllPaths(root, k));

Output
3
Comment