1559E - Mocha and Stars - CodeForces Solution


combinatorics dp fft math number theory *2200

Please click on ads to support us..

Python Code:

import random, sys, os, math, gc
from collections import Counter, defaultdict, deque
from functools import lru_cache, reduce, cmp_to_key
from itertools import accumulate, combinations, permutations, product
from heapq import nsmallest, nlargest, heapify, heappop, heappush
from io import BytesIO, IOBase
from copy import deepcopy
from bisect import bisect_left, bisect_right
from math import factorial, ceil, floor, gcd
from operator import mul, xor
from types import GeneratorType
if "PyPy" in sys.version:
    import pypyjit; pypyjit.set_param('max_unroll_recursion=-1')
BUFSIZE = 8192
MOD = 10**9 + 7
MODD = 998244353
INF = float('inf')
D4 = [(1, 0), (0, 1), (-1, 0), (0, -1)]
D8 = [(1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1)]

N = 10**5+10

phi = [0] * N
phi[1] = 1
pfact = [0] * N
fact = [[] for _ in range(N)]
primes = []

mobious = [0] * N
mobious[1] = 1

for i in range(1, N):
    if not phi[i]:
        primes.append(i)
        phi[i] = i-1
        mobious[i] = -1
        for j in range(1, N):
            if i * j >= N:
                break
            pfact[i*j] = i
    for j in primes:
        if i * j >= N:
            break
        if i % j == 0:
            phi[i*j] = phi[i] * j
            mobious[i*j] = 0
            break
        phi[i*j] = phi[i] * (j-1)
        mobious[i*j] = mobious[i] * mobious[j]
    
            
def getfact(x):
    res = defaultdict(int)
    while x != 1:
        res[pfact[x]] += 1
        x //= pfact[x]
    return res

def factlist(x):
    res = [1]
    while x != 1:
        k = pfact[x]
        c = 0
        while x % k == 0:
            c += 1
            x //= k
        tmp = []
        for i in res:
            for j in range(1, c+1):
                tmp.append(i * pow(k, j))
        res += tmp
    return res

def solve():
    n, m = LII()
    h = [LII() for _ in range(n)]
    ans = 0
    for d in range(1, m+1):
        if mobious[d] == 0:
            continue
        mm = m // d
        dp = [1] * (mm+1)
        for l, r in h:
            l = (l - 1) // d + 1
            r = r // d
            ddp = [0] * (mm+1)
            for i in range(l, mm+1):
                ddp[i] = dp[i-l]
                if i - r - 1 >= 0:
                    ddp[i] -= dp[i-r-1]
                ddp[i] %= MODD
            dp[0] = 0
            for i in range(1, mm+1):
                dp[i] = (dp[i-1] + ddp[i]) % MODD
        ans += dp[mm] * mobious[d]
        ans %= MODD
    print(ans)
            

def main():
    t = 1
        for _ in range(t):
        solve()

def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to
    return wrappedfunc

def bitcnt(n):
    c = (n & 0x5555555555555555) + ((n >> 1) & 0x5555555555555555)
    c = (c & 0x3333333333333333) + ((c >> 2) & 0x3333333333333333)
    c = (c & 0x0F0F0F0F0F0F0F0F) + ((c >> 4) & 0x0F0F0F0F0F0F0F0F)
    c = (c & 0x00FF00FF00FF00FF) + ((c >> 8) & 0x00FF00FF00FF00FF)
    c = (c & 0x0000FFFF0000FFFF) + ((c >> 16) & 0x0000FFFF0000FFFF)
    c = (c & 0x00000000FFFFFFFF) + ((c >> 32) & 0x00000000FFFFFFFF)
    return c

def lcm(x, y):
    return x * y // gcd(x, y)

def lowbit(x):
    return x & -x

