Skip to content

新生赛解题报告

by Niolle_Semis, DGME, AtomFirst, Sgdd, 4627488

补题链接

A RUST

签到,这道题考察一些基本的语法,根据题意,我们不妨开一个bool数组,然后依次读取每个操作,先判断给定的位置是否合法(即是否在1n之间),如果读入到alloc,就将对应的位置置为true,如果读入到free ,就将对应的位置置为false,如果所在位置已经是true/false,就输出Illegal operation,否则输出All operations are safe

完整代码

cpp
#include <bits/stdc++.h>

using namespace std;

int n, m, x, flag = 0;

bool a[1005];

int main()
{
	cin >> n >> m;
	for (int i = 1; i <= m; i++)
	{
		string s;
		cin >> s >> x;
		if (x < 1 || x > n)
		{
			flag = 1;
			continue;
		}
		if (s == "alloc")
		{
			if (a[x] == 0)
				a[x] = 1;
			else
				flag = 1;
		}
		else if (s == "free")
		{
			if (a[x] == 1)
				a[x] = 0;
			else
				flag = 1;
		}
	}
	if (flag == 1)
		cout << "Illegal operation";
	else
		cout << "All operations are safe";
	return 0;
}

D 輪符雨

思考一下每次选取 1i<jn ,并进行 a[i]+1,a[j]1 会产生什么宏观意义上的影响:

  • a,即数组总和不变。
  • 操作是不可逆的,对于操作 (i,j),其可以使用的次数随着游戏的进行非严格递减(简证:如果不考虑其它数字,你如果操作无穷次,那么必然 a[i]>a[j],违反保证有序的条件)。

考虑性质 2,每一次游戏的进行必然导致可用的 (i,j) 对的减少,最终游戏局面必然是无法进行任何操作,且有序的形式,于是 答案 是:

an,an,,an+1,,an+1

即类似: 3 3 3 3 4 4 这种形式,也有可能是 3 3 3 3 3 3 这种形式,根据是否总和还有剩余往后面的数字 +1

为什么可以达到这种局面

  • 对于所有有序序列,都存在一种操作方式可以达到 上述局面

很显然,这道题可以改成输出一种可能的构造方式来出题,但是作为签到题,还是简单一点好。

我们称类似于上述平均数,平均数,..., 平均数 + 1 的形式为 的,那么每次对于当前序列的后缀开始调整为好的,假设我们已经将 a[i+1,n] 调整为好的,设为 b

那么对于我们下一个目标,即好的 a[i,n],设为 c,一定有 j[i+1,n],bjcj,那么我们可以从那些 bjcj 的位置,向 ai 贡献一个 1,从而便可以从 b 转移到 c

归纳多次后,便证明了上述结论。

由于每种有序序列都有方法转移到答案这种情况(把它想象成一个 DAG),而操作步骤次数必然是有限次,即使是随机往下跳,最终也能跳到答案。

完整代码

cpp
#include <bits/stdc++.h>
using namespace std;

void solve()
{
    int N;
    cin >> N;
    long long s = 0;
    for (int i = 1; i <= N; i++)
    {
        int x;
        cin >> x, s += x;
    }
    for (int i = 1; i <= N; i++)
        cout << s / N + ((N - i) < s % N) << " ";
    cout << '\n';
}

int main()
{
    cin.tie(0)->sync_with_stdio(0), cout.tie(0);
    int t;
    cin >> t;
    while (t--)
        solve();
    return 0;
}

E 焚音打

那……能陪我组一辈子的乐队吗?

写出题意所述的操作,通过观察,我们可以发现,仅有第一盏灯最后是开的状态。因此仅有第一盏灯需要输出"YES",其余的灯需要输出"NO"。

cpp
void press(int x) { 
    light[x]^=1;
    for (int y=x+x;y<=n;y+=x)
        press(y);
}
for(int i=1;i<=n;i++) press(i);

本题的 press 是递归定义的,与平常所做的完全平方数类型的开关灯题不同。

考虑一个数 a 的所有除自身以外的因数 d,参考 press 函数,容易发现对于每次对 d 的修改后,必然会修改 a

cnti 表示 i 被修改的次数,可以发现:

cnti=didicntd+1

特别地,这个加 1 是因为遍历到 d 时,对此的修改。

