Skip to content

Commit

Permalink
Add async method to model
Browse files Browse the repository at this point in the history
  • Loading branch information
valeriyvan committed Jul 22, 2024
1 parent 2e3c4f8 commit d086fb3
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 14 deletions.
102 changes: 102 additions & 0 deletions Sources/geometrize/GeometrizeModelHillClimb.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,51 @@ class GeometrizeModelHillClimb: GeometrizeModelBase {
return states
}

private func getHillClimbStateAsync( // swiftlint:disable:this function_parameter_count
shapeCreator: @escaping ShapeCreator,
alpha: UInt8,
shapeCount: Int,
maxShapeMutations: Int,
maxThreads: Int, // Ignored. Single thread is used at the moment.
energyFunction: @escaping EnergyFunction
) async -> [State] {
// Ensure that the results of the random generation are the same between tasks with identical settings
// The RNG is thread-local and std::async may use a thread pool (which is why this is necessary)
// Note this implementation requires maxThreads to be the same between tasks for each task to produce the same results.

let lastScore = lastScore
let target = targetBitmap
let current = currentBitmap

let states = await withTaskGroup(of: State.self, returning: [State].self) { taskGroup in
for _ in 0..<maxThreads {
let seed = UInt64(baseRandomSeed + randomSeedOffset) // TODO: fix
randomSeedOffset += 1
taskGroup.addTask {
let state = await bestHillClimbStateAsync(
shapeCreator: shapeCreator,
alpha: alpha,
n: shapeCount,
age: maxShapeMutations,
target: target,
current: current,
lastScore: lastScore,
energyFunction: energyFunction,
seed: seed
)
return state
}
}

var states = [State]()
for await result in taskGroup {
states.append(result)
}
return states
}
return states
}

/// Concurrently runs several optimization sessions trying to improve image geometrization by adding a shape to it,
/// returns result of the best optimization or nil if improvement of image wasn't found.
/// - Parameters:
Expand Down Expand Up @@ -131,6 +176,63 @@ class GeometrizeModelHillClimb: GeometrizeModelBase {
return .success(ShapeResult(score: lastScore, color: color, shape: shape))
}

func stepAsync( // swiftlint:disable:this function_parameter_count
shapeCreator: @escaping ShapeCreator,
alpha: UInt8,
shapeCount: Int,
maxShapeMutations: Int,
maxThreads: Int,
energyFunction: @escaping EnergyFunction,
addShapePrecondition: @escaping ShapeAcceptancePreconditionFunction = defaultAddShapePrecondition
) async -> StepGeometrizationResult {

let states: [State] = await getHillClimbStateAsync(
shapeCreator: shapeCreator,
alpha: alpha,
shapeCount: shapeCount,
maxShapeMutations: maxShapeMutations,
maxThreads: maxThreads,
energyFunction: energyFunction
)

guard !states.isEmpty else {
fatalError("Failed to get a hill climb state.")
}

// State with min score
guard let it = states.min(by: { $0.score < $1.score }) else {
fatalError("Failed to get a state with min score.")
}

// Draw the shape onto the image
let shape = it.shape
let lines: [Scanline] = shape.rasterize(x: 0...width - 1, y: 0...height - 1)
let color: Rgba = lines.computeColor(target: targetBitmap, current: currentBitmap, alpha: alpha)
let before: Bitmap = currentBitmap
currentBitmap.draw(lines: lines, color: color)

// Check for an improvement - if not, roll back and return no result
let newScore: Double = before.differencePartial(
with: currentBitmap,
target: targetBitmap,
score: lastScore,
mask: lines
)
guard addShapePrecondition(lastScore, newScore, shape, lines, color, before, currentBitmap, targetBitmap) else {
currentBitmap = before
if before == currentBitmap {
return .match
} else {
return .failure
}
}

// Improvement - set new baseline and return the new shape
lastScore = newScore

return .success(ShapeResult(score: lastScore, color: color, shape: shape))
}

