Optimize board.SquareIsAttackedByOpponent()

This commit is contained in:
Sebastian Lague 2023-07-27 13:35:55 +02:00
parent 960655c5a8
commit 5aa54e4e51
3 changed files with 229 additions and 28 deletions

View file

@ -73,11 +73,10 @@ namespace ChessChallenge.API
/// </summary> /// </summary>
public void MakeMove(Move move) public void MakeMove(Move move)
{ {
hasCachedMoves = false;
hasCachedCaptureMoves = false;
if (!move.IsNull) if (!move.IsNull)
{ {
repetitionHistory.Add(board.ZobristKey); repetitionHistory.Add(board.ZobristKey);
OnPositionChanged();
board.MakeMove(new Chess.Move(move.RawValue), inSearch: true); board.MakeMove(new Chess.Move(move.RawValue), inSearch: true);
} }
} }
@ -87,11 +86,10 @@ namespace ChessChallenge.API
/// </summary> /// </summary>
public void UndoMove(Move move) public void UndoMove(Move move)
{ {
hasCachedMoves = false;
hasCachedCaptureMoves = false;
if (!move.IsNull) if (!move.IsNull)
{ {
board.UndoMove(new Chess.Move(move.RawValue), inSearch: true); board.UndoMove(new Chess.Move(move.RawValue), inSearch: true);
OnPositionChanged();
repetitionHistory.Remove(board.ZobristKey); repetitionHistory.Remove(board.ZobristKey);
} }
} }
@ -108,9 +106,8 @@ namespace ChessChallenge.API
{ {
return false; return false;
} }
hasCachedMoves = false;
hasCachedCaptureMoves = false;
board.MakeNullMove(); board.MakeNullMove();
OnPositionChanged();
return true; return true;
} }
@ -124,9 +121,8 @@ namespace ChessChallenge.API
/// </summary> /// </summary>
public void ForceSkipTurn() public void ForceSkipTurn()
{ {
hasCachedMoves = false;
hasCachedCaptureMoves = false;
board.MakeNullMove(); board.MakeNullMove();
OnPositionChanged();
} }
/// <summary> /// <summary>
@ -134,9 +130,8 @@ namespace ChessChallenge.API
/// </summary> /// </summary>
public void UndoSkipTurn() public void UndoSkipTurn()
{ {
hasCachedMoves = false;
hasCachedCaptureMoves = false;
board.UnmakeNullMove(); board.UnmakeNullMove();
OnPositionChanged();
} }
/// <summary> /// <summary>
@ -278,11 +273,7 @@ namespace ChessChallenge.API
/// </summary> /// </summary>
public bool SquareIsAttackedByOpponent(Square square) public bool SquareIsAttackedByOpponent(Square square)
{ {
if (!hasCachedMoves) return BitboardHelper.SquareIsSet(moveGen.GetOpponentAttackMap(board), square);
{
GetLegalMoves();
}
return BitboardHelper.SquareIsSet(moveGen.opponentAttackMap, square);
} }
@ -363,5 +354,12 @@ namespace ChessChallenge.API
return new Board(boardCore); return new Board(boardCore);
} }
void OnPositionChanged()
{
moveGen.NotifyPositionChanged();
hasCachedMoves = false;
hasCachedCaptureMoves = false;
}
} }
} }

View file

