1725E - Electrical Efficiency - CodeForces Solution


combinatorics data structures dp math number theory trees *2500

Please click on ads to support us..

C++ Code:

#include <bits/stdc++.h>
typedef long long ll;
const int BUFFER = 1 << 18;
struct ostream
{
    char buffer[BUFFER], *pos = buffer, *end = buffer + BUFFER;
    ~ostream() { flush(); }
    void flush() { fwrite(buffer, 1, pos - buffer, stdout), pos = buffer; }
    void put(char ch)
    {
        if (pos == end)
            flush();
        *(pos++) = ch;
    }
    template <typename V>
    void put(V num)
    {
        if (num)
            put(num / 10), put((char)(num % 10 + '0'));
    }
    ostream &operator<<(char s) { return put(s), *this; }
    ostream &operator<<(const char *s)
    {
        while (*s)
            put(*(s++));
        return *this;
    }
    template <typename V, std::enable_if_t<std::is_integral<V>::value, bool> = true>
    ostream &operator<<(V num)
    {
        if (num < 0)
            put('-'), put(-num);
        else if (num == 0)
            put('0');
        else
            put(num);
        return *this;
    }
} cout;
struct istream
{
    char buffer[BUFFER], *pos = buffer, *end = buffer;
    int get()
    {
        if (pos == end)
        {
            end = buffer + fread(buffer, 1, BUFFER, stdin), pos = buffer;
            if (pos == end)
                return 0;
        }
        return *(pos++);
    }
    istream &operator>>(char &ch)
    {
        while ((ch = get()) <= ' ')
            ;
        return *this;
    }
    template <typename V, std::enable_if_t<std::is_integral<V>::value, bool> = true>
    istream &operator>>(V &num)
    {
        char ch;
        while ((ch = get()) < '-')
            ;
        int sign = ch == '-';
        num = sign ? 0 : ch - '0';
        while ((ch = get()) > ' ')
            num = 10 * num + ch - '0';
        if (sign)
            num = -num;
        return *this;
    }
} cin;
#ifdef LOCAL
#include "debug.h"
#else
#define log(...) 9
#endif

struct montgomery
{
    using u64 = unsigned long long;
    using u32 = unsigned;
    u32 mod, inv, r, r2, r3;
    montgomery() {}
    montgomery(u32 mod) : mod(mod) { init(); }
    friend istream &operator>>(istream &in, montgomery &a) { return in >> a.mod, a.init(), in; }
    void init()
    {
        inv = 1;
        for (int i = 0; i < 5; i++)
            inv *= 2 - mod * inv;
        r = -mod % mod;
        r2 = ((u64)r * r) % mod;
        r3 = mul(r2, r2);
    }
    u32 add(u32 v) { return mul(v, r2); }
    u32 reduce(u64 v)
    {
        u32 l = ((u64)((u32)v * inv) * mod) >> 32;
        u32 h = v >> 32;
        if (h < l)
            return h + mod - l;
        return h - l;
    }
    u32 mul(u64 a, u32 b) { return reduce(a * b); }
} mod = 998244353;

struct mint
{
    unsigned v;
    mint() {}
    mint(unsigned v) : v(v) {}
    mint(int v) : v(mod.add(v)) {}
    template <typename V>
    friend V &operator<<(V &out, const mint &a) { return out << mod.reduce(a.v); }
    bool operator==(const mint &a) { return v == a.v; }
    bool operator!=(const mint &a) { return v != a.v; }
    mint &operator+=(const mint &a)
    {
        v += a.v;
        if (v >= mod.mod)
            v -= mod.mod;
        return *this;
    }
    mint &operator-=(const mint &a)
    {
        if (v < a.v)
            v += mod.mod - a.v;
        else
            v -= a.v;
        return *this;
    }
    mint operator*(const mint &a) const { return mod.mul(v, a.v); }
    mint operator/(const mint &a) const { return *this * a.inv(); }
    mint operator+(const mint &a) const { return mint(v) += a; }
    mint operator-(const mint &a) const { return mint(v) -= a; }
    mint &operator*=(const mint &a) { return *this = *this * a; }
    mint &operator/=(const mint &a) { return *this = *this / a; }
    mint inv() const
    {
        int a = v, b = mod.mod, x = 1, y = 0;
        while (a)
        {
            int q = b / a, r = b - q * a, z = y - q * x;
            b = a, a = r, y = x, x = z;
        }
        if (y < 0)
            y += mod.mod;
        return mod.mul(y, mod.r3);
    }
    mint pow(int n) const
    {
        mint ans(1);
        for (mint current(v); n; n >>= 1, current *= current)
            if (n & 1)
                ans *= current;
        return ans;
    }
};

template <int maxN>
struct binomial
{
    mint fac[maxN + 1], invfac[maxN + 1];
    binomial()
    {
        mint current = fac[0] = 1;
        for (int i = 1; i <= maxN; i++)
            fac[i] = current *= i;

        current = invfac[maxN] = current.inv();
        for (int i = maxN - 1; 0 <= i; i--)
            invfac[i] = current *= i + 1;
    }
    mint operator()(int n, int r)
    {
        if (r < 0)
            return mint(0);
        if (n < r)
            return mint(0);
        return fac[n] * invfac[r] * invfac[n - r];
    }
};
binomial<200000> C;