/// Sets the seed that the random number generators of this model use.
/// Note that the model also uses an internal seed offset which is incremented when the model is stepped.
/// - Parameter seed: The random number generator seed.
Expand Down
2 changes: 1 addition & 1 deletion Sources/geometrize/HillClimb.swift
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func hillClimb( // swiftlint:disable:this function_parameter_count
/// - energyFunction: An energy function to be used.
/// - using: The Random number generator to use.
/// - Returns: The best random state i.e. the one with the lowest energy.
private func bestRandomState( // swiftlint:disable:this function_parameter_count
internal func bestRandomState( // swiftlint:disable:this function_parameter_count
shapeCreator: ShapeCreator,
alpha: UInt8,
n: Int,
Expand Down
162 changes: 162 additions & 0 deletions Sources/geometrize/HillClimbAsync.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import Foundation

/// Gets the best state using a hill climbing algorithm.
/// - Parameters:
/// - shapeCreator: A function that will create the shapes that will be chosen from.
/// - alpha: The opacity of the shape.
/// - n: The number of random states to generate.
/// - age: The number of hillclimbing steps.
/// - target: The target bitmap.
/// - current: The current bitmap.
/// - buffer: The buffer bitmap.
/// - lastScore: The last score.
/// - energyFunction: A function to calculate the energy.
/// - seed: The Random number generator to use.
/// - Returns: The best state acquired from hill climbing i.e. the one with the lowest energy.
func bestHillClimbStateAsync( // swiftlint:disable:this function_parameter_count
shapeCreator: ShapeCreator,
alpha: UInt8,
n: Int,
age: Int,
target: Bitmap,
current: Bitmap,
lastScore: Double,
energyFunction: EnergyFunction = defaultEnergyFunction,
seed: UInt64
) async -> State {
var buffer = current
var generator = SplitMix64(seed: seed)
return await withCheckedContinuation { continuation in
let state: State = bestRandomState(
shapeCreator: shapeCreator,
alpha: alpha,
n: n,
target: target,
current: current,
buffer: &buffer, // TODO: should it be inout?
lastScore: lastScore,
energyFunction: energyFunction,
using: &generator
)
let resState = hillClimb(
state: state,
maxAge: age,
target: target,
current: current,
buffer: &buffer, // TODO: should it be inout?
lastScore: lastScore,
energyFunction: energyFunction,
using: &generator
)
continuation.resume(returning: resState)
}
}

/// Hill climbing optimization algorithm, attempts to minimize energy (the error/difference).
/// https://en.wikipedia.org/wiki/Hill_climbing
/// - Parameters:
/// - state: The state to optimize.
/// - maxAge: The maximum age.
/// - target: The target bitmap.
/// - current: The current bitmap.
/// - buffer: The buffer bitmap.
/// - lastScore: The last score.
/// - energyFunction: An energy function to be used.
/// - using: The Random number generator to use.
/// - Returns: The best state found from hillclimbing.
func hillClimbAsync( // swiftlint:disable:this function_parameter_count
state: State,
maxAge: Int,
target: Bitmap,
current: Bitmap,
buffer: inout Bitmap,
lastScore: Double,
energyFunction: EnergyFunction,
using generator: inout SplitMix64
) -> State {
let xRange = 0...target.width - 1, yRange = 0...target.height - 1
var s: State = state
var bestState: State = state
var bestEnergy: Double = bestState.score
var age: Int = 0
while age < maxAge {
let undo = s
let alpha = s.alpha
s = s.mutate(x: xRange, y: yRange, using: &generator) { aShape in
return energyFunction(
aShape.rasterize(x: xRange, y: yRange),
alpha,
target,
current,
&buffer,
lastScore
)
}
if s.score >= bestEnergy {
s = undo
} else {
bestEnergy = s.score
bestState = s
age = Int.max // TODO: What's the point??? And following increment overflows.
}
if age == Int.max {
age = 0
} else {
age += 1
}
}
return bestState
}

/// Gets the best state using a random algorithm.
/// - Parameters:
/// - shapeCreator: A function that will create the shapes that will be chosen from.
/// - alpha: The opacity of the shape.
/// - n: The number of states to try.
/// - target: The target bitmap.
/// - current: The current bitmap.
/// - buffer: The buffer bitmap.
/// - lastScore: The last score.
/// - energyFunction: An energy function to be used.
/// - using: The Random number generator to use.
/// - Returns: The best random state i.e. the one with the lowest energy.
private func bestRandomStateAsync( // swiftlint:disable:this function_parameter_count
shapeCreator: ShapeCreator,
alpha: UInt8,
n: Int,
target: Bitmap,
current: Bitmap,
buffer: inout Bitmap,
lastScore: Double,
energyFunction: EnergyFunction,
using generator: inout SplitMix64
) -> State {
let xRange = 0...target.width - 1, yRange = 0...target.height - 1
let shape = shapeCreator(&generator).setup(x: xRange, y: yRange, using: &generator)
var bestEnergy: Double = energyFunction(
shape.rasterize(x: xRange, y: yRange),
alpha,
target,
current,
&buffer,
lastScore
)
var bestState: State = State(score: bestEnergy, alpha: alpha, shape: shape)
for i in 0...n {
let shape = shapeCreator(&generator).setup(x: xRange, y: yRange, using: &generator)
let energy: Double = energyFunction(
shape.rasterize(x: xRange, y: yRange),
alpha,
target,
current,
&buffer,
lastScore
)
let state: State = State(score: energy, alpha: alpha, shape: shape)
if i == 0 || energy < bestEnergy {
bestEnergy = energy
bestState = state
}
}
return bestState
}
27 changes: 27 additions & 0 deletions Sources/geometrize/ImageRunner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,33 @@ public struct ImageRunner {
)
}

public mutating func stepAsync(
options: ImageRunnerOptions,
shapeCreator: ShapeCreator? = nil,
energyFunction: @escaping EnergyFunction,
addShapePrecondition: @escaping ShapeAcceptancePreconditionFunction
) async -> StepGeometrizationResult {
let types = options.shapeTypes

let shapeCreator = shapeCreator ??
makeDefaultShapeCreator(
types: types,
strokeWidth: Double(options.strokeWidth)
)

model.setSeed(options.seed)

return await model.stepAsync(
shapeCreator: shapeCreator,
alpha: options.alpha,
shapeCount: options.shapeCount,
maxShapeMutations: options.maxShapeMutations,
maxThreads: options.maxThreads,
energyFunction: energyFunction,
addShapePrecondition: addShapePrecondition
)
}

public var currentBitmap: Bitmap {
model.currentBitmap
}
Expand Down
10 changes: 1 addition & 9 deletions Sources/geometrize/SVGAsyncIterator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public struct SVGAsyncIterator: AsyncIteratorProtocol {
if verbose {
print("Step \(stepCounter)", terminator: "")
}
let stepResult = runner.step(
let stepResult = await runner.stepAsync(
options: runnerOptions,
energyFunction: defaultEnergyFunction,
addShapePrecondition: defaultAddShapePrecondition
Expand Down Expand Up @@ -177,14 +177,6 @@ public struct SVGAsyncIterator: AsyncIteratorProtocol {

iterationCounter += 1

// var svg = SVGExporter().export(data: shapeData, width: width, height: height)

// Fix SVG to keep original image size
// let range = svg.range(of: "width=")!.lowerBound ..< svg.range(of: "viewBox=")!.lowerBound
// svg.replaceSubrange(range.relative(to: svg), with: " width=\"\(originWidth)\" height=\"\(originHeight)\" ")
//
// print("Iteration \(iterationCounter) complete, \(iterationShapeData.count) shapes in iteration, " +
// "\(shapeData.count) shapes in total.")
return GeometrizingResult(svg: svg, thumbnail: runner.currentBitmap)
}

Expand Down
4 changes: 2 additions & 2 deletions Tests/geometrizeTests/ImageRunnerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import PNG

final class ImageRunnerTests: XCTestCase {

func testImageRunnerRedImage() throws {
func testImageRunnerRedImage() throws { // swiftlint:disable:this function_body_length
let width = 100, height = 100
let targetBitmap = Bitmap(width: width, height: height, color: .red)

Expand Down Expand Up @@ -77,7 +77,7 @@ final class ImageRunnerTests: XCTestCase {
)
}

func testImageRunner() throws {
func testImageRunner() throws { // swiftlint:disable:this function_body_length
throw XCTSkip("Randomness should be somehow handled in this test.")

// seedRandomGenerator(9001)
Expand Down
4 changes: 2 additions & 2 deletions Tests/geometrizeTests/SVGAsyncGeometrizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import SnapshotTesting
final class SVGAsyncGeometrizerTests: XCTestCase {

func testAsyncGeometrizerCompleteSVGEachIteration() async throws {
throw XCTSkip("The test should be debugged.")
//throw XCTSkip("The test should be debugged.")

guard let urlSource = Bundle.module.url(forResource: "sunrise_at_sea", withExtension: "ppm") else {
fatalError()
Expand Down Expand Up @@ -34,7 +34,7 @@ final class SVGAsyncGeometrizerTests: XCTestCase {
}

func testAsyncGeometrizerCompleteSVGFirstIterationThenDeltas() async throws {
throw XCTSkip("The test should be debugged.")
//throw XCTSkip("The test should be debugged.")

guard let urlSource = Bundle.module.url(forResource: "sunrise_at_sea", withExtension: "ppm") else {
fatalError("No resource files")
Expand Down

0 comments on commit d086fb3

Please sign in to comment.