@ -45,20 +45,32 @@ namespace ChessChallenge.Application.APIHelpers
// If only captures should be generated, this will have 1s only in positions of enemy pieces. // If only captures should be generated, this will have 1s only in positions of enemy pieces.
// Otherwise it will have 1s everywhere. // Otherwise it will have 1s everywhere.
ulong moveTypeMask; ulong moveTypeMask;
bool hasInitializedCurrentPosition;
public APIMoveGen() public APIMoveGen()
{ {
board = new Board(); board = new Board();
} }
// Movegen needs to know when position has changed to allow for some caching optims in api
public void NotifyPositionChanged()
{
hasInitializedCurrentPosition = false;
}
public ulong GetOpponentAttackMap(Board board)
{
Init(board);
return opponentAttackMap;
}
// Generates list of legal moves in current position. // Generates list of legal moves in current position.
// Quiet moves (non captures) can optionally be excluded. This is used in quiescence search. // Quiet moves (non captures) can optionally be excluded. This is used in quiescence search.
public void GenerateMoves(ref Span<API.Move> moves, Board board, bool includeQuietMoves = true) public void GenerateMoves(ref Span<API.Move> moves, Board board, bool includeQuietMoves = true)
{ {
this.board = board;
generateNonCapture = includeQuietMoves; generateNonCapture = includeQuietMoves;
Init(); Init(board);
GenerateKingMoves(moves); GenerateKingMoves(moves);
@ -79,10 +91,22 @@ namespace ChessChallenge.Application.APIHelpers
return inCheck; return inCheck;
} }
void Init() public void Init(Board board)
{ {
// Reset state this.board = board;
currMoveIndex = 0; currMoveIndex = 0;
if (hasInitializedCurrentPosition)
{
moveTypeMask = generateNonCapture ? ulong.MaxValue : enemyPieces;
return;
}
hasInitializedCurrentPosition = true;
// Reset state
inCheck = false; inCheck = false;
inDoubleCheck = false; inDoubleCheck = false;
checkRayBitmask = 0; checkRayBitmask = 0;
@ -103,7 +127,11 @@ namespace ChessChallenge.Application.APIHelpers
emptyOrEnemySquares = emptySquares | enemyPieces; emptyOrEnemySquares = emptySquares | enemyPieces;
moveTypeMask = generateNonCapture ? ulong.MaxValue : enemyPieces; moveTypeMask = generateNonCapture ? ulong.MaxValue : enemyPieces;
CalculateAttackData(); CalculateAttackData();
} }
API.Move CreateAPIMove(int startSquare, int targetSquare, int flag) API.Move CreateAPIMove(int startSquare, int targetSquare, int flag)
@ -532,7 +560,6 @@ namespace ChessChallenge.Application.APIHelpers
} }
// Pawn attacks // Pawn attacks
PieceList opponentPawns = board.pawns[enemyIndex];
opponentPawnAttackMap = 0; opponentPawnAttackMap = 0;
ulong opponentPawnsBoard = board.pieceBitboards[PieceHelper.MakePiece(PieceHelper.Pawn, board.OpponentColour)]; ulong opponentPawnsBoard = board.pieceBitboards[PieceHelper.MakePiece(PieceHelper.Pawn, board.OpponentColour)];

View file

