1401D - Maximum Distributed Tree - CodeForces Solution

dfs and similar dp greedy implementation math number theory sortings trees *1800

Python Code:

import gc
import math
import sqlite3
from collections import Counter, deque, defaultdict
from sys import stdout
import time
from math import factorial, log, gcd
import sys
from decimal import Decimal
import threading
from heapq import *

def S():
    return sys.stdin.readline().split()

def I():
    return [int(i) for i in sys.stdin.readline().split()]

def II():
    return int(sys.stdin.readline())

def IS():
    return sys.stdin.readline().replace('\n', '')

def main():
    n = II()
    tree = [[] for _ in range(n)]
    edges = []
    for _ in range(n - 1):
        u, v = I()
        edges.append((u - 1, v - 1))
        tree[u - 1].append(v - 1)
        tree[v - 1].append(u - 1)

    m = II()
    p = I()

    parents = [-1] * n
    queue = deque([0])
    order = []
    while queue:
        v = queue.pop()
        for u in tree[v]:
            if parents[v] != u:
                parents[u] = v
    order = order[::-1]
    dp = [1] * n
    for v in order:
        if parents[v] != -1:
            dp[parents[v]] += dp[v]
    new_order = [0] * n
    for i in range(n):
        new_order[order[i]] = i
    vertexes = []
    for v1, v2 in edges:
        if new_order[v1] < new_order[v2]:
            u = v1
            u = v2
        vertexes.append((dp[u], n - dp[u], abs(2 * dp[u] - n)))
    vertexes = list(sorted(vertexes, key=lambda x: x[2]))


    if m < n - 1:
        p = list(sorted(p))[::-1] + [1] * (n - 1 - m)
        p = list(sorted(p))[::-1]
        d = 1
        for i in range(m - n + 2):
            d = (d * p[i]) % mod
        p[m - n + 1] = d
        p = p[-(n - 1):]

    ans = 0
    for i in range(n - 1):
        a1, a2, _ = vertexes[i]
        pel = p[i]
        ans = (ans + a1 * a2 * pel) % mod


if __name__ == '__main__':
    mod = 10 ** 9 + 7
    for _ in range(II()):

C++ Code:

#include <bits/stdc++.h>

using namespace std;

const int MOD = 1e9 + 7;

void solve() {
    int n;
    cin >> n;
    vector<vector<int>> adj(n + 1);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
    int m;
    cin >> m;
    vector<long long> fac(m);
    for (auto& x : fac) cin >> x;
    sort(fac.begin(), fac.end());
    while (fac.size() > n - 1) {
        auto temp = fac.back();
        fac.back() *= temp;
        fac.back() %= MOD;
    vector<long long> sz(n + 1, 1);
    function<void(const int&, const int&)> dfs_size = [&](const int& node, const int& p) -> void {
        for (auto& x : adj[node]) {
            if (x == p) continue;
            dfs_size(x, node);
            sz[node] += sz[x]; 
    dfs_size(1, 1);
    vector<long long> rep;
    function<void(const int&, const int&)> dfs_get = [&](const int& node, const int& p) -> void {
        if (node != p) rep.emplace_back(sz[node] * (n - sz[node]));
        for (auto& x : adj[node]) {
            if (x == p) continue;
            dfs_get(x, node);
    dfs_get(1, 1);
    sort(rep.rbegin(), rep.rend());
    long long ans = 0;
    for (int i = 0; i < n - 1; i++) {
        long long f = 1;
        if (!fac.empty()) {
            f = fac.back();
        ans = (ans + (rep[i] % MOD) * f % MOD) % MOD;
    cout << ans << '\n';

int main () {

    int q;
    cin >> q;
    for (int i = 0; i < q; i++) {

    return 0;