@bootstrap
def exgcd(a: int, b: int):
    if b == 0:
        d, x, y = a, 1, 0
    else:
        (d, p, q) = yield exgcd(b, a % b)
        x = q
        y = p - q * (a // b)
 
    yield d, x, y

def perm(n, r):
    return factorial(n) // factorial(n - r) if n >= r else 0
 
def comb(n, r):
    return factorial(n) // (factorial(r) * factorial(n - r)) if n >= r else 0

def probabilityMod(x, y, mod):
    return x * pow(y, mod-2, mod) % mod

class SortedList:
    def __init__(self, iterable=[], _load=200):
        
        values = sorted(iterable)
        self._len = _len = len(values)
        self._load = _load
        self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
        self._list_lens = [len(_list) for _list in _lists]
        self._mins = [_list[0] for _list in _lists]
        self._fen_tree = []
        self._rebuild = True
 
    def _fen_build(self):
        
        self._fen_tree[:] = self._list_lens
        _fen_tree = self._fen_tree
        for i in range(len(_fen_tree)):
            if i | i + 1 < len(_fen_tree):
                _fen_tree[i | i + 1] += _fen_tree[i]
        self._rebuild = False
 
    def _fen_update(self, index, value):
        
        if not self._rebuild:
            _fen_tree = self._fen_tree
            while index < len(_fen_tree):
                _fen_tree[index] += value
                index |= index + 1
 
    def _fen_query(self, end):
        
        if self._rebuild:
            self._fen_build()
 
        _fen_tree = self._fen_tree
        x = 0
        while end:
            x += _fen_tree[end - 1]
            end &= end - 1
        return x
 
    def _fen_findkth(self, k):
        
        _list_lens = self._list_lens
        if k < _list_lens[0]:
            return 0, k
        if k >= self._len - _list_lens[-1]:
            return len(_list_lens) - 1, k + _list_lens[-1] - self._len
        if self._rebuild:
            self._fen_build()
 
        _fen_tree = self._fen_tree
        idx = -1
        for d in reversed(range(len(_fen_tree).bit_length())):
            right_idx = idx + (1 << d)
            if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
                idx = right_idx
                k -= _fen_tree[idx]
        return idx + 1, k
 
    def _delete(self, pos, idx):
        
        _lists = self._lists
        _mins = self._mins
        _list_lens = self._list_lens
 
        self._len -= 1
        self._fen_update(pos, -1)
        del _lists[pos][idx]
        _list_lens[pos] -= 1
 
        if _list_lens[pos]:
            _mins[pos] = _lists[pos][0]
        else:
            del _lists[pos]
            del _list_lens[pos]
            del _mins[pos]
            self._rebuild = True
 
    def _loc_left(self, value):
        
        if not self._len:
            return 0, 0
 
        _lists = self._lists
        _mins = self._mins
 
        lo, pos = -1, len(_lists) - 1
        while lo + 1 < pos:
            mi = (lo + pos) >> 1
            if value <= _mins[mi]:
                pos = mi
            else:
                lo = mi
 
        if pos and value <= _lists[pos - 1][-1]:
            pos -= 1
 
        _list = _lists[pos]
        lo, idx = -1, len(_list)
        while lo + 1 < idx:
            mi = (lo + idx) >> 1
            if value <= _list[mi]:
                idx = mi
            else:
                lo = mi
 
        return pos, idx
 
    def _loc_right(self, value):
        
        if not self._len:
            return 0, 0
 
        _lists = self._lists
        _mins = self._mins
 
        pos, hi = 0, len(_lists)
        while pos + 1 < hi:
            mi = (pos + hi) >> 1
            if value < _mins[mi]:
                hi = mi
            else:
                pos = mi
 
        _list = _lists[pos]
        lo, idx = -1, len(_list)
        while lo + 1 < idx:
            mi = (lo + idx) >> 1
            if value < _list[mi]:
                idx = mi
            else:
                lo = mi
 
        return pos, idx
 
    def add(self, value):
        
        _load = self._load
        _lists = self._lists
        _mins = self._mins
        _list_lens = self._list_lens
 
        self._len += 1
        if _lists:
            pos, idx = self._loc_right(value)
            self._fen_update(pos, 1)
            _list = _lists[pos]
            _list.insert(idx, value)
            _list_lens[pos] += 1
            _mins[pos] = _list[0]
            if _load + _load < len(_list):
                _lists.insert(pos + 1, _list[_load:])
                _list_lens.insert(pos + 1, len(_list) - _load)
                _mins.insert(pos + 1, _list[_load])
                _list_lens[pos] = _load
                del _list[_load:]
                self._rebuild = True
        else:
            _lists.append([value])
            _mins.append(value)
            _list_lens.append(1)
            self._rebuild = True
 
    def discard(self, value):
        
        _lists = self._lists
        if _lists:
            pos, idx = self._loc_right(value)
            if idx and _lists[pos][idx - 1] == value:
                self._delete(pos, idx - 1)
 
    def remove(self, value):
        
        _len = self._len
        self.discard(value)
        if _len == self._len:
            raise ValueError('{0!r} not in list'.format(value))
 
    def pop(self, index=-1):
        
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        value = self._lists[pos][idx]
        self._delete(pos, idx)
        return value
 
    def bisect_left(self, value):
        
        pos, idx = self._loc_left(value)
        return self._fen_query(pos) + idx
 
    def bisect_right(self, value):
        
        pos, idx = self._loc_right(value)
        return self._fen_query(pos) + idx
 
    def count(self, value):
        
        return self.bisect_right(value) - self.bisect_left(value)
 
    def __len__(self):
        
        return self._len
 
    def __getitem__(self, index):
        
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        return self._lists[pos][idx]
 
    def __delitem__(self, index):
        
        pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
        self._delete(pos, idx)
 
    def __contains__(self, value):
        
        _lists = self._lists
        if _lists:
            pos, idx = self._loc_left(value)
            return idx < len(_lists[pos]) and _lists[pos][idx] == value
        return False
 
    def __iter__(self):
        
        return (value for _list in self._lists for value in _list)
 
    def __reversed__(self):
        
        return (value for _list in reversed(self._lists) for value in reversed(_list))
 
    def __repr__(self):
        
        return 'SortedList({0})'.format(list(self))

class FastIO(IOBase):
    newlines = 0

    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None

    def read(self):
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()

    def readline(self):
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()

    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)

