1746D - Paths on the Tree - CodeForces Solution


dfs and similar dp greedy sortings trees *1900

Please click on ads to support us..

Python Code:

import sys, threading


def DFS(v,k,boolean):
    ret = v.lable * k
    ret2 = v.lable * (k+1)
    child_size = len(v.child)
    if child_size == 0:
        if boolean:
            return [ret,ret2]
        else:
            return [ret ,0]
    if child_size ==1:
        tmp = DFS(v.child[0],k,boolean)
        if boolean:
            return [ret + tmp[0], ret2 + tmp[1]]
        else:
            return [ret + tmp[0], 0]
    k1 = int(k/child_size)
    remind = k - k1*child_size
    if remind == 0 and not boolean:
        for child in v.child:
            tmp = DFS(child,k1,False)
            ret += tmp[0]
        return [ret,0]
    tmp_list = []
    for child in v.child:
        tmp= DFS(child,k1,True)
        ret += tmp[0]
        ret2 += tmp[0]
        dif = tmp[1] - tmp[0]
        tmp_list.append(dif)
    tmp_list.sort()
    tmp_list.reverse()
    i = 0
    while i < remind:
        ret += tmp_list[i]
        ret2 += tmp_list[i]
        i += 1
    if boolean :
        ret2 += tmp_list[remind]
        return [ret,ret2]
    else:
        return [ret,0]


class treeNode:

    def __init__(self):
        self.child = list()
        self.lable = 0


def DFS_Helper():
    t = int(input())
    for iter in range(t):
        line1 = input().split(" ")
        line2 = input().split(" ")
        line3 = input().split(" ")
        n = int(line1[0])
        k = int(line1[1])
        arr_s = [int(x) for x in line3]
        node_list = [treeNode() for i in range(0,n)]
        i = 2
        for x in line2:
            tmp = int(x)
            node_list[tmp-1].child.append(node_list[i-1])
            i += 1
        i = 0
        for s in arr_s:
            node_list[i].lable = s
            i += 1
        root = node_list[0]
        try:
            print(DFS(root, k,False)[0])
        except Exception as e:
            print(e)
    exit()


sys.setrecursionlimit(3*(10**5))
threading.stack_size(10**8)
t1 = threading.Thread(target=DFS_Helper)
t1.start()

C++ Code:

/* Author : Hakesh D */
#include <bits/stdc++.h>
using namespace std;
// #include <ext/pb_ds/assoc_container.hpp>
// #include <ext/pb_ds/tree_policy.hpp>
// using namespace __gnu_pbds;
// #define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>

#define int long long
#define ld long double
 
 
#define ar2 array<int,2>
#define MAX *max_element
#define MIN *min_element
#define all(c) (c).begin(), (c).end()
#define sz(x) (int)(x).size()
 
#define inf 1e9
#define INF 1e18

int cil(int x,int y){if(x%y == 0) {return x/y;}; return x/y+1;}
/***********************************************************************************************/


void solve(){
    int n, k;
    cin >> n >> k;
    vector<int> par(n+1);
    vector<vector<int>> adj(n+1);
    for(int i = 2; i <= n; i++) {
        int p; cin >> p; par[i] = p;
        adj[p].push_back(i);
    }

    vector<int> s(n+1, 0);
    for(int i = 1; i <= n; i++) cin >> s[i];

    vector<int> val1(n+1, 0), val2(n+1, 0);
    function<void(int, int)> dfs = [&](int v, int a) {
        int q = sz(adj[v]);
        if(q == 0) {
            val1[v] = s[v] * a;
            val2[v] = s[v] * (a + 1);
            return;
        }
        for(auto u: adj[v]) {
            dfs(u, a/q);
        }

        //val1[v] wrt a
        int r = a % q;
        int curr = a * s[v];
        vector<int> extras;
        for(auto u: adj[v]) {
            curr += val1[u];
            extras.push_back(max(0ll, val2[u] - val1[u]));
        }
        sort(all(extras), greater<int>());
        for(int i = 0; i < r; i++) curr += extras[i];
        val1[v] = curr;


        // val2[v] wrt a+1
        r = (a + 1) % q;
        if(r == 0) r = q;
        curr = (a + 1) * s[v];
        extras.clear();
        for(auto u: adj[v]) {
            curr += val1[u];
            extras.push_back(max(0ll, val2[u] - val1[u]));
        }
        sort(all(extras), greater<int>());
        for(int i = 0; i < r; i++) curr += extras[i];
        val2[v] = curr;
    };

    dfs(1, k);
    cout << val1[1] << endl;
}


int32_t main(){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int tt = 1;
    cin >> tt;
    for(int tc = 1; tc <= tt; tc++) {
        //cout << "Case #" << tc << ": ";
        solve();
    }
    return 0;
}
/***********************************************************************************************/






Comments

Submit
0 Comments
More Questions

952. Largest Component Size by Common Factor
212. Word Search II
174. Dungeon Game
127. Word Ladder
123. Best Time to Buy and Sell Stock III
85. Maximal Rectangle
84. Largest Rectangle in Histogram
60. Permutation Sequence
42. Trapping Rain Water
32. Longest Valid Parentheses
Cutting a material
Bubble Sort
Number of triangles
AND path in a binary tree
Factorial equations
Removal of vertices
Happy segments
Cyclic shifts
Zoos
Build a graph
Almost correct bracket sequence
Count of integers
Differences of the permutations
Doctor's Secret
Back to School
I am Easy
Teddy and Tweety
Partitioning binary strings
Special sets
Smallest chosen word