题目大意
给定一个长度为 N 的整数序列,我们需要将其连续地分割成四段(每段至少包含一个数),且每段数字之和相等。求满足条件的不同分割方法总数。
思路分析
首先,整个序列的总和必须是 4 的倍数,否则直接输出 0。
由于后续需要用到前 n 项的和,因此维护一个前缀和数组 s。
//输入
for (int i = 1; i <= n; i++) {
scanf("%lld", &a);
s[i] = s[i - 1] + a; // 计算前缀和
}
//特判不可分割
if (s[n] % 4 != 0) { // 总和不是4的倍数,无解
cout << 0;
return 0;
}
设总和为 S (即: s[n]),则每段的和应为 T = S / 4。
t = s[n] / 4; // 每段的目标和
特殊情况:T = 0
当 T = 0 时,整个序列的总和为 0。此时,我们需要在 [1, n-1] 之间找到三个位置 i,j,k,使得前缀和 s[i] = 0、s[j] = 0、s[k] = 0。由于每段至少有一个数,所以这三个位置必须不同。假设一共有 cnt 个位置(包括 i 从 1 到 n-1 )满足 s[i] = 0,那么我们需要从这 cnt 个位置中选出 3 个(因为要分成四段,需要三个分割点),即组合数 C_{cnt}^{3}。
if (t == 0) { // 特殊情况:每段和为0
ll cnt = 0;
for (int i = 1; i < n; i++) { // 注意:分割点不取n,因为最后一段不能为空
if (s[i] == 0) {
cnt++;
}
}
if (cnt < 3) {
ans = 0;
} else {
// 从cnt个点中选3个分割点(组合数C(cnt,3))
ans = cnt * (cnt - 1) * (cnt - 2) / 6;
}
cout << ans;
return 0;
}
一般情况:T \neq 0
我们设三个分割点为 i, j, k(满足 1 \le i < j < k < N ),则四段分别为:
- [1, i],和为 s[i] = T
- [i+1, j],和为 s[j] - s[i] = T,即 s[j] = 2T
- [j+1, k],和为 s[k] - s[j] = T,即 s[k] = 3T
- [k+1, N],和为 s[N] - s[k] = T,即 s[N] = 4T(一般情况下必然成立)
因此,我们需要统计满足以下条件的三元组 (i, j, k) :
- s[i] = T
- s[j] = 2T
- s[k] = 3T
- 且 i < j < k。
我们可以遍历整个前缀和数组,同时维护两个计数器:
- c1:记录到目前为止满足 s[i] = T 的个数(这些位置可以作为第一个分割点)
- c2:记录到目前为止满足 s[j] = 2T 的个数(这些位置可以作为第二个分割点,且每个 j 都可以对应前面的所有第一个分割点,即所有 c1 )
- ans:记录到目前为止满足 s[k] = 3T 的个数,(这些位置可以作为第三个分割点,且每个 k 都可以对应前面的所有第二个分割点,即所有 c2,也正是答案。
// 一般情况:T != 0
for (int i = 1; i <= n; i++) {
if (s[i] == 3 * t) ans += c2; // 找到第三个分割点,累加第二个分割点的数量
if (s[i] == 2 * t) c2 += c1; // 找到第二个分割点,累加第一个分割点的数量
if (s[i] == t) c1++; // 找到第一个分割点,计数增加
}
注意:遍历顺序很重要,必须先判断 3T,再判断 2T,最后判断 T。这样可以确保当我们处理某个位置时,之前的位置已经被正确计数,避免同一个位置被重复使用。
代码
#include<bits/stdc++.h>
#define ll long long
#define N 100005
using namespace std;
ll a, s[N], n, t, c1, c2, ans;
int main() {
scanf("%lld", &n);
for (int i = 1; i <= n; i++) {
scanf("%lld", &a);
s[i] = s[i - 1] + a; // 计算前缀和
}
if (s[n] % 4 != 0) { // 总和不是4的倍数,无解
cout << 0;
return 0;
}
t = s[n] / 4; // 每段的目标和
if (t == 0) { // 特殊情况:每段和为0
ll cnt = 0;
for (int i = 1; i < n; i++) { // 注意:分割点不能取n,因为最后一段不能为空
if (s[i] == 0) {
cnt++;
}
}
if (cnt < 3) {
ans = 0;
} else {
// 从cnt个点中选3个分割点(组合数C(cnt,3))
ans = cnt * (cnt - 1) * (cnt - 2) / 6;
}
cout << ans;
return 0;
}
// 一般情况:t != 0
for (int i = 1; i <= n; i++) {
if (s[i] == 3 * t) ans += c2; // 找到第三个分割点,累加第二个分割点的数量
if (s[i] == 2 * t) c2 += c1; // 找到第二个分割点,累加第一个分割点的数量
if (s[i] == t) c1++; // 找到第一个分割点,计数增加
}
cout << ans;
return 0;
}
时间复杂度:O(N),最多只需要遍历数组 2 次(特殊情况 1 次,一般情况 1 次)。