可以发现,当 cnti 为奇数时,灯开;反之,灯灭。

  1. i=1 时,显然 cnti=1,灯开。
  2. iPP 表示质数集合)时,cnti=cnt1+1=1+1=2,灯灭。
  3. i 为合数时,cnti 为偶数,使用数学归纳法证明如下。

证明

显然,cnt4=cnt1+cnt2+1=1+2+1=4

假设当前已经处理到了合数 t,则 cntt 为偶数。

考虑 t 的下一个合数 p,由于 (t,p) 范围内的数都是质数,所以

cntt+1=cntt+2==cntp1=2

对于 p 的不等于 1 且不等于 p 的因数 ri,易知 cntri 为偶数,那么其和也是偶数。

因此有

cntp=cnt1+cntri+1=1+偶数+1=偶数

故得证。

综上所述,当且仅当 k=1 时,灯是亮着的。

完整代码

cpp
#include <iostream>

int main()
{
    std::cin.tie(0);
    std::cout.tie(0);
    std::ios::sync_with_stdio(0);

    int T;
    std::cin >> T;
    while (T--)
    {
        int n, k;
        std::cin >> n >> k;
        std::cout << (k == 1 ? "YES\n" : "NO\n");
    }
}

H Life Will Change

遍历排列 p,每当 p[i] 不等于 i 时,需要将 p[i] 与其正确位置的元素交换。直接查找每个元素的目标位置会导致 O(n2) 的复杂度。为优化查找过程,我们引入一个数组 q,其中 q[p[i]] = i,即 q[x] 记录值为 x 的元素所在位置。

这样,在遍历时,如果发现 p[i] ≠ i,可以通过 q 数组直接找到 p[i] 的目标位置 j,将 p[i]p[j] 交换,并更新 q 数组。整个过程的复杂度从 O(n2) 降低到 O(n)

完整代码

cpp
#include <iostream>
#include <vector>
#include <algorithm>
#include <array>

using namespace std;
using ll = long long;
using int2 = array<int, 2>;

#define all(x) x.begin(), x.end()

void solve()
{
    int n;
    cin >> n;
    vector<int> p(n + 1), ind(n + 1);
    for (int i = 1; i <= n; i++)
        cin >> p[i], ind[p[i]] = i;

    int cnt = 0;
    vector<int2> opt;
    for (int x = 1; x <= n; x++)
        if (p[x] != x)
        {
            int y = ind[x];
            opt.push_back({x, y});
            ++cnt;
            swap(ind[x], ind[p[x]]);
            swap(p[x], p[y]);
        }

    cout << cnt << '\n';
    for (auto [x, y] : opt)
        cout << x << ' ' << y << '\n';
}
// 1 2 3 4 5 6 7 9 8 10

int main()
{
    ios::sync_with_stdio(0),
        cin.tie(0), cout.tie(0);

    int t = 1;
    // cin>>t;
    while (t--)
        solve();

    return 0;
}

M INFINITY

给定一个字符串,每次将其最后一个字符移到最前方,形成的新串接到原串后作为下一次操作的字符串。

现询问第 N 个位置的字符。

根据数据范围,N < 1018,显然不能直接模拟,所以我们从每次操作入手。每次操作过后,字符串的长度会变为原先的两倍,因此可以使用分治算法。

分治法的核心思想是把问题分成多个子问题,分别解决并合并结果。对于这道题,我们可以利用分治来解决。

为了找到第 N 个字符,我们可以用一个 t 变量记录何时字符串长度超过 N,代码如下:

cpp
while (t < n) t <<= 1;  // 位运算更快

根据题意可得,当第 N 个字符在长度为 t 的字符串的后半段时,前半段字符串中的第 N - 1 - t / 2 个字符必定与第 N 个字符相同。我们可以通过这个关系来编写如下代码:

cpp
while (t != l) t >>= 1, n -= 1 + t;

但是,这段代码虽然接近正确,但还不完全。根据原操作,我们是将字符串的最后一个字符移到第一个后接到原串。因此,可以推导出如下的关系:

当 N 等于 t / 2 + 1 时,第 N 个字符为后半段字符串的第一个字符。对应的字符位置应为 t / 2N - 1。加入特判后的代码如下:

cpp
while (t != l) {
    t >>= 1;
    if (t + 1 == n) n = t;
    else n -= 1 + t;
}

使用三目运算符也可以简写成以下代码:

cpp
while (t != l) t >>= 1, n = (t + 1 != n) ? n - 1 - t : t;

如果第 N 个字符位于前半段字符串中,我们不需要做任何操作,直接返回 t / 2 即可。

cpp
while (t != l) {
	t >>= 1;
	if (n <= t) continue;
	if (t + 1 == n) n = t;
	else n -= 1 + t;
}

完整代码

cpp
#include <bits/stdc++.h>
using namespace std;

long long l, n, t;

char s[55];

int main() {
	scanf("%s%lld", s + 1, &n), l = t = strlen(s + 1);
	while (t < n) t <<= 1;
	while (t != l) t >>= 1, n = n > t ? ((t + 1 != n) ? n - 1 - t : t) : n;
	putchar(s[n]);
	return 0;
}

C International Chairs-Problem Contest

如果只有一个椅子,只能从小往大逐个尝试,最坏需要 n 次。

若尝试 a 重量椅子不坏,b 重量坏,答案缩小到 S(a,b)={a,a+1,a+2,...,b1}

对于一个椅子,从 S(a,b) 缩小到精确答案最坏需 (ba1) 次尝试。

若使用两个椅子,总次数 cost 则为第一个椅子尝试次数 cost1 与第二个椅子尝试次数 cost2 的和。

若第一次尝试 t1=k 重量,第二次尝试 t2=k+(k1) 重量,第 i 次尝试 ti=k+(k1)+...+(ki+1) 重量,直到椅子损坏,并将范围缩小到 S(ti1,ti) 。记第 cost1 次尝试时损坏,则最坏还需要 cost2=tcost1tcost11+1=kcost1 次尝试,最坏总次数 cost=cost1+cost2=k

注意应满足 tkn,即 k(k1)/2n,解得 k=1+8n+12

完整代码

cpp
#include <iostream>
#include <vector>
#include <algorithm>
#include <array>
#include <cmath>

using namespace std;
using ll = long long;
using int2 = array<int, 2>;

#define all(x) x.begin(), x.end()

int n;

int ask(int h)
{
    if (h > n)
        return 1;
    cout << "? " << h << endl;
    cout.flush();
    int res;
    cin >> res;
    return res;
}

void ans(int h)
{
    cout << "! " << h << endl;
    cout.flush();
}

void solve()
{
    cin >> n;
    // c*(c+1)/2=n
    // c2+c+2(-n)=0
    // dealt=1+8n
    // c=(-1+sqrt(dealt))/2
    // n=7 -> (0 3 5 6 7] -> c=3
    //
    int c = ceil((-1 + sqrtl(1 + ll(n) * 8)) / 2);

    int lst = 0, cur;
    for (int i = c; i > 0; i--)
    {
        cur = lst + i;
        if (ask(cur))
            break;
        lst = cur;
    }
    int res = lst + 1;
    while (not ask(res))
        ++res;
    ans(res - 1);
}

int main()
{
    ios::sync_with_stdio(0),
        cin.tie(0), cout.tie(0);

    int t = 1;
    // cin>>t;
    while (t--)
        solve();

    return 0;
}

F Distortion!!

一个简单的解法是尝试所有满足 1LRN(L,R) 对,找到所有可能的 S,并输出这些字符串和原始字符串 S 中按字典序最小的一个。 这种方法可以得到正确答案,但由于存在 N2(L,R),且每个字符串的长度为 N,时间复杂度为 O(N3),效率非常低。我们可以尝试缩小需要检查的对数。

首先,如果字符串 S 仅由字符 d 组成,那么最优解是跳过操作,直接输出 S。接下来,假设 S 中包含字符 p。 设 i 为字符串 S 中最左侧字符 p 的位置。定义 g(l,r) 为选择 (L,R)=(l,r) 后得到的字符串 S

在这种情况下,有以下结论:

对于任何 g(L,R),当 Li 时,总存在一个 R 使得 g(i,R)<g(L,R)()