class IOWrapper(IOBase):
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable
        self.write = lambda s: self.buffer.write(s.encode("ascii"))
        self.read = lambda: self.buffer.read().decode("ascii")
        self.readline = lambda: self.buffer.readline().decode("ascii")

if sys.platform != "win32":
    sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout)
input = lambda: sys.stdin.readline().rstrip("\r\n")

def I():
    return input()

def II():
    return int(input())

def MI():
    return map(int, input().split())

def LI():
    return list(input().split())

def LII():
    return list(map(int, input().split()))

def GMI():
    return map(lambda x: int(x) - 1, input().split())

def LGMI():
    return list(map(lambda x: int(x) - 1, input().split()))

def debug(*args):
    print('\033[92m', end='')
    print(*args)
    print('\033[0m', end='')

if __name__ == "__main__":
    main()

C++ Code:

#include <bits/stdc++.h>

using namespace std;

int n, m;

int l[51];

int r[51];

constexpr int base = 998244353;

long long val[100001];

int mu[100001];

int p[100001];

int range[51];

int f[100001];

int g[100001];

void solve(){

    cin>>n>>m;

    for(int i=1; i<=n; i++) cin>>l[i]>>r[i];

    for(int i=2; i<=m; i++) if(p[i]==0) for(int j=i; j<=m; j+=i) if(p[j]==0) {

        p[j]=i;

    }

    mu[1]=1;

    for(int i=2; i<=m; i++){

        int u=i/p[i];

        if(p[u]==p[i]) mu[i]=0;

        else mu[i]=-mu[u];

    }

    for(int i=1; i<=m; i++) if(mu[i]){

        int mx=m/i;

        val[i]=1;

        int sum=m;

        for(int j=1; j<=n; j++){

            int low = l[j];

            if(low%i) low+=i-low%i;

            int high = r[j]-r[j]%i;

            if(low>high){

                val[i]=0;

                break;

            }

            sum-=low;

            if(sum<0){

                val[i]=0;

                break;

            }

            range[j]=(high-low)/i+1;

        }

        // cerr<<i<<' '<<ranfe

        if(val[i]==0) continue;

        sum/=i;

        for(int j=0; j<=sum; j++) f[j]=0;

        for(int j=0; j<=sum; j++) g[j]=1;

        f[0]=1;

        for(int j=1; j<=n; j++){

            for(int k=sum; k>=0; k--){

                if(k-range[j]>=0){

                    f[k]=g[k]-g[k-range[j]];

                    if(f[k]<0) f[k]+=base;

                }

                else f[k]=g[k];

            }

            g[0]=f[0];

            for(int k=1; k<=sum; k++){

                g[k]=g[k-1]+f[k];

                if(g[k]>=base) g[k]-=base;

            }

        }

        val[i]=g[sum];

        // for(int f=r[1]-r[1]%i; f>=l[1]; f-=i)

        // for(int g=r[2]-r[2]%i; g>=l[2]; g-=i)

        // for(int h=r[3]-r[3]%i; h>=l[3]; h-=i)

        // for(int j=r[4]-r[4]%i; j>=l[4]; j-=i)

        // for(int k=r[5]-r[5]%i; k>=l[5]; k-=i)

        //   if(f+g+h+j+k<=m) val[i]--;

        // cerr<<i<<' '<<val[i]%base<<'\n';

        

    }

    long long ans=0;

    for(int i=1; i<=m; i++){

        ans+=val[i]*mu[i];

    }

    ans%=base;

    (ans+=base)%=base;

    cout<<ans<<'\n';

}



int main(){

    ios_base::sync_with_stdio(0);

    cin.tie(0);

    cout.tie(0);

    int t=1;

    // cin>>t;

    while(t--) solve();

}


Comments

Submit
0 Comments
More Questions

22E - Scheme
1566A - Median Maximization
1278A - Shuffle Hashing
1666F - Fancy Stack
1354A - Alarm Clock
1543B - Customising the Track
1337A - Ichihime and Triangle
1366A - Shovels and Swords
919A - Supermarket
630C - Lucky Numbers
1208B - Uniqueness
1384A - Common Prefixes
371A - K-Periodic Array
1542A - Odd Set
1567B - MEXor Mixup
669A - Little Artem and Presents
691B - s-palindrome
851A - Arpa and a research in Mexican wave
811A - Vladik and Courtesy
1006B - Polycarp's Practice
1422A - Fence
21D - Traveling Graph
1559B - Mocha and Red and Blue
1579C - Ticks
268B - Buttons
898A - Rounding
1372B - Omkar and Last Class of Math
1025D - Recovering BST
439A - Devu the Singer and Churu the Joker
1323A - Even Subset Sum Problem