这道题目的核心任务是:给定一个长度为n的数组a,求出所有子数组的最大值之和。
solve(l, r)
函数:
pre
数组中,并计算每个位置的前缀最大值。pre
数组的前缀和,存储在sum
数组中。i
和j
分别从左半部分的末尾和右半部分的开头开始遍历。max
变量,记录当前左半部分的最大值。a[i]
,找到右半部分中第一个不小于a[i]
的元素a[j]
。
(j - mid) * max
:表示以a[i]
为左端点,在右半部分能扩展到的子数组个数乘以最大值。sum[r - mid] - sum[j - mid]
:表示右半部分从j
到r-1
这一段的最大值之和。solve
函数,计算并输出答案。pre
和sum
。该代码通过分治、双指针和前缀和的巧妙结合,高效地解决了求所有子数组最大值之和的问题。算法的时间复杂度和空间复杂度都比较优秀。
需要注意的是:
solve
处理的是左闭右开的区间,即[l, r)
。pre
数组存储的是从pre[0]
到pre[i]
的最大值,而不是从a[mid]
到a[i]
的最大值。sum
数组存储的是pre
数组的前缀和。代码实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | /**************************************************************** * Description: 2023_CSP_S 完善程序-2 * Author: Alex Li * Date: 2024-06-19 09:27:04 * LastEditTime: 2024-06-19 22:50:42 ****************************************************************/ #include <iostream> #include <algorithm> #include <vector> const int MAXN=100000; int n; int a[MAXN]; long long ans; void solve(int l, int r){ if(l+1==r){ ans+=a[l]; //左闭右开的区间,a[0]~a[r-1] return ; } int mid=(l+r)>>1; std::vector<int> pre(a+mid, a+r); //复制数组a[mid]~a[r-1]到pre //经过这个循环,pre[i]存储的是从pre[0]到pre[i]的最大值。 for(int i=1; i<r-mid;++i) pre[i]=std::max(pre[i],pre[i-1]); std::vector<long long> sum(r-mid+1); for(int i=0;i<r-mid;++i) sum[i+1]=sum[i]+pre[i]; //前缀最大值之和 /*i:从mid - 1(左半部分的最后一个元素)开始,向左遍历到1。 j:从mid(右半部分的第一个元素)开始,向右遍历到r。 max:用于记录遍历过程中左半部分的最大值。 这段代码的目的是通过遍历左半部分,结合右半部分的信息,计算所有跨越中点的子数组的最大值之和。 通过调整指针j,确定每个左半部分的元素a[i]在右半部分能扩展到的位置,并且结合前缀最大值和前缀和, 快速计算并累加这些子数组的最大值之和。 */ for(int i=mid-1,j=mid,max=0;i>=l;--i){ //这个循环将j指针向右移动,直到找到第一个不小于a[i]的元素或者到达右半部分的末尾r。 //目的:找到右半部分中第一个不小于a[i]的元素,从而确定a[i]在右半部分能扩展到的位置。 while(j<r&&a[j]<a[i])++j; max=std::max(max,a[i]); //更新最大值 //计算并累加左半部分当前元素a[i]能扩展到右半部分的位置数(j - mid),乘以当前的最大值max。 //(j - mid)表示从mid到j-1这一段元素的数量。 //max是当前左半部分的最大值。 ans+=(long long)(j-mid)*max; /*使用右半部分的前缀和数组sum,累加从j到r-1这一段的最大值之和。 sum[r - mid]表示右半部分所有元素的最大值之和。 sum[j - mid]表示从mid到j-1这一段元素的最大值之和。 因此,sum[r - mid] - sum[j - mid]表示从j到r-1这一段元素的最大值之和。 */ ans+=sum[r-mid]-sum[j-mid]; } solve(l,mid); solve(mid,r); } int main(){ std::cin>>n; for(int i=0;i<n;++i) std::cin>>a[i]; solve(0,n); //左闭右开的区间 std::cout<<ans<<std::endl; return 0; } |