我们将根据 iL 的大小关系进行分类讨论,来证明上述结论。

  1. i<L:比较 g(i,i)g(L,R)。这两个字符串的前 (L1) 个字符相同,但 g(i,i)g(L,R) 的第 L 个字符分别为 dp,因此有 g(i,i)<g(L,R)

  2. i>L:比较 g(i,R)g(L,R)。从第 1 到第 (L1) 个字符以及从第 (R+1) 到第 N 个字符,这两者是相同的,因此我们关注剩余部分。根据假设,字符串 S 的第 L 到第 (i1) 个字符都是 d,设 U 为长度为 iL 的全 d 字符串,第 i 到第 R 个字符的比较如下:

    • g(i,R) 的第 i 到第 R 个字符为 U+f(T)
    • g(L,R) 的第 i 到第 R 个字符为 f(U+T)=f(T)+f(U)

现在考虑 U+f(T)f(T)+f(U),比较其中 p 的最左位置(因为字符串长度相同,只需比较最左侧不同位置即可确定哪个字符串更小)。

我们将根据 f(T) 是否包含 p 分情况讨论:

  1. f(T) 不包含 pU+f(T) 仅由 d 组成且不包含 p,而 f(U) 全由 p 组成,因此 f(T)+f(U) 的第 (|T|+1) 个字符是 p,所以 U+f(T)<f(T)+f(U)
  2. f(T) 中最左侧的 p 在位置 jU+f(T)p 的最左位置是 |U|+j,而 f(T)+f(U) 的最左位置是 j,因此 U+f(T)<f(T)+f(U)

因此在所有情况下,U+f(T)<f(T)+f(U) 成立,所以在情况 2 中也有 g(i,R)<g(L,R)

综上所述,() 在所有情况下都成立。因此,只需检查满足 L=i(L,R) 对即可,因为对于任何 g(L,R)Li),总能找到一个更小的字符串 g(i,R)

由此可见,我们需要检查的字符串数量从 O(N2) 降低到 O(N),整体复杂度为 O(N2),这已经足够高效。

完整代码

cpp
#include <bits/stdc++.h>

#define rep(i, n) for (int i = 0; i < n; ++i)
#define repn(i, n) for (int i = 1; i <= n; ++i)

using namespace std;

int n;
string s, t, ans;

int main()
{
    cin >> n >> s;
    ans = s;
    for (int i = 0; i < n; i++)
    {
        if (s[i] == 'p')
        {
            t = s;
            t[i] = 'd';
            ans = min(ans, t);
            t = s;
            for (int j = 0; j < n; j++)
                if (s[j] == 'p')
                {
                    reverse(t.begin() + j, t.begin() + i + 1);
                    for (int k = j; k <= i; ++k)
                        if (t[k] == 'p')
                            t[k] = 'd';
                        else
                            t[k] = 'p';
                    ans = min(ans, t);
                    break;
                }
        }
    }
    cout << ans;
    return 0;
}

J 空の箱

经典二分答案套路,面对最小值最大化的问题,可以按照以下步骤进行二分:

  1. 设定二分的初始范围 [l,r],其中 x 表示当前二分到的最小值。

  2. 将目标最小值设为 x,然后尽可能多地将原序列拆分成若干个权值和大于等于 x 的区间。

  3. 判断拆分出来的区间数量是否大于等于 m

    • 如果区间数量 m,说明 x 是一个合法的最小值,此时答案在 [x,r] 中。
    • 如果区间数量 <m,说明 x 过大,此时答案在 [l,x1] 中。
  4. 继续二分,直到区间收敛,得到最终答案。

完整代码

cpp
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#define int long long
using namespace std;
int n, m, a[100005];
int work(int x)
{
	int sum = 0, s = 0;
	for (int i = 1; i <= n; i++)
	{
		if (s >= x)
		{
			sum++;
			s = 0;
		}
		s += a[i];
	}
	if (s >= x)
		++sum;
	return sum;
}
signed main()
{
	scanf("%lld%lld", &n, &m);
	int l = 0, r = 0;
	for (int i = 1; i <= n; i++)
		scanf("%lld", &a[i]), r += a[i];
	while (l < r)
	{
		int mid = (l + r + 1) / 2;
		if (work(mid) >= m)
			l = mid;
		else
			r = mid - 1;
	}
	printf("%lld", l);
}
rust
use std::io::{self, BufRead};

