Skip to content

Commit

Permalink
Fix some dividend contract compile error, update test case
Browse files Browse the repository at this point in the history
  • Loading branch information
weiqiushi committed Apr 24, 2024
1 parent 4600b4d commit 1fa3f17
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 41 deletions.
46 changes: 28 additions & 18 deletions contracts/dividend.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ pragma solidity ^0.8.0;
import "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import "@openzeppelin/contracts/utils/ReentrancyGuard.sol";

import "hardhat/console.sol";

contract DividendContract is ReentrancyGuard {
address public stakingToken;

Expand Down Expand Up @@ -35,7 +37,8 @@ contract DividendContract is ReentrancyGuard {
uint256 public totalStaked;

// the cycle info of the contract
CycleInfo[] public cycles;
//CycleInfo[] public cycles;
mapping(uint256 => CycleInfo) public cycles;

// the staking record of the user
struct StakeRecord {
Expand Down Expand Up @@ -64,11 +67,12 @@ contract DividendContract is ReentrancyGuard {
return currentCycleIndex;
}

function getCurrentCycle() public view returns (CycleInfo storage) {
function getCurrentCycle() internal returns (CycleInfo storage) {
/*
uint256 currentCycleIndex = getCurrentCycleIndex();
if (cycles.length <= currentCycleIndex) {
cycles.push();
}
cycles.push(CycleInfo(block.number, 0, new RewardInfo[](0)));
}*/
return cycles[currentCycleIndex];
}

Expand Down Expand Up @@ -109,12 +113,16 @@ contract DividendContract is ReentrancyGuard {

function getStakedAmount(address user, uint256 cycleIndex) public view returns (uint256) {
StakeRecord[] memory stakeRecords = UserStakeRecords[user];
for (uint i = stakeRecords.length - 1; i != uint(-1); i--) {
for (uint i = stakeRecords.length - 1; ; i--) {

// stakeRecords里面的对应周期的质押数据,都是对应周期发起的操作导致的状态,所以需要进入下一个周期才会生效,所以这里使用 < 而不是 <=
if (stakeRecords[i].cycleIndex < cycleIndex) {
return stakeRecords[i].amount;
}

if (i == 0) {
break;
}
}

return 0;
Expand All @@ -131,7 +139,7 @@ contract DividendContract is ReentrancyGuard {
if (stakeRecords.length == 0) {
stakeRecords.push(StakeRecord(currentCycleIndex, amount, amount));
} else {
StackRecord storage lastStakeRecord = stakeRecords[stakeRecords.length - 1];
StakeRecord storage lastStakeRecord = stakeRecords[stakeRecords.length - 1];
if (lastStakeRecord.cycleIndex == currentCycleIndex) {
lastStakeRecord.amount += amount;
lastStakeRecord.newAmount += amount;
Expand All @@ -154,18 +162,18 @@ contract DividendContract is ReentrancyGuard {
require(stakeRecords.length > 0, "No stake record found");

// get the last stake record of the user
StackRecord storage lastStakeRecord = stakeRecords[stakeRecords.length - 1];
StakeRecord storage lastStakeRecord = stakeRecords[stakeRecords.length - 1];
require(lastStakeRecord.amount >= amount, "Insufficient staked amount");

// 如果存在当前周期的质押操作,那么这个质押操作是可以直接撤销的不影响周期数据(当前质押要在下个周期进入cycleInfo中)
// 如果不是当前周期的质押操作,或者当前周期的质押数量不足,那么这个质押操作是需要从上个周期关联的cycleInfo数据中减去的
uint256 currentCycleIndex = getCurrentCycleIndex();
if (lastStakeRecord.cycleIndex == currentCycleIndex) {
if (lastStakeRecord.addAmount >= amount) {
if (lastStakeRecord.newAmount >= amount) {
lastStakeRecord.amount -= amount;
lastStakeRecord.addAmount -= amount;
lastStakeRecord.newAmount -= amount;
} else {
uint256 diff = amount - lastStakeRecord.addAmount;
uint256 diff = amount - lastStakeRecord.newAmount;

StakeRecord memory prevStakeRecord = stakeRecords[stakeRecords.length - 2];

Expand Down Expand Up @@ -196,7 +204,7 @@ contract DividendContract is ReentrancyGuard {
CycleInfo storage currentCycle = getCurrentCycle();
if (currentBlock - currentCycle.startBlock >= cycleMaxLength) {
currentCycleIndex = currentCycleIndex + 1;

console.log("enter new cycle %d", currentCycleIndex);
CycleInfo storage newCycle = cycles[currentCycleIndex];
newCycle.startBlock = currentBlock;
newCycle.totalStaked = totalStaked;
Expand All @@ -210,25 +218,26 @@ contract DividendContract is ReentrancyGuard {
}

// check if the user has settled the rewards for the cycle
function isDividendWithdrawed(address user, uint256 cycleIndex, uint256 token) public view returns (bool) {
function isDividendWithdrawed(address user, uint256 cycleIndex, address token) public view returns (bool) {
bytes32 key = keccak256(abi.encodePacked(user, cycleIndex, token));
return withdrawDividendState[key];
}

// claim rewards for the cycle
function withdrawDevidends(uint256[] cycleIndexs, uint256[] tokens) external nonReentrant {
function withdrawDevidends(uint256[] calldata cycleIndexs, address[] calldata tokens) external nonReentrant {
require(cycleIndexs.length > 0, "No cycle index");
require(tokens.length > 0, "No token");

RewardInfo[] storage rewards = [];
RewardInfo[] memory rewards = new RewardInfo[](cycleIndexs.length*tokens.length);
uint256 realRewardLength = 0;

for (uint i = 0; i < cycleIndexs.length; i++) {
uint256 cycleIndex = cycleIndexs[i];
require(cycleIndex < currentCycleIndex, "Cannot claim current cycle");

// withdraw every token
for (uint j = 0; j < tokens.length; j++) {
uint256 token = tokens[j];
address token = tokens[j];
require(!isDividendWithdrawed(msg.sender, cycleIndex, token), "Already claimed");

CycleInfo storage cycle = cycles[cycleIndex];
Expand All @@ -251,7 +260,8 @@ contract DividendContract is ReentrancyGuard {
}

if (rewardAmount > 0) {
rewards.push(RewardInfo(token, rewardAmount));
rewards[realRewardLength++] = RewardInfo(token, rewardAmount);
//rewards.push(RewardInfo(token, rewardAmount));
}

// set the withdraw state of the user and the cycle and the token
Expand All @@ -261,8 +271,8 @@ contract DividendContract is ReentrancyGuard {
}

// do the transfer
for (uint i = 0; i < rewards.length; i++) {
RewardInfo storage reward = rewards[i];
for (uint i = 0; i < realRewardLength; i++) {
RewardInfo memory reward = rewards[i];
if (reward.token == address(0)) {
payable(msg.sender).transfer(reward.amount);
} else {
Expand Down
30 changes: 16 additions & 14 deletions contracts/dividend2.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ contract Dividend2 {
struct UserStackLog {
uint256 cycleNumber;
uint256 amount;
uint256 addAmount;
}

mapping(uint256 => DividendInfo) dividends;
Expand Down Expand Up @@ -68,14 +67,13 @@ contract Dividend2 {
UserStackLog[] storage logs = userStackLog[msg.sender];

if (logs.length == 0) {
logs.push(UserStackLog(nextCycle, amount, amount));
logs.push(UserStackLog(nextCycle, amount));
} else {
UserStackLog storage lastLog = logs[logs.length-1];
if (lastLog.cycleNumber == nextCycle) {
lastLog.amount += amount;
lastLog.addAmount += amount;
} else {
logs.push(UserStackLog(nextCycle, lastLog.amount + amount, amount));
logs.push(UserStackLog(nextCycle, lastLog.amount + amount));
}
}

Expand All @@ -94,13 +92,21 @@ contract Dividend2 {
dmcToken.transfer(msg.sender, amount);
uint256 nextCycle = lastCycleNumber + 1;

// addAmount可以从last - lastlast的差值得到,就不需要再存一份了
if (lastLog.cycleNumber == nextCycle) {

if (lastLog.addAmount >= amount) {
uint256 addAmount = 0;
if (logs.length > 1) {
if (logs[logs.length-2].amount > lastLog.amount) {
addAmount = logs[logs.length-2].amount - lastLog.amount;
}
} else {
addAmount = lastLog.amount;
}

if (addAmount >= amount) {
lastLog.amount -= amount;
lastLog.addAmount -= amount;
} else {
uint256 diff = amount - lastLog.addAmount;
uint256 diff = amount - addAmount;

// 从上一个周期扣amount的差值, 能走到这里,一定说明至少有两个周期,否则不会出现lastLog.amount >= amount 且 lastLog.addAmount < amount的情况
UserStackLog memory lastLastLog = logs[logs.length-2];
Expand All @@ -112,23 +118,19 @@ contract Dividend2 {
lastLog.cycleNumber = nextCycle - 1;
// 70 -> 25 = 50 - (45 - 20)
lastLog.amount = lastLastLog.amount - diff;
// addAmount不改了,因为这个值在非lastLog的位置并没有意义

logs.push(UserStackLog(nextCycle, lastLog.amount, 0));
logs.push(UserStackLog(nextCycle, lastLog.amount));
} else {
// 假设用户在周期2存了50,周期3存了20,又提取了45
// 这里的原log为[{3, 50, 50}, {4, 70, 20}],变成[{3, 25, 25}, {4, 25, 0}]
lastLastLog.amount -= diff;
// addAmount不改了,因为这个值在非lastLog的位置并没有意义
lastLog.amount -= amount;
// 这里的addAmount就必须要改
lastLog.addAmount = 0;
}

dividends[lastLog.cycleNumber].totalDeposits -= diff;
}
} else {
logs.push(UserStackLog(nextCycle, lastLog.amount - amount, 0));
logs.push(UserStackLog(nextCycle, lastLog.amount - amount));
}

totalDeposits -= amount;
Expand Down
20 changes: 11 additions & 9 deletions test/test_dividend.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import { ethers, upgrades } from "hardhat"
import { DMC2, GWTToken2, Dividend2 } from "../typechain-types"
import { DMC2, GWTToken2, Dividend2, DividendContract } from "../typechain-types"
import { HardhatEthersSigner } from '@nomicfoundation/hardhat-ethers/signers';
import { expect } from "chai";
import { mine } from "@nomicfoundation/hardhat-network-helpers";

describe("Devidend", function () {
let dmc: DMC2
let gwt: GWTToken2
let dividend: Dividend2;
//let dividend: Dividend2;
let dividend: DividendContract;
let signers: HardhatEthersSigner[];

before(async () => {
signers = await ethers.getSigners()

dmc = await (await ethers.deployContract("DMC2", [ethers.parseEther("1000000000"), [signers[0].address], [1000000]])).waitForDeployment()
gwt = await (await ethers.deployContract("GWTToken2")).waitForDeployment()
dividend = await (await ethers.deployContract("Dividend2", [await dmc.getAddress(), 1000])).waitForDeployment()
//dividend = await (await ethers.deployContract("Dividend2", [await dmc.getAddress(), 1000])).waitForDeployment()
dividend = await(await ethers.deployContract("DividendContract", [await dmc.getAddress(), 1000])).waitForDeployment();

// 给signers[0] 1000个GWT
await (await gwt.enableMinter([signers[0].address])).wait()
await (await gwt.mint(signers[0].address, 1000)).wait()
await (await gwt.approve(await dividend.getAddress(), 1000)).wait()

// 给signers[1]到19, 每人100个DMC
for (let i = 1; i < 20; i++) {
for (let i = 1; i < 3; i++) {
await (await dmc.transfer(signers[i].address, 100)).wait()
await (await dmc.connect(signers[i]).approve(await dividend.getAddress(), 100)).wait()
}
Expand Down Expand Up @@ -51,7 +53,7 @@ describe("Devidend", function () {
await (await dividend.deposit(100, await gwt.getAddress())).wait();

// 因为周期1开始时没有已确定的抵押,周期1的分红是提不到的
expect(dividend.connect(signers[1]).withdraw([1])).to.be.revertedWith("cannot withdraw");
expect(dividend.connect(signers[1]).withdrawDevidends([1], [await gwt.getAddress()])).to.be.revertedWith("cannot withdraw");

mine(1000);
});
Expand All @@ -61,12 +63,12 @@ describe("Devidend", function () {
await (await dividend.deposit(100, await gwt.getAddress())).wait();

// 此时提现周期2的,应该能提到100 GWT
await (await dividend.connect(signers[1]).withdraw([2]));
await (await dividend.connect(signers[1]).withdrawDevidends([2], [await gwt.getAddress()]));
expect(await gwt.balanceOf(signers[1].address)).to.equal(100);

// 周期3,signer1先存20, 再提取45 DMC出来
await (await dividend.connect(signers[1]).stake(20)).wait();
await (await dividend.connect(signers[1]).unStake(45)).wait();
await (await dividend.connect(signers[1]).unstake(45)).wait();

mine(1000);
});
Expand All @@ -76,11 +78,11 @@ describe("Devidend", function () {
await (await dividend.deposit(0, ethers.ZeroAddress)).wait();

// 此时提现周期3的,应该能提到33 GWT
await (await dividend.connect(signers[1]).withdraw([3]));
await (await dividend.connect(signers[1]).withdrawDevidends([3], [await gwt.getAddress()]));
expect(await gwt.balanceOf(signers[1].address)).to.equal(133);

// signers2提取两个周期的分红,应该能提到100+66=166 GWT
await (await dividend.connect(signers[2]).withdraw([2,3]));
await (await dividend.connect(signers[2]).withdrawDevidends([2,3], [await gwt.getAddress()]));
expect(await gwt.balanceOf(signers[2].address)).to.equal(166);
})
})

0 comments on commit 1fa3f17

Please sign in to comment.