@ -1,4 +1,5 @@
using ChessChallenge.API; using ChessChallenge.API;
using ChessChallenge.Application.APIHelpers;
using ChessChallenge.Chess; using ChessChallenge.Chess;
using System; using System;
@ -6,6 +7,8 @@ namespace ChessChallenge.Application
{ {
public static class Tester public static class Tester
{ {
const bool throwOnAssertFail = false;
static MoveGenerator moveGen; static MoveGenerator moveGen;
static API.Board boardAPI; static API.Board boardAPI;
@ -22,6 +25,9 @@ namespace ChessChallenge.Application
MiscTest(); MiscTest();
TestBitboards(); TestBitboards();
TestMoveCreate(); TestMoveCreate();
new SearchTest().Run(false);
new SearchTest().Run(true);
if (runPerft) if (runPerft)
{ {
@ -37,7 +43,6 @@ namespace ChessChallenge.Application
{ {
WriteWithCol("ALL TESTS PASSED", ConsoleColor.Green); WriteWithCol("ALL TESTS PASSED", ConsoleColor.Green);
} }
} }
public static void RunPerft(bool useStackalloc = true) public static void RunPerft(bool useStackalloc = true)
@ -126,13 +131,43 @@ namespace ChessChallenge.Application
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("a6")), "Square attacked wrong"); Assert(boardAPI.SquareIsAttackedByOpponent(new Square("a6")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("f3")), "Square attacked wrong"); Assert(boardAPI.SquareIsAttackedByOpponent(new Square("f3")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("c5")), "Square attacked wrong");
Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("c3")), "Square attacked wrong"); Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("c3")), "Square attacked wrong");
Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("h4")), "Square attacked wrong"); Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("h4")), "Square attacked wrong");
boardAPI.MakeMove(new API.Move("b5b7", boardAPI)); var m1 = new API.Move("b5b7", boardAPI);
boardAPI.MakeMove(m1);
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("e7")), "Square attacked wrong"); Assert(boardAPI.SquareIsAttackedByOpponent(new Square("e7")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("b8")), "Square attacked wrong"); Assert(boardAPI.SquareIsAttackedByOpponent(new Square("b8")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("d4")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("h6")), "Square attacked wrong");
Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("a5")), "Square attacked wrong"); Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("a5")), "Square attacked wrong");
Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("e8")), "Square attacked wrong"); Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("b7")), "Square attacked wrong");
var m2 = new API.Move("f6e4", boardAPI);
boardAPI.MakeMove(m2);
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("f2")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("c3")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("h6")), "Square attacked wrong");
Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("h4")), "Square attacked wrong");
boardAPI.ForceSkipTurn();
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("f7")), "Square attacked wrong");
Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("d5")), "Square attacked wrong");
boardAPI.UndoSkipTurn();
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("c5")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("c3")), "Square attacked wrong");
Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("h5")), "Square attacked wrong");
boardAPI.UndoMove(m2);
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("b1")), "Square attacked wrong");
Assert(!boardAPI.SquareIsAttackedByOpponent(new Square("a5")), "Square attacked wrong");
boardAPI.UndoMove(m1);
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("a5")), "Square attacked wrong");
Assert(boardAPI.SquareIsAttackedByOpponent(new Square("f8")), "Square attacked wrong");
} }
static void CheckTest() static void CheckTest()
@ -373,7 +408,63 @@ namespace ChessChallenge.Application
board.LoadPosition(testFens[i]); board.LoadPosition(testFens[i]);
boardAPI = new API.Board(board); boardAPI = new API.Board(board);
ulong result = Search(testDepths[i]); ulong result = Search(testDepths[i]);
Assert(result == testResults[i], "TEST FAILED"); Assert(result == testResults[i], "Movegen test failed");
}
board.LoadPosition("r2q2k1/pp2rppp/3p1n2/1R1Pn3/8/2PB1Q1P/P4PP1/2B2RK1 w - - 7 16");
boardAPI = new(board);
API.Move m1 = new("f3f6", boardAPI);
Assert(RecreateOpponentAttackMap() == 18446743649919696896ul, "Wrong attack map");
Assert(boardAPI.GetLegalMoves().Length == 43, "Wrong move count");
Assert(boardAPI.GetLegalMoves(true).Length == 3, "Wrong capture count");
boardAPI.MakeMove(m1);
Assert(RecreateOpponentAttackMap() == 68361585683595006ul, "Wrong attack map");
Assert(boardAPI.GetLegalMoves().Length == 31, "Wrong move count");
Assert(boardAPI.GetLegalMoves(true).Length == 2, "Wrong capture count");
boardAPI.ForceSkipTurn();
Assert(RecreateOpponentAttackMap() == 18446743065535709184ul, "Wrong attack map");
Assert(boardAPI.GetLegalMoves().Length == 48, "Wrong move count");
Assert(boardAPI.GetLegalMoves(true).Length == 7, "Wrong capture count");
boardAPI.ForceSkipTurn();
Assert(RecreateOpponentAttackMap() == 68361585683595006ul, "Wrong attack map");
Assert(boardAPI.GetLegalMoves().Length == 31, "Wrong move count");
Assert(boardAPI.GetLegalMoves(true).Length == 2, "Wrong capture count");
boardAPI.UndoSkipTurn();
Assert(RecreateOpponentAttackMap() == 18446743065535709184ul, "Wrong attack map");
Assert(boardAPI.GetLegalMoves().Length == 48, "Wrong move count");
Assert(boardAPI.GetLegalMoves(true).Length == 7, "Wrong capture count");
boardAPI.UndoSkipTurn();
Assert(RecreateOpponentAttackMap() == 68361585683595006ul, "Wrong attack map");
Assert(boardAPI.GetLegalMoves().Length == 31, "Wrong move count");
Assert(boardAPI.GetLegalMoves(true).Length == 2, "Wrong capture count");
boardAPI.UndoMove(m1);
Assert(RecreateOpponentAttackMap() == 18446743649919696896ul, "Wrong attack map");
Assert(boardAPI.GetLegalMoves().Length == 43, "Wrong move count");
Assert(boardAPI.GetLegalMoves(true).Length == 3, "Wrong capture count");
Span<API.Move> moveList = stackalloc API.Move[218];
boardAPI.GetLegalMovesNonAlloc(ref moveList);
Span<API.Move> moveListDupe = stackalloc API.Move[218];
boardAPI.GetLegalMovesNonAlloc(ref moveListDupe);
Assert(moveList.Length == 43 && moveListDupe.Length == 43, "Move gen wrong");
Span<API.Move> moveListAtk = stackalloc API.Move[218];
boardAPI.GetLegalMovesNonAlloc(ref moveListAtk, true);
Assert(moveListAtk.Length == 3, "Move gen wrong");
Assert(RecreateOpponentAttackMap() == 18446743649919696896ul, "Wrong attack map");
ulong RecreateOpponentAttackMap()
{
ulong bb = 0;
for (int i = 0; i < 64; i ++)
{
if (boardAPI.SquareIsAttackedByOpponent(new Square(i)))
{
BitboardHelper.SetSquare(ref bb, new Square(i));
}
}
return bb;
} }
} }
@ -432,6 +523,10 @@ namespace ChessChallenge.Application
{ {
WriteWithCol(msg); WriteWithCol(msg);
anyFailed = true; anyFailed = true;
if (throwOnAssertFail)
{
throw new Exception();
}
} }
} }
@ -443,5 +538,86 @@ namespace ChessChallenge.Application
Console.ResetColor(); Console.ResetColor();
} }
public class SearchTest
{
API.Board board;
bool useStackalloc;
int numLeafNodes;
int numCalls;
long miscSumTest;
public void Run(bool useStackalloc)
{
this.useStackalloc = useStackalloc;
Console.WriteLine("Running misc search test | stackalloc = " + useStackalloc);
Chess.Board b = new();
b.LoadPosition("1r4k1/2P1r1pp/3p4/4n1Q1/1p6/2PB3P/P3pPP1/2B3K1 w - - 7 16");
board = new API.Board(b);
Search(4);
Assert(miscSumTest == 101146355, "Misc search test failed");
}
void Search(int plyRemaining)
{
numCalls++;
var square = new Square(numCalls % 64);
miscSumTest += (int)boardAPI.GetPiece(square).PieceType;
miscSumTest += boardAPI.GetAllPieceLists()[numCalls % 12].Count;
miscSumTest += (long)(boardAPI.ZobristKey % 100);
miscSumTest += boardAPI.IsInCheckmate() ? 1 : 0;
if (numCalls % 6 == 0)
{
miscSumTest += boardAPI.IsInCheck() ? 1 : 0;
}
if (numCalls % 18 == 0)
{
miscSumTest += boardAPI.SquareIsAttackedByOpponent(square) ? 1 : 0;
}
if (plyRemaining == 0)
{
numLeafNodes++;
return;
}
if (numCalls % 3 == 0 && plyRemaining > 2)
{
if (boardAPI.TrySkipTurn())
{
Search(plyRemaining - 2);
boardAPI.UndoSkipTurn();
} }
} }
API.Move[] moves;
if (useStackalloc)
{
Span<API.Move> moveSpan = stackalloc API.Move[256];
boardAPI.GetLegalMovesNonAlloc(ref moveSpan);
moves = moveSpan.ToArray(); // (don't actually care about allocations here, just testing the func)
}
else
{
moves = boardAPI.GetLegalMoves();
}
foreach (var move in moves)
{
boardAPI.MakeMove(move);
Search(plyRemaining - 1);
boardAPI.UndoMove(move);
}
}
}
}
}