fn work(a: &Vec<i64>, n: usize, x: i64) -> i64 {
    let mut sum = 0;
    let mut s = 0;
    for i in 0..n {
        if s >= x {
            sum += 1;
            s = 0;
        }
        s += a[i];
    }
    if s >= x {
        sum += 1;
    }
    sum
}

fn main() {
    let stdin = io::stdin();
    let mut lines = stdin.lock().lines();

    let first_line = lines.next().unwrap().unwrap();
    let nums: Vec<i64> = first_line
        .split_whitespace()
        .map(|x| x.parse::<i64>().unwrap())
        .collect();
    let n = nums[0] as usize;
    let m = nums[1];

    let second_line = lines.next().unwrap().unwrap();
    let mut a: Vec<i64> = second_line
        .split_whitespace()
        .map(|x| x.parse::<i64>().unwrap())
        .collect();

    let mut l = 0;
    let mut r = a.iter().sum::<i64>();

    while l < r {
        let mid = (l + r + 1) / 2;
        if work(&a, n, mid) >= m {
            l = mid;
        } else {
            r = mid - 1;
        }
    }

    println!("{}", l);
}

L 视野一隅

O(n3)做法:就是01背包,设dp[i][j]表示前i个数中选择的区间长度为j的最大价值

状态转移方程为

dp[i][j]=max{maxk=1j{dp[ik][jk]+abs(biajk+1)},dp[i1][j]}

O(n2)做法,有兴趣可以问出题人 Niolle_Semis qwq

完整代码

cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 410;
int a[N], b[N];
long long f[N][N];
int main()
{
	int T, n, k;
	scanf("%d", &T);
	while (T--)
	{
		scanf("%d%d", &n, &k);
		for (int i = 1; i <= n; i++)
			scanf("%d", a + i);
		for (int i = 1; i <= n; i++)
			scanf("%d", b + i);
		for (int j = 1; j <= k; j++)
		{
			for (int i = 0; i <= n; i++)
				f[j][i] = -1e18;
			for (int L = 1; L <= n; L++)
			{
				for (int R = L; R <= n && R - L + 1 <= j; R++)
				{
					f[j][R] = max(f[j][R], abs(b[R] - a[L]) + f[j - R + L - 1][L - 1]);
				}
			}
			// if(j == k)cerr<<f[j][n]<<endl;
			for (int i = 1; i <= n; i++)
				f[j][i] = max(f[j][i], f[j][i - 1]);
		}
		printf("%lld\n", f[k][n]);
	}
	return 0;
}
rust
use std::cmp::max;
use std::io;

fn abs<T: std::cmp::PartialOrd + std::ops::Neg<Output = T> + Default>(x: T) -> T {
    if x < T::default() {
        -x
    } else {
        x
    }
}

fn work(n: usize, k: usize, a: &Vec<i64>, b: &Vec<i64>) -> i64 {
    let mut dp = vec![vec![0i64; k + 1]; n + 1];

    for i in 1..=n {
        for z in (0..=k).rev() {
            dp[i][z] = dp[i - 1][z];
            for j in 1..=i {
                let t = z as i64 - (i - j + 1) as i64;
                if t < 0 {
                    continue;
                }
                dp[i][z] = max(dp[i][z], dp[j - 1][t as usize] + abs(b[i - 1] - a[j - 1]));
            }
        }
    }

    dp[n][k]
}

fn main() {
    let stdin = io::stdin();
    let mut input = String::new();
    stdin.read_line(&mut input).unwrap();
    let t: usize = input.trim().parse().unwrap();
    
    for _ in 0..t {
        input.clear();
        stdin.read_line(&mut input).unwrap();
        let tokens: Vec<usize> = input
            .split_whitespace()
            .map(|s| s.parse().unwrap())
            .collect();
        let n = tokens[0];
        let k = tokens[1];

        input.clear();
        stdin.read_line(&mut input).unwrap();
        let a: Vec<i64> = input
            .split_whitespace()
            .map(|s| s.parse().unwrap())
            .collect();

        input.clear();
        stdin.read_line(&mut input).unwrap();
        let b: Vec<i64> = input
            .split_whitespace()
            .map(|s| s.parse().unwrap())
            .collect();

        println!("{}", work(n, k, &a, &b));
    }
}

