@Bei-S
2019-01-10T15:10:33.000000Z
字数 13544
阅读 1310
数据结构
最怕一生碌碌无为,还安慰自己平凡可贵。
平衡树就是防止2X查找树的树高过高导致复杂度过大
Splay是一种用来解决平衡树问题的算法,但是常数较大(和Treap相比)
但是他的旋转操作在LCT中便有了用武之地
可以用来进行区间操作,每次splay前驱后继即可
很重要的一个就是区间翻转(翻转标记!!!)
还有Split!!!
懒标记,最大值,子树和....
主要是用来更新父亲节点信息,例如sum,size之类的
inline void update(int x){
siz[x]=siz[son[x][0]]+siz[son[x][1]]+num[x];
}
用来翻转节点,具体过程推荐自己手画几遍最好
inline void rotate(int x){
int y=f[x],z=f[y],t=son[y][0]==x;
if(z) son[z][son[z][1]==y]=x;f[x]=z;
son[y][!t]=son[x][t];f[son[x][t]]=y;
son[x][t]=y;f[y]=x;
update(y);update(x);
}
把一个节点旋转到另一个节点下方
我们定义A为父亲和爷爷的大小关系以及B为父亲和儿子的大小关系
这里主要是分两种情况
1.当A==B,即三者在同一条链上时,我们先翻转父亲,再翻转儿子,可以降低树高(当链长较长时)
2.当A!=B,翻转两次儿子即可
inline void splay(int x,int s){
while(f[x]!=s){
int y=f[x],z=f[y];
if(z!=s) (son[y][0]==x)^(son[z][0]==y)?rotate(x):rotate(y);
rotate(x);
}
if(!s) rt=x;
}
主要用来找到第x个数的位置
当我们找到x时,把它splay到根,这时它左子树大小就是它的排名了辣
(这种写法要加入两个虚点,要不首尾节点没有前驱后继了)
inline void find(int x){
int s=rt;
if(s==0) return;
while(son[s][x>w[s]]&&w[s]!=x) s=son[s][x>w[s]];
splay(s,0);
}
先把x旋转到根,那前驱就是它左儿子的最右儿子了
后继就是右儿子最左儿子
inline int suf(int x){
find(x);
int s=rt;
if(w[s]>x) return s;
s=son[s][1];
while(son[s][0]) s=son[s][0];
return s;
}
inline int pre(int x){
find(x);
int s=rt;
if(w[s]<x) return s;
s=son[s][0];
while(son[s][1]) s=son[s][1];
return s;
}
先找到它的位置,如果有了直接num++,否则新建一个节点
inline void insert(int x){
int s=rt,fa=0;
while(s&&w[s]!=x) fa=s,s=son[s][x>w[s]];
if(s) num[s]++;
else {
s=++tot;
if(fa) son[fa][x>w[fa]]=s;
w[s]=x;
siz[s]=num[s]=1;
son[s][0]=son[s][1]=0;
f[s]=fa;
}
splay(s,0);
}
我们找到x的前驱和后继
把前驱旋转到根,
然后把后继旋转成前驱的儿子,
此时的话后继的左儿子一定是我们要的数(考虑一下为什么)。
直接删除即可(注意要清除信息)
inline void del(int x){
int pe=pre(x),sf=suf(x);
splay(pe,0);
splay(sf,pe);
int s=son[sf][0];
if(num[s]>1) {
num[s]--;
splay(s,0);
}
else son[sf][0]=0;
}
求第k大和线段树操作有点类似,我们直接类似二分的概念直接找即可
inline int kth(int x){
int s=rt;
while(1){
if(siz[son[s][0]]+num[s]<x) x-=siz[son[s][0]]+num[s],s=son[s][1];
else
if(siz[son[s][0]]>=x) s=son[s][0];
else return w[s];
}
}
这个就是把需要操作序列的前驱和后继splay一下,这样直接对整段序列进行操作啦
inline int split(int k,int tot){
int x=find(k),y=find(k+tot+1);
splay(x,0);
splay(y,x);
return son[y][0];
}
Talk is cheat,show you the code:
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+7;
const int inf=0x7ffffff;
int son[N][2],rt,f[N],siz[N],w[N],tot,num[N];
inline void update(int x){siz[x]=siz[son[x][0]]+siz[son[x][1]]+num[x];}
inline void rotate(int x){
int y=f[x],z=f[y],t=son[y][0]==x;
if(z) son[z][son[z][1]==y]=x;f[x]=z;
son[y][!t]=son[x][t];f[son[x][t]]=y;
son[x][t]=y;f[y]=x;
update(y);update(x);
}
inline void splay(int x,int s){
if(!s) rt=x;
while(f[x]!=s){
int y=f[x],z=f[y];
if(z!=s) (son[y][0]==x)^(son[z][0]==y)?rotate(x):rotate(y);
rotate(x);
}
}
inline void find(int x){
int s=rt;
if(s==0) return;
while(son[s][x>w[s]]&&w[s]!=x) s=son[s][x>w[s]];
splay(s,0);
}
inline int suf(int x){
find(x);
int s=rt;
if(w[s]>x) return s;
s=son[s][1];
while(son[s][0]) s=son[s][0];
return s;
}
inline int pre(int x){
find(x);
int s=rt;
if(w[s]<x) return s;
s=son[s][0];
while(son[s][1]) s=son[s][1];
return s;
}
inline void insert(int x){
int s=rt,fa=0;
while(s&&w[s]!=x) fa=s,s=son[s][x>w[s]];
if(s) num[s]++;
else {
s=++tot;
if(fa) son[fa][x>w[fa]]=s;
w[s]=x;
siz[s]=num[s]=1;
son[s][0]=son[s][1]=0;
f[s]=fa;
}
splay(s,0);
}
inline void del(int x){
int pe=pre(x),sf=suf(x);
splay(pe,0);
splay(sf,pe);
int s=son[sf][0];
if(num[s]>1) {
num[s]--;
splay(s,0);
}
else son[sf][0]=0;
}
inline int kth(int x){
int s=rt;
while(1){
if(siz[son[s][0]]+num[s]<x) x-=siz[son[s][0]]+num[s],s=son[s][1];
else
if(siz[son[s][0]]>=x) s=son[s][0];
else return w[s];
}
}
template<class T>
inline void read(T &num){
T x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
num=f*x;
}
int main()
{
int n;
read(n);
insert(inf-1);
insert(-inf+1);
while(n--)
{
int op,x;
read(op),read(x);
if(op==1)
insert(x);
if(op==2)
del(x);
if(op==3)
find(x),printf("%d\n",siz[son[rt][0]]);
if(op==4)
printf("%d\n",kth(x+1));
if(op==5)
printf("%d\n",w[pre(x)]);
if(op==6)
printf("%d\n",w[suf(x)]);
}
}
还有一份不需要虚点,直接求rank和前驱后继的代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+7;
const int inf=0x7ffffff;
int num[N],f[N],siz[N],son[N][2],w[N];
int tot,n,m,rt;
inline void update(int x) {siz[x]=siz[son[x][0]]+siz[son[x][1]]+num[x];}
void rtt(int x){
int y=f[x],z=f[y],t=(son[y][0]==x);
if(z) son[z][son[z][1]==y]=x;
f[x]=z;
son[y][!t]=son[x][t];
f[son[x][t]]=y;
son[x][t]=y;
f[y]=x;
update(y);
update(x);
}
void splay(int x,int s){
while(f[x]!=s){
int y=f[x],z=f[y];
if(z!=s){
if( (son[y][0]==x)^(son[z][0]==y) ) rtt(x);
else rtt(y);
}
rtt(x);
}
if(!s) rt=x;
}
inline void insert(int &x,int fa,int v){
if(!x){
x=++tot;
f[x]=fa;
w[x]=v;
siz[x]=num[x]=1;
son[x][0]=son[x][1]=0;
splay(x,0);
return;
}
if(w[x]==v){
siz[x]++;
num[x]++;
splay(x,0);
return;
}
insert(son[x][v>w[x]],x,v);
update(x);
}
int get(int x){
int s=rt;
while(w[s]!=x&&s) s=son[s][x>w[s]];
return s;
}
int rank(int x){
int s=rt,ret=0;
int fa;
while(s){
if(x<=w[s]) fa=s,s=son[s][0];
else{
fa=s;
ret+=siz[son[s][0]]+num[s];
s=son[s][1];
}
}
splay(fa,0);
return ret+1;
}
int kth(int x){
int s=rt;
int fa=0;
while(x<=siz[son[s][0]]||x>siz[son[s][0]]+num[s]){
if(x<=siz[son[s][0]]) {
fa=s;
s=son[s][0];
}
else{
fa=s;
x-=siz[son[s][0]]+num[s];
s=son[s][1];
}
}
if(fa!=0) splay(fa,0);
return w[s];
}
void del(int x){
x=get(x);
if(!x) return;
splay(x,0);
if(num[x]>1) {
num[x]--;
siz[x]--;
return;
}
if(!son[x][0]||!son[x][1]) {
rt=son[x][0]+son[x][1];
}
else{
int y=son[x][1];
while(son[y][0]) y=son[y][0];
splay(y,x);
son[y][0]=son[x][0];
f[son[x][0]]=y;
rt=y;
}
update(rt);
f[rt]=0;
}
int pre(int x){
int s=rt,ret=-inf,fa;
while(s){
if(x<=w[s]) fa=s,s=son[s][0];
else fa=s,ret=max(ret,w[s]),s=son[s][1];
}
splay(fa,0);
return ret;
}
int suf(int x){
int s=rt,ret=inf,fa;
while(s){
if(x<w[s]) fa=s,ret=min(ret,w[s]),s=son[s][0];
else fa=s,s=son[s][1];
}
splay(fa,0);
return ret;
}
template<class T>
inline void read(T &num){
T x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
num=f*x;
}
int main()
{
int n;
read(n);
while(n--)
{
int op,x;
read(op),read(x);
if(op==1)
insert(rt,0,x);
if(op==2)
del(x);
if(op==3)
printf("%d\n",rank(x));
if(op==4)
printf("%d\n",kth(x));
if(op==5)
printf("%d\n",pre(x));
if(op==6)
printf("%d\n",suf(x));
}
}
还有用树状数组的[自己YY的,复杂度最坏应该多个log]
#include<bits/stdc++.h>
using namespace std;
#define lowbit(x) x & -x
#define mid ((l+r)>>1)
const int N=1e5+7;
int opt[N],a[N],b[N],n,cnt,m;
int s[N];
inline void add(int x,int y){
while(x<=m){
s[x]+=y;
x+=lowbit(x);
}
}
inline int sa(int x){
int sum=0;
while(x){
sum+=s[x];
x-=lowbit(x);
}
return sum;
}
inline int gt(int x){
return lower_bound(b+1,b+1+m,x)-b;
}
inline int bk(int x){
return b[x];
}
inline int rank(int x){
return sa(x-1)+1;
}
inline int kth(int x){
int l=1,r=m;
while(l<r){
if(sa(mid)<x) l=mid+1;
else r=mid;
}
return bk(l);
}
void fre(){
freopen("tree.in","r",stdin);
freopen("my.out","w",stdout);
}
inline int dsa(int x){
return sa(x)-sa(x-1);
}
int main(){
// fre();
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d",&opt[i],&a[i]);
if(opt[i]^4) b[++cnt]=a[i];
}
sort(b+1,b+1+cnt);
m=unique(b+1,b+1+cnt)-b-1;
for(int i=1;i<=n;i++){
if(opt[i]==1) add(gt(a[i]),1);
if(opt[i]==2) add(gt(a[i]),-1);
if(opt[i]==3) printf("%d\n",rank(gt(a[i])));
if(opt[i]==4) printf("%d\n",kth(a[i]));
if(opt[i]==5) printf("%d\n",kth(rank(gt(a[i]))-1));
if(opt[i]==6) printf("%d\n",kth(rank(gt(a[i]))+dsa(gt(a[i]))));
}
}
还有vector[虽然最坏情况被卡炸,但是娱乐一下么~]
#include<bits/stdc++.h>
using namespace std;
int n;
vector<int> v;
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
int opt,x;
scanf("%d%d",&opt,&x);
if(opt==1) v.insert(lower_bound(v.begin(),v.end(),x),x);
if(opt==2) v.erase(lower_bound(v.begin(),v.end(),x));
if(opt==3) printf("%d\n",lower_bound(v.begin(),v.end(),x)-v.begin()+1);
if(opt==4) printf("%d\n",v[x-1]);
if(opt==5) printf("%d\n",*--lower_bound(v.begin(),v.end(),x));
if(opt==6) printf("%d\n",*upper_bound(v.begin(),v.end(),x));
}
}
这里Splay维护的是在Shelf上面的位置。
其实就是维护一段序列,top和bottom就相当于把左(右)子树全都拎走,
Pos记录每个值对应点的编号(Splay不是按照编号排序,按位置排序)。
Insert操作就是和前驱(后继)交换信息
Ask和Query就是常规操作了
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+7;
int siz[N],son[N][2],f[N],w[N],pos[N],n,m,rt,tot;
char ch[50];
int inline read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline void update(int x){
siz[x]=siz[son[x][0]]+siz[son[x][1]]+1;
}
inline void rotate(int x){
int y=f[x],z=f[y],t=son[y][0]==x;
if(z) son[z][y==son[z][1]]=x;
f[x]=z;
son[y][!t]=son[x][t];
f[son[x][t]]=y;
son[x][t]=y;
f[y]=x;
update(y);
update(x);
}
inline void splay(int x,int s){
if(s==0) rt=x;
while(f[x]!=s) {
int y=f[x],z=f[y];
if(z!=s) (son[y][0]==x)^(son[z][0]==y)?rotate(x):rotate(y);
rotate(x);
}
}
inline void insert(int x){
w[++tot]=x;
siz[tot]=1;
son[tot][0]=son[tot][1]=0;
pos[x]=tot;
if(tot>1){
son[tot-1][1]=tot;
f[tot]=tot-1;
splay(tot,0);
}
}
inline int find(int x){
int s=rt;
while(1){
if(siz[son[s][0]]+1==x) return s;
if(siz[son[s][0]]>=x) s=son[s][0];
else if(siz[son[s][0]]+1<x) x-=siz[son[s][0]]+1,s=son[s][1];
}
return 0;
}
inline void top(int x){
x=pos[x];
splay(x,0);
if(!son[x][0]) return;
if(!son[x][1]) son[x][1]=son[x][0],son[x][0]=0;
else{
int y=find(siz[son[x][0]]+2);
f[son[x][0]]=y;
son[y][0]=son[x][0];
son[x][0]=0;
splay(y,0);
}
}
inline void bottom(int x){
x=pos[x];
splay(x,0);
if(!son[x][1]) return;
if(!son[x][0]) son[x][0]=son[x][1],son[x][1]=0;
else{
int y=find(siz[son[x][0]]);
f[son[x][1]]=y;
son[y][1]=son[x][1];
son[x][1]=0;
splay(y,0);
}
}
/*void ins(int f,int x)
{
if (!f) return;
splay(pos[x],0);
int y=find(f==1?siz[son[pos[x]][0]]+2:siz[son[pos[x]][0]]);
int x1=w[y],x2=pos[x];
swap(pos[x],pos[x1]);
swap(w[x2],w[y]);
}*/
inline void ins(int f,int x){
if(!f) return;
splay(pos[x],0);
int s;
if(f==1){//suf
s=find(siz[son[pos[x]][0]]+2);
int t=pos[x],ff=w[s];
swap(pos[x],pos[ff]);
swap(w[t],w[s]);
}
else{
s=find(siz[son[pos[x]][0]]);
int t=pos[x],ff=w[s];
swap(pos[x],pos[ff]);
swap(w[t],w[s]);
}
}
inline void getans(int x){
splay(pos[x],0);
printf("%d\n",siz[son[pos[x]][0]]);
}
int main(){
// freopen("tx.txt","r",stdin);
scanf("%d%d",&n,&m);
rt=1;
for(int i=1;i<=n;i++){
int p;
scanf("%d",&p);
insert(p);
}
for(int i=1;i<=m;i++){
scanf("%s",ch);
switch(ch[0])
{
case 'T':top(read());break;
case 'B':bottom(read());break;
case 'I':ins(read(),read());break;
case 'A':getans(read());break;
case 'Q':printf("%d\n",w[find(read())]);break;
}
}
}
这道题真的太恶心了
首先空间不够,要开队列recycle
其次就是在区间翻转的时候,lx和rx也要交换位置
最后就是在find的时候需要pushdown(因为这个调了一下午(>_<)!!)
节点要维护区间和,前后缀
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+7;
const int inf=0x7ffffff;
int n,m,rt,cnt;
int f[N],son[N][2],w[N],siz[N],id[N],lx[N],rx[N],mx[N],sum[N],a[N];
bool tag[N],rev[N];
queue<int> q;
inline void update(int x){
int l=son[x][0],r=son[x][1];
sum[x]=sum[l]+sum[r]+w[x];
siz[x]=siz[l]+siz[r]+1;
mx[x]=max(mx[l],max(mx[r],rx[l]+w[x]+lx[r]));
lx[x]=max(lx[l],sum[l]+w[x]+lx[r]);
rx[x]=max(rx[r],sum[r]+w[x]+rx[l]);
}
inline void pushdown(int x){
int l=son[x][0],r=son[x][1];
if(tag[x]){
rev[x]=tag[x]=0;
if(l) tag[l]=1,w[l]=w[x],sum[l]=w[x]*siz[l];
if(r) tag[r]=1,w[r]=w[x],sum[r]=w[x]*siz[r];
if (w[x]>=0){
if (l)lx[l]=rx[l]=mx[l]=sum[l];
if (r)lx[r]=rx[r]=mx[r]=sum[r];
}else{
if (l)lx[l]=rx[l]=0,mx[l]=w[x];
if (r)lx[r]=rx[r]=0,mx[r]=w[x];
}
}
if(rev[x]){
rev[x]=0;rev[l]^=1;rev[r]^=1;
swap(lx[l],rx[l]);swap(lx[r],rx[r]);
swap(son[l][0],son[l][1]);swap(son[r][0],son[r][1]);
}
}
inline void rotate(int x){
int y=f[x],z=f[y],t=son[y][0]==x;
if(z) son[z][son[z][1]==y]=x;f[x]=z;
// if (y==k)k=x;else son[z][son[z][1]==y]=x;
son[y][!t]=son[x][t];f[son[x][t]]=y;
son[x][t]=y;f[y]=x;
update(y);
update(x);
}
inline void splay(int x,int s){
while(f[x]!=s){
int y=f[x],z=f[y];
if(z!=s) (son[y][0]==x)^(son[z][0]==y)?rotate(x):rotate(y);
rotate(x);
}
if(!s) rt=x;
}
inline int find(int x){
int s=rt;
while(1){
pushdown(s);
if(siz[son[s][0]]+1==x) return s;
if(siz[son[s][0]]>=x) s=son[s][0];
else x-=siz[son[s][0]]+1,s=son[s][1];
}
}
/*inline void rotate(int x,int &k){
int y=f[x],z=f[y],l=(son[y][1]==x),r=l^1;
if (y==k)k=x;else son[z][son[z][1]==y]=x;
f[son[x][r]]=y;f[y]=x;f[x]=z;
son[y][l]=son[x][r];son[x][r]=y;
update(y);
update(x);
}
inline void splay(int x,int &k){
while (x!=k){
int y=f[x],z=f[y];
if (y!=k){
if (son[z][0]==y ^ son[y][0]==x)rotate(x,k);
else rotate(y,k);
}
rotate(x,k);
}
}
inline int find(int x){
int s=rt;
while(1){
pushdown(s);
if(siz[son[s][0]]+1==x) return s;
if(siz[son[s][0]]>=x) s=son[s][0];
else x-=siz[son[s][0]]+1,s=son[s][1];
}
}*/
inline void recycle(int x){
int &l=son[x][0],&r=son[x][1];
if(l) recycle(l);
if(r) recycle(r);
q.push(x);
f[x]=l=r=tag[x]=rev[x]=0;
}
inline int split(int k,int tot){
int x=find(k),y=find(k+tot+1);
splay(x,0);
splay(y,x);
return son[y][0];
}
inline void sa(int k,int tot){
int x=split(k,tot);
printf("%d\n",sum[x]);
}
inline void modify(int k,int tot,int val){
int x=split(k,tot),y=f[x];
w[x]=val;
tag[x]=1;
sum[x]=siz[x]*val;
if (val>=0)lx[x]=rx[x]=mx[x]=sum[x];
else lx[x]=rx[x]=0,mx[x]=val;
update(y);
update(f[y]);
}
inline void rever(int k,int tot){
int x=split(k,tot),y=f[x];
if (!tag[x]){
rev[x]^=1;
swap(son[x][0],son[x][1]);
swap(lx[x],rx[x]);
update(y);update(f[y]);
}
}
inline void del(int k,int tot){
int x=split(k,tot),y=f[x];
recycle(x);
son[y][0]=0;
update(y);update(f[y]);
}
inline void build(int l,int r,int fa){
int mid=(l+r)>>1,now=id[mid],pre=id[fa];
if(l==r) {
mx[now]=sum[now]=a[l];
tag[now]=rev[now]=0;
lx[now]=rx[now]=max(a[l],0);
siz[now]=1;
}
if(l<mid) build(l,mid-1,mid);
if(r>mid) build(mid+1,r,mid);
w[now]=a[mid];
f[now]=pre;
update(now);
son[pre][mid>=fa]=now;
}
inline void insert(int k,int tot){
for(int i=1;i<=tot;i++) scanf("%d",&a[i]);
for(int i=1;i<=tot;i++){
if (!q.empty())id[i]=q.front(),q.pop();
else id[i]=++cnt;//利用队列中记录的冗余节点编号
}
build(1,tot,0);
int z=id[(1+tot)>>1];
int x=find(k+1),y=find(k+2);
splay(x,0);
splay(y,x);
f[z]=y;
son[y][0]=z;
update(y);
update(x);
}
inline int read(){
int x=0,ff=1;char ch=getchar();
while (ch<'0' || ch>'9'){if (ch=='-')ff=-1;ch=getchar();}
while ('0'<=ch && ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return x*ff;
}
int main(){
scanf("%d%d",&n,&m);
mx[0]=a[1]=a[n+2]=-inf;
for(int i=1;i<=n;i++){
scanf("%d",&a[i+1]);
}
for(int i=1;i<=n+2;i++) id[i]=i;
build(1,n+2,0);
rt=(n+3)>>1;
cnt=n+2;
int k,tot,val;
char ch[20];
while(m--){
scanf("%s",ch);
if (ch[0]!='M' || ch[2]!='X') k=read(),tot=read();
if (ch[0]=='I')insert(k,tot);
if (ch[0]=='D')del(k,tot);
if (ch[0]=='M'){
if (ch[2]=='X')printf("%d\n",mx[rt]);
else val=read(),modify(k,tot,val);
}
if (ch[0]=='R')rever(k,tot);
if (ch[0]=='G')sa(k,tot);
}
}