Add all greater values to every node in a given BST
Last Updated : 17 Oct, 2024
Given a Binary Search Tree (BST), the task is to modify it so that all greater and equal values in the given BST are added to every node.
Examples:
Input:
Output:
Explanation: The above tree represents the greater sum tree where each node contains the sum of all nodes greater than or equal to that node in original tree.
The root node 50 becomes 260 (sum of 50 + 60 + 70 + 80).
The left child of 50 becomes 330 (sum of 30 + 40 + 50 + 60 + 70 + 80).
The right child of 50 becomes 150 (70 + 80) and so on.
[Naive Approach] By Calculating Sum for Each Node - O(n^2) Time and O(n) Space
The idea is to traverse the binary tree and for each node, find the sum of all nodes with values greater than or equal to it. As we traverse, we compute these sums and replace the current node’s value with the corresponding sum.
This method doesn’t require the tree to be a BST. Following are the steps:
Traverse node by node (in-order, pre-order, etc.).
For each node, find all the nodes greater than equal to the current node and sum their values. Store all these sums.
Replace each node’s value with its corresponding sum by traversing in the same order as in Step 1.
Below is the implementation of the above approach:
C++
// C++ program to transform a BST to// sum tree#include<bits/stdc++.h>usingnamespacestd;classNode{public:intdata;Node*left;Node*right;Node(intvalue){data=value;left=nullptr;right=nullptr;}};// Function to find nodes having greater value than// current node.voidfindGreaterNodes(Node*root,Node*curr,unordered_map<Node*,int>&mp){if(root==nullptr)return;// if value is greater than equal to node, // then increment it in the mapif(root->data>=curr->data)mp[curr]+=root->data;findGreaterNodes(root->left,curr,mp);findGreaterNodes(root->right,curr,mp);}voidtransformToGreaterSumTree(Node*curr,Node*root,unordered_map<Node*,int>&mp){if(curr==nullptr){return;}// Find all nodes greater than current nodefindGreaterNodes(root,curr,mp);// Recursively check for left and right subtree.transformToGreaterSumTree(curr->left,root,mp);transformToGreaterSumTree(curr->right,root,mp);}// Function to update value of each node.voidpreOrderTrav(Node*root,unordered_map<Node*,int>&mp){if(root==nullptr)return;root->data=mp[root];preOrderTrav(root->left,mp);preOrderTrav(root->right,mp);}voidtransformTree(Node*root){// map to store greater sum for each node.unordered_map<Node*,int>mp;transformToGreaterSumTree(root,root,mp);// update the value of nodespreOrderTrav(root,mp);}voidinorder(Node*root){if(root==nullptr){return;}inorder(root->left);cout<<root->data<<" ";inorder(root->right);}intmain(){// Representation of input binary tree:// 50// / \// 30 70// / \ / \ // 20 40 60 80Node*root=newNode(50);root->left=newNode(30);root->right=newNode(70);root->left->left=newNode(20);root->left->right=newNode(40);root->right->left=newNode(60);root->right->right=newNode(80);transformTree(root);inorder(root);return0;}
Java
// Java program to transform a BST to// sum treeimportjava.util.HashMap;classNode{intdata;Nodeleft,right;Node(intvalue){data=value;left=null;right=null;}}classGfG{// Function to find nodes having greater value than// current node.staticvoidfindGreaterNodes(Noderoot,Nodecurr,HashMap<Node,Integer>mp){if(root==null)return;// if value is greater than equal to node, // then increment it in the mapif(root.data>=curr.data)mp.put(curr,mp.getOrDefault(curr,0)+root.data);findGreaterNodes(root.left,curr,mp);findGreaterNodes(root.right,curr,mp);}staticvoidtransformToGreaterSumTree(Nodecurr,Noderoot,HashMap<Node,Integer>mp){if(curr==null){return;}// Find all nodes greater than current nodefindGreaterNodes(root,curr,mp);// Recursively check for left and right subtree.transformToGreaterSumTree(curr.left,root,mp);transformToGreaterSumTree(curr.right,root,mp);}// Function to update value of each node.staticvoidpreOrderTrav(Noderoot,HashMap<Node,Integer>mp){if(root==null)return;root.data=mp.getOrDefault(root,0);preOrderTrav(root.left,mp);preOrderTrav(root.right,mp);}staticvoidtransformTree(Noderoot){// map to store greater sum for each node.HashMap<Node,Integer>mp=newHashMap<>();transformToGreaterSumTree(root,root,mp);// update the value of nodespreOrderTrav(root,mp);}staticvoidinorder(Noderoot){if(root==null){return;}inorder(root.left);System.out.print(root.data+" ");inorder(root.right);}publicstaticvoidmain(String[]args){// Representation of input binary tree:// 50// / \// 30 70// / \ / \ // 20 40 60 80Noderoot=newNode(50);root.left=newNode(30);root.right=newNode(70);root.left.left=newNode(20);root.left.right=newNode(40);root.right.left=newNode(60);root.right.right=newNode(80);transformTree(root);inorder(root);}}
Python
# Python program to transform a BST# to sum treeclassNode:def__init__(self,value):self.data=valueself.left=Noneself.right=None# Function to find nodes having greater # value than current node.deffindGreaterNodes(root,curr,map):ifrootisNone:return# if value is greater than equal to node, # then increment it in the mapifroot.data>=curr.data:map[curr]+=root.datafindGreaterNodes(root.left,curr,map)findGreaterNodes(root.right,curr,map)deftransformToGreaterSumTree(curr,root,map):ifcurrisNone:return# Find all nodes greater than current nodemap[curr]=0findGreaterNodes(root,curr,map)# Recursively check for left and right subtree.transformToGreaterSumTree(curr.left,root,map)transformToGreaterSumTree(curr.right,root,map)# Function to update value of each node.defpreOrderTrav(root,map):ifrootisNone:returnroot.data=map.get(root,root.data)preOrderTrav(root.left,map)preOrderTrav(root.right,map)deftransformTree(root):# map to store greater sum for each node.map={}transformToGreaterSumTree(root,root,map)# update the value of nodespreOrderTrav(root,map)definorder(root):ifrootisNone:returninorder(root.left)print(root.data,end=" ")inorder(root.right)if__name__=="__main__":# Representation of input binary tree:# 50# / \# 30 70# / \ / \ # 20 40 60 80root=Node(50)root.left=Node(30)root.right=Node(70)root.left.left=Node(20)root.left.right=Node(40)root.right.left=Node(60)root.right.right=Node(80)transformTree(root)inorder(root)
C#
// C# program to transform a BST// to sum treeusingSystem;usingSystem.Collections.Generic;classNode{publicintdata;publicNodeleft,right;publicNode(intvalue){data=value;left=null;right=null;}}classGfG{// Function to find nodes having greater value// than current node.staticvoidFindGreaterNodes(Noderoot,Nodecurr,Dictionary<Node,int>map){if(root==null)return;// if value is greater than equal to node, // then increment it in the mapif(root.data>=curr.data)map[curr]+=root.data;FindGreaterNodes(root.left,curr,map);FindGreaterNodes(root.right,curr,map);}staticvoidTransformToGreaterSumTree(Nodecurr,Noderoot,Dictionary<Node,int>map){if(curr==null){return;}// Find all nodes greater than// current nodemap[curr]=0;FindGreaterNodes(root,curr,map);// Recursively check for left and right subtree.TransformToGreaterSumTree(curr.left,root,map);TransformToGreaterSumTree(curr.right,root,map);}// Function to update value of each node.staticvoidPreOrderTrav(Noderoot,Dictionary<Node,int>map){if(root==null)return;root.data=map.ContainsKey(root)?map[root]:root.data;PreOrderTrav(root.left,map);PreOrderTrav(root.right,map);}staticvoidTransformTree(Noderoot){// map to store greater sum for each node.Dictionary<Node,int>map=newDictionary<Node,int>();TransformToGreaterSumTree(root,root,map);// update the value of nodesPreOrderTrav(root,map);}staticvoidInorder(Noderoot){if(root==null){return;}Inorder(root.left);Console.Write(root.data+" ");Inorder(root.right);}staticvoidMain(string[]args){// Representation of input binary tree:// 50// / \// 30 70// / \ / \ // 20 40 60 80Noderoot=newNode(50);root.left=newNode(30);root.right=newNode(70);root.left.left=newNode(20);root.left.right=newNode(40);root.right.left=newNode(60);root.right.right=newNode(80);TransformTree(root);Inorder(root);}}
JavaScript
// JavaScript program to transform // a BST to sum treeclassNode{constructor(value){this.data=value;this.left=null;this.right=null;}}// Function to find nodes having greater value // than current node.functionfindGreaterNodes(root,curr,map){if(root===null)return;// if value is greater than equal to node, // then increment it in the mapif(root.data>=curr.data){map.set(curr,(map.get(curr)||0)+root.data);}findGreaterNodes(root.left,curr,map);findGreaterNodes(root.right,curr,map);}functiontransformToGreaterSumTree(curr,root,map){if(curr===null){return;}// Find all nodes greater than current nodefindGreaterNodes(root,curr,map);// Recursively check for left and right subtree.transformToGreaterSumTree(curr.left,root,map);transformToGreaterSumTree(curr.right,root,map);}// Function to update value of each node.functionpreOrderTrav(root,map){if(root===null)return;root.data=map.has(root)?map.get(root):0;preOrderTrav(root.left,map);preOrderTrav(root.right,map);}functiontransformTree(root){// map to store greater sum for each node.constmap=newMap();transformToGreaterSumTree(root,root,map);// update the value of nodespreOrderTrav(root,map);}functioninorder(root){if(root===null){return;}inorder(root.left);console.log(root.data+" ");inorder(root.right);}// Representation of input binary tree:// 50// / \// 30 70// / \ / \ // 20 40 60 80letroot=newNode(50);root.left=newNode(30);root.right=newNode(70);root.left.left=newNode(20);root.left.right=newNode(40);root.right.left=newNode(60);root.right.right=newNode(80);transformTree(root);inorder(root);
Output
350 330 300 260 210 150 80
Note: Since this approach runs in O(n2) this will give TLE, so we need to think of a more efficient approach.
[Expected Approach] Using Single Traversal – O(n) Time and O(h) Space
The idea is to traverse the tree in reverse in-order (right -> root -> left) while keeping a running sum of all previously visited nodes. The value of each node is updated to this running sum, which ensure that each node contains the sum of all nodes greater than equal to it.
Below is the implementation of the above approach:
C++
// C++ program to transform a BST to sum tree#include<bits/stdc++.h>usingnamespacestd;classNode{public:intdata;Node*left;Node*right;Node(intvalue){data=value;left=nullptr;right=nullptr;}};voidtransformToGreaterSumTree(Node*root,int&sum){if(root==nullptr){return;}// Traverse the right subtree first (larger values)transformToGreaterSumTree(root->right,sum);// Update the sum and the current node's valuesum+=root->data;root->data=sum;// Traverse the left subtree (smaller values)transformToGreaterSumTree(root->left,sum);}voidtransformTree(Node*root){// Initialize the cumulative sumintsum=0;transformToGreaterSumTree(root,sum);}voidinorder(Node*root){if(root==nullptr){return;}inorder(root->left);cout<<root->data<<" ";inorder(root->right);}intmain(){// Representation of input binary tree:// 50// / \ // 30 70// / \ / \ // 20 40 60 80Node*root=newNode(50);root->left=newNode(30);root->right=newNode(70);root->left->left=newNode(20);root->left->right=newNode(40);root->right->left=newNode(60);root->right->right=newNode(80);transformTree(root);inorder(root);return0;}
C
// C program to transform a BST // to sum tree#include<stdio.h>#include<stdlib.h>structNode{intdata;structNode*left;structNode*right;};voidtransformToGreaterSumTree(structNode*root,int*sum){if(root==NULL){return;}// Traverse the right subtree first (larger values)transformToGreaterSumTree(root->right,sum);// Update the sum and the current node's value*sum+=root->data;root->data=*sum;// Traverse the left subtree (smaller values)transformToGreaterSumTree(root->left,sum);}voidtransformTree(structNode*root){// Initialize the cumulative sumintsum=0;transformToGreaterSumTree(root,&sum);}voidinorder(structNode*root){if(root==NULL){return;}inorder(root->left);printf("%d ",root->data);inorder(root->right);}structNode*createNode(intdata){structNode*node=(structNode*)malloc(sizeof(structNode));node->data=data;node->left=NULL;node->right=NULL;returnnode;}intmain(){// Representation of input binary tree:// 50// / \ // 30 70// / \ / \ // 20 40 60 80structNode*root=createNode(50);root->left=createNode(30);root->right=createNode(70);root->left->left=createNode(20);root->left->right=createNode(40);root->right->left=createNode(60);root->right->right=createNode(80);transformTree(root);inorder(root);return0;}
Java
// Java program to transform a // BST to sum treeclassNode{intdata;Nodeleft,right;Node(intvalue){data=value;left=right=null;}}classGfG{staticvoidtransformToGreaterSumTree(Noderoot,int[]sum){if(root==null){return;}// Traverse the right subtree first (larger values)transformToGreaterSumTree(root.right,sum);// Update the sum and the current node's valuesum[0]+=root.data;root.data=sum[0];// Traverse the left subtree (smaller values)transformToGreaterSumTree(root.left,sum);}staticvoidtransformTree(Noderoot){// Initialize the cumulative sumint[]sum={0};transformToGreaterSumTree(root,sum);}staticvoidinorder(Noderoot){if(root==null){return;}inorder(root.left);System.out.print(root.data+" ");inorder(root.right);}publicstaticvoidmain(String[]args){// Representation of input binary tree:// 50// / \// 30 70// / \ / \ // 20 40 60 80Noderoot=newNode(50);root.left=newNode(30);root.right=newNode(70);root.left.left=newNode(20);root.left.right=newNode(40);root.right.left=newNode(60);root.right.right=newNode(80);transformTree(root);inorder(root);}}
Python
# Python program to transform a # BST to sum treeclassNode:def__init__(self,value):self.data=valueself.left=Noneself.right=NonedeftransformToGreaterSumTree(root,sum):ifrootisNone:return# Traverse the right subtree first# (larger values)transformToGreaterSumTree(root.right,sum)# Update the sum and the current node's valuesum[0]+=root.dataroot.data=sum[0]# Traverse the left subtree (smaller values)transformToGreaterSumTree(root.left,sum)deftransformTree(root):# Initialize the cumulative sumsum=[0]transformToGreaterSumTree(root,sum)definorder(root):ifrootisNone:returninorder(root.left)print(root.data,end=" ")inorder(root.right)if__name__=="__main__":# Representation of input binary tree:# 50# / \# 30 70# / \ / \ # 20 40 60 80root=Node(50)root.left=Node(30)root.right=Node(70)root.left.left=Node(20)root.left.right=Node(40)root.right.left=Node(60)root.right.right=Node(80)transformTree(root)inorder(root)
C#
// C# program to transform a BST to// sum treeusingSystem;classNode{publicintdata;publicNodeleft,right;publicNode(intvalue){data=value;left=right=null;}}classGfG{staticvoidtransformToGreaterSumTree(Noderoot,refintsum){if(root==null){return;}// Traverse the right subtree first (larger values)transformToGreaterSumTree(root.right,refsum);// Update the sum and the current node's valuesum+=root.data;root.data=sum;// Traverse the left subtree (smaller values)transformToGreaterSumTree(root.left,refsum);}staticvoidtransformTree(Noderoot){// Initialize the cumulative sumintsum=0;transformToGreaterSumTree(root,refsum);}staticvoidinorder(Noderoot){if(root==null){return;}inorder(root.left);Console.Write(root.data+" ");inorder(root.right);}staticvoidMain(){// Representation of input binary tree:// 50// / \// 30 70// / \ / \ // 20 40 60 80Noderoot=newNode(50);root.left=newNode(30);root.right=newNode(70);root.left.left=newNode(20);root.left.right=newNode(40);root.right.left=newNode(60);root.right.right=newNode(80);transformTree(root);inorder(root);}}
JavaScript
// JavaScript program to transform a// BST to sum treeclassNode{constructor(value){this.data=value;this.left=null;this.right=null;}}functiontransformToGreaterSumTree(root,sum){if(root===null){return;}// Traverse the right subtree first (larger values)transformToGreaterSumTree(root.right,sum);// Update the sum and the current node's valuesum[0]+=root.data;root.data=sum[0];// Traverse the left subtree (smaller values)transformToGreaterSumTree(root.left,sum);}functiontransformTree(root){letsum=[0];// Initialize the cumulative sumtransformToGreaterSumTree(root,sum);}// Function to perform in-order traversalfunctioninorder(root){if(root===null){return;}inorder(root.left);console.log(root.data+" ");inorder(root.right);}// Representation of input binary tree:// 50// / \// 30 70// / \ / \ // 20 40 60 80constroot=newNode(50);root.left=newNode(30);root.right=newNode(70);root.left.left=newNode(20);root.left.right=newNode(40);root.right.left=newNode(60);root.right.right=newNode(80);transformTree(root);inorder(root);