B CF与睡眠与蓝色星球

在线做法

二分区间右端点位置,并使用RMQ得到区间最小值,判断是否大于等于 k

时间复杂度 O(mlogn+nlogn) ,前者是查询,后者是预处理ST表

离线做法

将所有查询区间按照 k 由大到小排序

之后依次将大于等于当前查询区间 k 值的位置加入树状数组中,树状数组维护一个后缀最小值即可

时间复杂度 O(mlogn+nlogn)

完整代码

cpp
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = a; i <= b; i++)
#define dwn(i, a, b) for (int i = a; i >= b; i--)
#define lowbit(x) (x & (-x))
#define MAXN 1002501
#define mp(x, y) make_pair(x, y)

using namespace std;

typedef long long ll;

inline int read()
{
	int x = 0, f = 1;
	char ch = getchar();
	while (ch > '9' || ch < '0')
	{
		if (ch == '-')
			f = -1;
		ch = getchar();
	}
	while ('0' <= ch && ch <= '9')
		x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
	return x * f;
}

struct node
{
	int k, l, id;
} q[MAXN], p[MAXN];

bool cmpp(node x, node y)
{
	return x.k > y.k;
}

int n, a[MAXN], m, c[MAXN], inf, ans[MAXN];

void ins(int x, int y)
{
	while (x <= n)
		c[x] = min(c[x], y), x += lowbit(x);
}

int qry(int x)
{
	int res = n + 1;
	while (x)
		res = min(res, c[x]), x -= lowbit(x);
	return res;
}

int rev(int x)
{
	return n - x + 1;
}

int main()
{
	n = read();
	m = read();
	rep(i, 1, n) c[i] = n + 1;
	rep(i, 1, n) a[i] = read(), q[i] = {a[i], i, i};
	int cnt = 1;
	rep(i, 1, m)
	{
		p[i].k = read();
		p[i].l = read();
		p[i].id = i;
	}
	sort(q + 1, q + n + 1, cmpp);
	sort(p + 1, p + m + 1, cmpp);
	rep(i, 1, m)
	{
		while (cnt <= n && q[cnt].k >= p[i].k)
		{
			ins(rev(q[cnt].l), q[cnt].l);
			++cnt;
		}
		//		cout<<"I:"<<p[i].l<<" "<<qry(rev(p[i].l))<<endl;
		ans[p[i].id] = qry(rev(p[i].l)) - p[i].l;
	}
	rep(i, 1, m) if (!ans[i]) puts("all in acm");
	else printf("%d\n", ans[i]);
	return 0;
}

G 名無声

由于对于每节课不能去上的概率是均等的,所以每节课翘掉的概率都是kn

答案就是kni=1naibi

完整代码

