-
Notifications
You must be signed in to change notification settings - Fork 0
The simplest wait free algorithm in the world (getAndIncrement)
Pslydhh edited this page Nov 24, 2017
·
1 revision
package com.psly;
import java.lang.reflect.Field;
import java.util.concurrent.CountDownLatch;
import sun.misc.Unsafe;
public class WaitFreeAtomic {
public final static int N = 1000;
public final static int loops = 100000;
public final static int MAX = 64;
public final static int bigYields = 32;
public final static int[] ints = new int[N * loops];
// private final static AtomicInteger inter = new AtomicInteger();
public static void main(String[] args) throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
Thread[] threads = new Thread[N];
for(int i = 0; i < N; ++i) {
threadObjs[i] = new ThreadObj(null);
states[i] = new StateObj(7);
int threadId = i;
(threads[i] = new Thread(){
public void run(){
try {
latch.await();
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
for(int j = 0; j < loops; ++j)
ints[getAndIncrement(threadId)] = 1;
}
}).start();
}
long start =System.currentTimeMillis();
latch.countDown();
for(Thread thread: threads)
thread.join();
System.out.println(valueObj.value);
for(int j = 0; j < N * loops; ++j) {
if(ints[j] != 1)
System.out.println(j + " " + ints[j] + " wrong!");
}
System.out.println("over");
System.out.println("\ntimes: " + ((System.currentTimeMillis() - start) / 1000.0) + " seconds");
}
final static boolean casValueObj(ValueObj cmp, ValueObj val) {
return UNSAFE.compareAndSwapObject(valueBase, valueObjOffset, cmp, val);
}
static volatile ValueObj valueObj = new ValueObj(0, null);
//value
private static final Object valueBase;
private static final long valueObjOffset;
final static ThreadObj getThreadObj(long i) {
return (ThreadObj) UNSAFE.getObjectVolatile(threadObjs, ((long) i << ASHIFT) + ABASE);
}
final static void setThreadObj(int i, ThreadObj v) {
UNSAFE.putObjectVolatile(threadObjs, ((long) i << ASHIFT) + ABASE, v);
}
final static boolean casThreadObj(int i, ThreadObj cmp, ThreadObj finish) {
return UNSAFE.compareAndSwapObject(threadObjs, ((long) i << ASHIFT) + ABASE, cmp, finish);
}
final static ThreadObj[] threadObjs= new ThreadObj[N];
final static StateObj[] states = new StateObj[N];
private static final sun.misc.Unsafe UNSAFE;
//thread array
private static final int _Obase;
private static final int _Oscale;
private static final long ABASE;
private static final int ASHIFT;
static {
try {
UNSAFE = UtilUnsafe.getUnsafe();
valueObjOffset = UNSAFE.staticFieldOffset(WaitFreeAtomic.class.getDeclaredField("valueObj"));
valueBase = UNSAFE.staticFieldBase(WaitFreeAtomic.class.getDeclaredField("valueObj"));
_Obase = UNSAFE.arrayBaseOffset(ThreadObj[].class);
_Oscale = UNSAFE.arrayIndexScale(ThreadObj[].class);
ABASE = _Obase;
if ((_Oscale & (_Oscale - 1)) != 0)
throw new Error("data type scale not a power of two");
ASHIFT = 31 - Integer.numberOfLeadingZeros(_Oscale);
} catch (Exception e) {
throw new Error(e);
}
}
static class ThreadObj {
public ThreadObj(WrapperObj wrapObj) {
super();
this.wrapperObj = wrapObj;
}
WrapperObj wrapperObj;
long[] longs = new long[16];
static final class WrapperObj {
final ValueObj value;
final boolean isFinish;
public WrapperObj(ValueObj value, boolean isFinish) {
super();
this.value = value;
this.isFinish = isFinish;
}
}
boolean casWrapValue(WrapperObj cmp, WrapperObj val) {
return UNSAFE.compareAndSwapObject(this, wrapValueOffset, cmp, val);
}
private static final sun.misc.Unsafe UNSAFE;
private static final long wrapValueOffset;
static {
try {
UNSAFE = UtilUnsafe.getUnsafe();
wrapValueOffset = UNSAFE.objectFieldOffset(ThreadObj.class.getDeclaredField("wrapperObj"));
} catch (Exception e) {
throw new Error(e);
}
}
}
private static class StateObj {
public StateObj(int assistStep) {
super();
this.assistStep = assistStep;
this.steps = 0;
this.index = 0;
}
private final int assistStep;
private int steps;
private long index;
}
private static class ValueObj {
private final int value;
private final ThreadObj threadObj;
public ValueObj(int value, ThreadObj threadObj) {
super();
this.value = value;
this.threadObj = threadObj;
}
}
// valueObj->threadObj->wrapperObj->valueObj。
// step 1-3,每一个步骤都不会阻塞其他步骤。
// 严格遵守以下顺序:
// step 1: 通过将ValueObj指向ThreadObj:
// atomic: (value, null)->(value, ThreadObj)来锚定该值 //确定该value归ThreadObj对应线程所有。
// step 2: 通过将ThreadObj包裹的WrapperObj,
// atomic: 从(null, false)更新为(valueObj, true)来更新状态的同时传递value //对应线程通过isFinish判定操作已完成。
// step 3: 更新ValueObj,提升value,同时设置ThreadObj为null:
// atomic: (value, ThreadObj)->(value+1, null)完成收尾动作 //此时value值回到了没有被线程锚定的状态,也可以看做step1之前的状态。
private static ValueObj help(long helpIndex) {
helpIndex = helpIndex % N;
ThreadObj helpObj = getThreadObj(helpIndex);
ThreadObj.WrapperObj wrapperObj;
if(helpObj == null || helpObj.wrapperObj == null)
return null;
//判定句,是否该线程对应的操作未完成,(先取valueObj,再取isFinish,这很重要)。
ValueObj valueObj_ = valueObj;
while(!(wrapperObj = helpObj.wrapperObj).isFinish) {
ThreadObj threadObj;
/*ValueObj valueObj_ = valueObj;*/
if((threadObj = valueObj_.threadObj) == null) {
ValueObj intermediateObj = new ValueObj(valueObj_.value, helpObj);
//step1
if(!casValueObj(valueObj_, intermediateObj)) {
valueObj_ = valueObj;
continue;
}
//step1: 锚定该ValueObj,接下来所有看到该valueObj的线程,都会一致地完成一系列操作.
valueObj_ = intermediateObj;
threadObj = helpObj;
}
//完成ValueObj、ThreadObj中的WrapperObj的状态迁移。
helpTransfer(valueObj_, threadObj);
valueObj_ = valueObj;
}
valueObj_ = wrapperObj.value;
helpValueTransfer(valueObj_);
//返回锚定的valueObj。
return valueObj_;
}
private static void helpTransfer(ValueObj valueObj_, ThreadObj threadObj) {
ThreadObj.WrapperObj wrapperObj = threadObj.wrapperObj;
//step2: 先完成ThreadObj的状态迁移,WrapperObj(valueObj,true)分别表示(值,完成),原子地将这两个值喂给threadObj。
if(!wrapperObj.isFinish) {
ThreadObj.WrapperObj wrapValueFiniash = new ThreadObj.WrapperObj(valueObj_, true);
threadObj.casWrapValue(wrapperObj, wrapValueFiniash);
}
//step3: 最后完成ValueObj上的状态迁移
helpValueTransfer(valueObj_);
}
private static ValueObj helpValueTransfer(ValueObj valueObj_) {
if(valueObj_ == valueObj) {
ValueObj valueObjNext = new ValueObj(valueObj_.value + 1, null);
casValueObj(valueObj_, valueObjNext);
}
return valueObj_;
}
public static int getAndIncrement(int index) {
StateObj myState = states[index];
//前进一步,每assistStep,尝试一个帮助。
if((++myState.steps) % myState.assistStep == 0){
long helpThread = myState.index;
help(helpThread);
//下一个协助的对象。
++myState.index;
}
//fast-path, 最多MAX次。
int count = MAX;
for(;;) {
ValueObj valueObj_ = valueObj;
ThreadObj threadObj;
if((threadObj = valueObj_.threadObj) == null) {
ValueObj valueObjNext = new ValueObj(valueObj_.value + 1, null);
if(casValueObj(valueObj_, valueObjNext))
return valueObj_.value;
Thread.yield();Thread.yield();Thread.yield();Thread.yield();
} else {
helpTransfer(valueObj_, threadObj);
}
if(--count == 0)
break;
}
// System.out.println("here " + inter.incrementAndGet());
for(int j = 0; j < bigYields; ++j)
Thread.yield();
//slow-path,将自己列为被帮助对象。
ThreadObj myselfObj = new ThreadObj(new ThreadObj.WrapperObj(null, false));
setThreadObj(index, myselfObj);
//开始帮助自己
ValueObj result = help(index);
setThreadObj(index, null);
return result.value;
}
private static class UtilUnsafe {
private UtilUnsafe() {
}
public static Unsafe getUnsafe() {
if (UtilUnsafe.class.getClassLoader() == null)
return Unsafe.getUnsafe();
try {
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
fld.setAccessible(true);
return (Unsafe) fld.get(UtilUnsafe.class);
} catch (Exception e) {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}
}