template <int maxN>
struct sieve
{
    int leastPrimes[maxN + 1], nexts[maxN + 1], primes[maxN + 1], *last;
    sieve()
    {
        memset(leastPrimes, 0, sizeof(leastPrimes));
        leastPrimes[1] = nexts[1] = 1, last = primes;
        for (int i = 2; i <= maxN; i++)
        {
            if (leastPrimes[i] == 0)
            {
                *(last++) = i;
                leastPrimes[i] = i, nexts[i] = 1;
            }
            for (auto p = primes; p < last; p++)
            {
                int current = i * *p;
                if (current > maxN)
                    break;
                leastPrimes[current] = *p, nexts[current] = i;
            }
        }
    }
    void operator()(int &num, int &prime)
    {
        if (num > maxN)
        {
            for (auto p = primes; *p * *p <= num; p++)
                if (num % *p == 0)
                {
                    num /= prime = *p;
                    return;
                }
            prime = num, num = 1;
        }
        else
            prime = leastPrimes[num], num = nexts[num];
    }
};
sieve<200000> factor;

struct item
{
    void *key1;
    int key2;
    int value;
    item *next, *nextKey;

    template <typename V>
    friend V &operator<<(V &out, const item *a)
    {
        if (a)
            out << '(' << a->key2 << ',' << a->value << ')' << a->nextKey;
        return out;
    }
};

template <int PRIME>
struct hashmap
{
    int rounds[PRIME], round;
    item items[PRIME], *buckets[PRIME], *last;
    hashmap() : round(0) {}
    void clear() { round++, last = items; }
    int &get(void *key1, int key2, int &size, item *&nextKey)
    {
        unsigned bucket = ((unsigned)key1 + key2) % PRIME;
        item **current = &buckets[bucket];
        if (rounds[bucket] < round)
            rounds[bucket] = round, *current = 0;
        while (*current && ((*current)->key1 != key1 || (*current)->key2 != key2))
            current = &(*current)->next;
        if (!*current)
        {
            size++;
            last->nextKey = nextKey, nextKey = last;
            last->key1 = key1, last->key2 = key2, last->next = 0, last->value = 0;
            *current = last++;
        }
        return (*current)->value;
    }
};
hashmap<7'200'007> map;

template <int maxN>
struct tree
{
    static int const maxM = maxN - 1;

    struct edge;
    struct vertex
    {
        edge *next;
        vertex *rep;
        item *nextKey;
        int size;
        mint ans;
    } vs[maxN], *lastV = vs + maxN;
    struct edge
    {
        vertex *to;
        edge *next;
    } es[2 * maxM], *lastE = es;
    int n;
    void clear(int N) { n = N, lastV = vs + n, lastE = es, memset(vs, 0, sizeof(vs)); }
    void add(int f, int t)
    {
        lastE->to = vs + t, lastE->next = vs[f].next, vs[f].next = lastE++;
        lastE->to = vs + f, lastE->next = vs[t].next, vs[t].next = lastE++;
    }

    mint ans;
    item *nextKey;
    int size;

    vertex *merge(vertex *a, vertex *b)
    {
        if (a->size > b->size)
            std::swap(a, b);

        for (auto it = a->nextKey; it; it = it->nextKey)
        {
            int &total = map.get(0, it->key2, size, nextKey);
            int &old = map.get(b, it->key2, b->size, b->nextKey);

            b->ans += C(total - old, 3);
            b->ans += C(old, 3);

            old += it->value;

            b->ans -= C(total - old, 3);
            b->ans -= C(old, 3);
        }

        return b;
    }

    void solve()
    {
        map.clear();
        nextKey = 0, size = 0;

        int prime, last;
        for (auto v = vs; v < lastV; v++)
        {
            int num;
            cin >> num;
            last = 1;
            while (num > 1)
            {
                factor(num, prime);
                if (prime != last)
                {
                    last = prime;
                    map.get(v, prime, v->size, v->nextKey) = 1;
                    map.get(0, prime, size, nextKey)++;
                }
            }
        }

        for (auto v = vs; v < lastV; v++)
            for (auto it = v->nextKey; it; it = it->nextKey)
                v->ans += C(map.get(0, it->key2, size, nextKey) - 1, 2);

        for (int i = 1; i < n; i++)
        {
            int u, v;
            cin >> u >> v;
            add(u - 1, v - 1);
        }

        ans = 0;
        for (auto e = vs->next; e; e = e->next)
            dfs(e->to, vs);

        cout << ans << '\n';
    }
    void dfs(vertex *v, vertex *from)
    {
        v->rep = v;
        for (edge *e = v->next; e; e = e->next)
            if (e->to != from)
            {
                dfs(e->to, v);
                v->rep = merge(v->rep, e->to->rep);
            }
        ans += v->rep->ans;
    }
};
tree<200000> g;

void testCase()
{
    int n;
    cin >> n;

    g.clear(n);
    g.solve();
}

int main()
{
    testCase();
    return 0;
}


Comments

Submit
0 Comments
More Questions

647. Palindromic Substrings
583. Delete Operation for Two Strings
518. Coin Change 2
516. Longest Palindromic Subsequence
468. Validate IP Address
450. Delete Node in a BST
445. Add Two Numbers II
442. Find All Duplicates in an Array
437. Path Sum III
436. Find Right Interval
435. Non-overlapping Intervals
406. Queue Reconstruction by Height
380. Insert Delete GetRandom O(1)
332. Reconstruct Itinerary
368. Largest Divisible Subset
377. Combination Sum IV
322. Coin Change
307. Range Sum Query - Mutable
287. Find the Duplicate Number
279. Perfect Squares
275. H-Index II
274. H-Index
260. Single Number III
240. Search a 2D Matrix II
238. Product of Array Except Self
229. Majority Element II
222. Count Complete Tree Nodes
215. Kth Largest Element in an Array
198. House Robber
153. Find Minimum in Rotated Sorted Array