Balanced Forest
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Queue;
import java.util.StringTokenizer;
import java.util.TreeSet;
class GraphNode {
int value;
List<Integer> adjacents;
GraphNode(int value) {
this.value = value;
this.adjacents = new ArrayList<>();
}
}
class TreeNode {
int value;
TreeNode parent;
List<TreeNode> children;
int enterTime;
int leaveTime;
long valueSum;
TreeNode(int value, TreeNode parent) {
this.value = value;
this.parent = parent;
this.children = new ArrayList<>();
}
}
public class Solution {
static int time;
static long minAddition;
public static void main(String[] args) throws Throwable {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int q = Integer.parseInt(br.readLine());
for (int tc = 0; tc < q; tc++) {
int n = Integer.parseInt(br.readLine());
GraphNode[] graphNodes = new GraphNode[n];
StringTokenizer st = new StringTokenizer(br.readLine());
for (int i = 0; i < graphNodes.length; i++) {
int value = Integer.parseInt(st.nextToken());
graphNodes[i] = new GraphNode(value);
}
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine());
int graphNodeIndex1 = Integer.parseInt(st.nextToken()) - 1;
int graphNodeIndex2 = Integer.parseInt(st.nextToken()) - 1;
graphNodes[graphNodeIndex1].adjacents.add(graphNodeIndex2);
graphNodes[graphNodeIndex2].adjacents.add(graphNodeIndex1);
}
System.out.println(solve(graphNodes));
}
}
static long solve(GraphNode[] graphNodes) {
if (graphNodes.length < 3) {
return -1;
}
int graphRootNodeIndex = findGraphRootNodeIndexWithMinHeight(graphNodes);
TreeNode treeRoot = buildTree(graphNodes, graphRootNodeIndex);
Map<Long, NavigableSet<Integer>> valueSumToEnterTimes = new HashMap<>();
Map<Long, NavigableSet<Integer>> valueSumToLeaveTimes = new HashMap<>();
computeSubtreeValueSums(treeRoot, valueSumToEnterTimes, valueSumToLeaveTimes);
minAddition = Long.MAX_VALUE;
cut(treeRoot.valueSum, valueSumToEnterTimes, valueSumToLeaveTimes, treeRoot);
return minAddition == Long.MAX_VALUE ? -1 : minAddition;
}
static boolean otherContains(Map<Long, NavigableSet<Integer>> valueSumToEnterTimes,
Map<Long, NavigableSet<Integer>> valueSumToLeaveTimes, TreeNode treeNode, long targetValueSum) {
if (valueSumToLeaveTimes.containsKey(targetValueSum)
&& valueSumToLeaveTimes.get(targetValueSum).lower(treeNode.enterTime) != null) {
return true;
}
if (valueSumToEnterTimes.containsKey(targetValueSum)
&& valueSumToEnterTimes.get(targetValueSum).higher(treeNode.leaveTime) != null) {
return true;
}
for (TreeNode p = treeNode.parent; p != null; p = p.parent) {
if (p.valueSum - treeNode.valueSum == targetValueSum) {
return true;
}
}
return false;
}
static void cut(long originalTotal, Map<Long, NavigableSet<Integer>> valueSumToEnterTimes,
Map<Long, NavigableSet<Integer>> valueSumToLeaveTimes, TreeNode treeNode) {
long cutValueSum = treeNode.valueSum;
long remainTotal = originalTotal - cutValueSum;
if (cutValueSum <= remainTotal) {
if (cutValueSum == remainTotal) {
minAddition = Math.min(minAddition, cutValueSum);
} else {
if (remainTotal % 2 == 0) {
long halfRemainTotal = remainTotal / 2;
if (halfRemainTotal >= cutValueSum
&& otherContains(valueSumToEnterTimes, valueSumToLeaveTimes, treeNode, halfRemainTotal)) {
minAddition = Math.min(minAddition, halfRemainTotal - cutValueSum);
}
}
long otherValueSum = remainTotal - cutValueSum;
if (cutValueSum >= otherValueSum && (otherContains(valueSumToEnterTimes, valueSumToLeaveTimes, treeNode,
cutValueSum)
|| otherContains(valueSumToEnterTimes, valueSumToLeaveTimes, treeNode, otherValueSum))) {
minAddition = Math.min(minAddition, cutValueSum - otherValueSum);
}
}
}
for (TreeNode child : treeNode.children) {
cut(originalTotal, valueSumToEnterTimes, valueSumToLeaveTimes, child);
}
}
static void computeSubtreeValueSums(TreeNode treeNode, Map<Long, NavigableSet<Integer>> valueSumToEnterTimes,
Map<Long, NavigableSet<Integer>> valueSumToLeaveTimes) {
treeNode.valueSum = treeNode.value;
for (TreeNode child : treeNode.children) {
computeSubtreeValueSums(child, valueSumToEnterTimes, valueSumToLeaveTimes);
treeNode.valueSum += child.valueSum;
}
addToValueSumToTimes(valueSumToEnterTimes, treeNode.valueSum, treeNode.enterTime);
addToValueSumToTimes(valueSumToLeaveTimes, treeNode.valueSum, treeNode.leaveTime);
}
static void addToValueSumToTimes(Map<Long, NavigableSet<Integer>> valueSumToTimes, long valueSum, int time) {
if (!valueSumToTimes.containsKey(valueSum)) {
valueSumToTimes.put(valueSum, new TreeSet<>());
}
valueSumToTimes.get(valueSum).add(time);
}
static TreeNode buildTree(GraphNode[] graphNodes, int graphRootNodeIndex) {
time = 0;
return buildTreeNode(graphNodes, graphRootNodeIndex, new boolean[graphNodes.length], null);
}
static TreeNode buildTreeNode(GraphNode[] graphNodes, int graphNodeIndex, boolean[] visited, TreeNode parent) {
visited[graphNodeIndex] = true;
TreeNode treeNode = new TreeNode(graphNodes[graphNodeIndex].value, parent);
time++;
treeNode.enterTime = time;
for (int adjacent : graphNodes[graphNodeIndex].adjacents) {
if (!visited[adjacent]) {
treeNode.children.add(buildTreeNode(graphNodes, adjacent, visited, treeNode));
}
}
time++;
treeNode.leaveTime = time;
return treeNode;
}
static int findGraphRootNodeIndexWithMinHeight(GraphNode[] graphNodes) {
boolean[] visited = new boolean[graphNodes.length];
Queue<Integer> queue = new LinkedList<>();
for (int i = 0; i < graphNodes.length; i++) {
if (graphNodes[i].adjacents.size() == 1) {
visited[i] = true;
queue.offer(i);
}
}
int rootGraphIndex = -1;
while (!queue.isEmpty()) {
int head = queue.poll();
rootGraphIndex = head;
for (int adjacent : graphNodes[head].adjacents) {
if (!visited[adjacent]) {
visited[adjacent] = true;
queue.offer(adjacent);
}
}
}
return rootGraphIndex;
}
}
Last updated