cpp
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = a; i <= b; i++)
#define dwn(i, a, b) for (int i = a; i >= b; i--)
#define lowbit(x) (x & (-x))
#define MAXN 1002501
#define int long long
#define mp(x, y) make_pair(x, y)
using namespace std;
typedef long long ll;
inline int read()
{
	int x = 0, f = 1;
	char ch = getchar();
	while (ch > '9' || ch < '0')
	{
		if (ch == '-')
			f = -1;
		ch = getchar();
	}
	while ('0' <= ch && ch <= '9')
		x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
	return x * f;
}
const int mod = 998244353;
int n, k, ans, a, b;
int ksm(int x, int y)
{
	int res = 1;
	while (y)
	{
		if (y & 1)
			res = res * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return res;
}
signed main()
{
	n = read();
	k = read();
	rep(i, 1, n)
	{
		a = read();
		b = read();
		ans += a * ksm(b, mod - 2);
		ans %= mod;
	}
	// cout<<ans<<endl;
	ans = ans * k % mod;
	ans = ans * ksm(n, mod - 2) % mod;
	cout << ans;
	return 0;
}

K Savourons les moments

首先考虑如何让一个序列的最大众数为 k ,假设 k 出现了 x 次,那么 [0,k1] 每个数出现次数不能超过 x 次,[k+1,5000] 每个数出现次数不能超过 x1

cnti 表示 i 的出现次数,那么可以枚举 k 的出现次数,出现 x 次的答案即为

(cntkx)×i=1k1(j=0min{cnti,x}(cntij))×i=k+15000(j=0min{cnti,x1}(cntij))

利用前缀和即可在 O(n2) 内计算。

注意到当 x 超过 cnti 时,i 对应的贡献不会再发生变化,每当增加 x 时,只更新 cntix (或 cnti>x) 部分的值,每张照片最多会产生一次更新。

时间复杂度 O(n)

模意义下组合数这里提供两种方法:

  • 令杨辉三角首行为第 0 行,首列为第 0 列,第 i 行第 j 列的值即为 (ij)
  • 利用exgcd或快速幂求逆元,预处理前缀\后缀积后计算

完整代码

cpp
#include <bits/stdc++.h>

const int mod = (int)1e9 + 7;

inline int qpow(int x, int k)
{
	int res = 1, buf = x;
	for (; k; k >>= 1)
	{
		if (k & 1)
			res = 1ll * buf * res % mod;
		buf = 1ll * buf * buf % mod;
	}
	return res;
}

int main()
{
	const int N = (int)5e3;
	std::vector<int> fac(N + 1), inv(N + 1);
	fac[0] = 1;
	for (int i = 1; i <= N; i++)
		fac[i] = 1ll * fac[i - 1] * i % mod;
	inv[N] = qpow(fac[N], mod - 2);
	for (int i = N; i; i--)
		inv[i - 1] = 1ll * inv[i] * i % mod;
	auto C = [&](int n, int m)
	{
		if (m > n || m < 0)
			return 0ll;
		return 1ll * fac[n] * inv[n - m] % mod * inv[m] % mod;
	};

	int n, k;
	scanf("%d%d", &n, &k);
	std::vector<int> cnt(n + 2);
	for (int i = 0; i < n; i++)
	{
		int x;
		scanf("%d", &x);
		cnt[x]++;
	}
	int ans = 0;
	std::vector<int> f(n + 2);
	for (int i = 0; i <= cnt[k]; i++)
	{
		int res = 1;
		for (int j = 1; j <= n + 1; j++)
		{
			if (j == k)
				continue;
			f[j] = (f[j] + C(cnt[j], i - (j > k))) % mod;
			res = 1ll * f[j] * res % mod;
		}
		ans = (1ll * res * C(cnt[k], i) + ans) % mod;
	}
	printf("%d\n", ans);
	return 0;
}

I 若成为星座

容斥问题,总方案数为

S0=(Aa1+1)×(Aa2+1)×(Bb1+1)×(Bb2+1)

行不相交的方案数为

S1=AAa1a2+22×(Bb1+1)×(Bb2+1)

列不相交的方案数为

S2=ABb1b2+22×(Aa1+1)×(Aa2+1)

行列都不交的方案数为

S3=ABb1b2+22×AAa1a2+22

答案即为

S1+S2S3

完整代码

cpp
#include <bits/stdc++.h>
using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
constexpr int _{400010};
constexpr int N{100000};
constexpr i64 MY_INF{1ll << 60};
constexpr int mod{998244353};

int main()
{
	// std::ios::sync_with_stdio(false);
	// std::cin.tie();

	int _t{1};
	std::cin >> _t;
	while (_t--)
	{
		i64 a, b, a1, b1, a2, b2;
		std::cin >> a >> b >> a1 >> b1 >> a2 >> b2;
		if (a < a1 || b < b1 || a < a2 || b < b2)
		{
			std::cout << "0\n";
			continue;
		}
		i64 tot{(a - a1 + 1) * (b - b1 + 1) % mod * (a - a2 + 1) % mod * (b - b2 + 1) % mod};
		i64 ta{(a - a1 + 1) * (a - a2 + 1) % mod}, tb{(b - b1 + 1) * (b - b2 + 1) % mod};
		if (a >= a1 + a2)
			ta -= ((a - a1 - a2 + 2) * (a - a1 - a2 + 1) % mod),
				ta += mod,
				ta %= mod;
		if (b >= b1 + b2)
			tb -= ((b - b1 - b2 + 2) * (b - b1 - b2 + 1) % mod),
				tb += mod,
				tb %= mod;
		tot -= ta * tb % mod;
		tot += mod;
		tot %= mod;
		std::cout << tot << '\n';
	}
	return 0;
}