Skip to content

Commit

Permalink
Support overriding options for concrete flow
Browse files Browse the repository at this point in the history
  • Loading branch information
stidsborg committed Jul 14, 2024
1 parent a39cb02 commit eafcc5b
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 16 deletions.
66 changes: 66 additions & 0 deletions Cleipnir.Flows.Tests/Flows/OptionsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using Cleipnir.Flows.AspNet;
using Cleipnir.ResilientFunctions.Domain;
using Cleipnir.ResilientFunctions.Domain.Exceptions;
using Cleipnir.ResilientFunctions.Reactive.Extensions;
using Microsoft.Extensions.DependencyInjection;
using Shouldly;

namespace Cleipnir.Flows.Tests.Flows;

[TestClass]
public class OptionsTests
{
[TestMethod]
public async Task SimpleFlowCompletesSuccessfully()
{
var serviceCollection = new ServiceCollection();

serviceCollection.AddFlows(c => c
.UseInMemoryStore()
.WithOptions(new Options(messagesDefaultMaxWaitForCompletion: TimeSpan.MaxValue))
.RegisterFlow<OptionsTestWithOverriddenOptionsFlow, OptionsTestWithOverriddenOptionsFlows>(
factory: sp => new OptionsTestWithOverriddenOptionsFlows(
sp.GetRequiredService<FlowsContainer>(),
new Options(messagesDefaultMaxWaitForCompletion: TimeSpan.Zero)
)
)
.RegisterFlow<OptionsTestWithDefaultProvidedOptionsFlow, OptionsTestWithDefaultProvidedOptionsFlows>()
);

var sp = serviceCollection.BuildServiceProvider();
var flowsWithOverridenOptions = sp.GetRequiredService<OptionsTestWithOverriddenOptionsFlows>();

await Should.ThrowAsync<FunctionInvocationSuspendedException>(
() => flowsWithOverridenOptions.Run("Id")
);

var flowsWithDefaultProvidedOptions = sp.GetRequiredService<OptionsTestWithDefaultProvidedOptionsFlows>();
await flowsWithDefaultProvidedOptions.Schedule("Id");

await Task.Delay(100);

var controlPanel = await flowsWithDefaultProvidedOptions.ControlPanel("Id");
controlPanel.ShouldNotBeNull();
controlPanel.Status.ShouldBe(Status.Executing);

await controlPanel.Messages.Append("Hello");

await controlPanel.WaitForCompletion();
}

public class OptionsTestWithOverriddenOptionsFlow : Flow
{
public override async Task Run()
{
await Messages.First();
}
}

public class OptionsTestWithDefaultProvidedOptionsFlow : Flow
{
public override async Task Run()
{
await Messages.First();
}
}
}
9 changes: 9 additions & 0 deletions Cleipnir.Flows/AspNet/FlowsModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ public FlowsConfigurator RegisterFlow<TFlow, TFlows>() where TFlow : BaseFlow wh

return this;
}

public FlowsConfigurator RegisterFlow<TFlow, TFlows>(Func<IServiceProvider, TFlows> factory) where TFlow : BaseFlow where TFlows : BaseFlows<TFlow>
{
Services.AddScoped<TFlow>();
Services.AddTransient(factory);
FlowsTypes = FlowsTypes.Append(typeof(TFlows));

return this;
}

public FlowsConfigurator RegisterFlowsAutomatically(Assembly? rootAssembly = null)
{
Expand Down
18 changes: 12 additions & 6 deletions Cleipnir.Flows/Flows.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public class Flows<TFlow> : BaseFlows<TFlow> where TFlow : Flow
{
private readonly ParamlessRegistration _registration;

public Flows(string flowName, FlowsContainer flowsContainer) : base(flowsContainer)
public Flows(string flowName, FlowsContainer flowsContainer, Options? options = null) : base(flowsContainer)
{
var callChain = CreateMiddlewareCallChain<Unit, Unit>(runFlow: async (flow, _) =>
{
Expand All @@ -115,7 +115,9 @@ public Flows(string flowName, FlowsContainer flowsContainer) : base(flowsContain
_registration = flowsContainer.FunctionRegistry.RegisterParamless(
flowName,
inner: workflow => callChain(Unit.Instance, workflow),
new Settings(routes: CreateRoutingInformation())
(options ?? Options.Default)
.Merge(new Options(routes: CreateRoutingInformation()))
.MapToRFunctionsSettings()
);
}

Expand Down Expand Up @@ -147,7 +149,7 @@ public class Flows<TFlow, TParam> : BaseFlows<TFlow>
{
private readonly FuncRegistration<TParam, Unit> _registration;

public Flows(string flowName, FlowsContainer flowsContainer) : base(flowsContainer)
public Flows(string flowName, FlowsContainer flowsContainer, Options? options = null) : base(flowsContainer)
{
var callChain = CreateMiddlewareCallChain<TParam, Unit>(
runFlow: async (flow, param) =>
Expand All @@ -159,7 +161,9 @@ public Flows(string flowName, FlowsContainer flowsContainer) : base(flowsContain
_registration = flowsContainer.FunctionRegistry.RegisterFunc<TParam, Unit>(
flowName,
inner: (param, workflow) => callChain(param, workflow),
settings: new Settings(routes: CreateRoutingInformation())
settings: (options ?? Options.Default)
.Merge(new Options(routes: CreateRoutingInformation()))
.MapToRFunctionsSettings()
);
}

Expand Down Expand Up @@ -199,7 +203,7 @@ public class Flows<TFlow, TParam, TResult> : BaseFlows<TFlow>
{
private readonly FuncRegistration<TParam, TResult> _registration;

public Flows(string flowName, FlowsContainer flowsContainer) : base(flowsContainer)
public Flows(string flowName, FlowsContainer flowsContainer, Options? options = null) : base(flowsContainer)
{
var callChain = CreateMiddlewareCallChain<TParam, TResult>(
runFlow: (flow, param) => flow.Run(param)
Expand All @@ -208,7 +212,9 @@ public Flows(string flowName, FlowsContainer flowsContainer) : base(flowsContain
_registration = flowsContainer.FunctionRegistry.RegisterFunc<TParam, TResult>(
flowName,
inner: (param, workflow) => callChain(param, workflow),
new Settings(routes: CreateRoutingInformation())
(options ?? Options.Default)
.Merge(new Options(routes: CreateRoutingInformation()))
.MapToRFunctionsSettings()
);
}

Expand Down
28 changes: 28 additions & 0 deletions Cleipnir.Flows/Options.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Cleipnir.Flows.CrossCutting;
using Cleipnir.ResilientFunctions.CoreRuntime.ParameterSerialization;
using Cleipnir.ResilientFunctions.Domain;
Expand Down Expand Up @@ -66,6 +67,33 @@ public Options UseMiddleware(IMiddleware middleware)
return this;
}

public Options Merge(Options options)
{
var merged = new Options(
UnhandledExceptionHandler ?? options.UnhandledExceptionHandler,
RetentionPeriod ?? options.RetentionPeriod,
RetentionCleanUpFrequency ?? options.RetentionCleanUpFrequency,
LeaseLength ?? options.LeaseLength,
EnableWatchdogs ?? options.EnableWatchdogs,
WatchdogCheckFrequency ?? options.WatchdogCheckFrequency,
MessagesPullFrequency ?? options.MessagesPullFrequency,
MessagesDefaultMaxWaitForCompletion ?? options.MessagesDefaultMaxWaitForCompletion,
DelayStartup ?? options.DelayStartup,
MaxParallelRetryInvocations ?? options.MaxParallelRetryInvocations,
Serializer ?? options.Serializer,
Routes ?? options.Routes
);

if (Middlewares.Any())
foreach (var middleware in Middlewares)
merged.Middlewares.Add(middleware);
else
foreach (var middleware in options.Middlewares)
merged.Middlewares.Add(middleware);

return merged;
}

internal Settings MapToRFunctionsSettings()
=> new(
UnhandledExceptionHandler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,49 +190,53 @@ private void AddFlowsWrapper(GeneratorExecutionContext context, FlowInformation
generatedCode =
$@"namespace {flowsNamespace}
{{
#nullable enable
[Cleipnir.Flows.SourceGeneration.SourceGeneratedFlowsAttribute]
{accessibilityModifier} class {flowsName} : Cleipnir.Flows.Flows<{flowType}>
{{
public {flowsName}(Cleipnir.Flows.FlowsContainer flowsContainer)
: base(flowName: ""{flowName}"", flowsContainer) {{ }}
{{
public {flowsName}(Cleipnir.Flows.FlowsContainer flowsContainer, Cleipnir.Flows.Options? options = null)
: base(flowName: ""{flowName}"", flowsContainer, options) {{ }}
}}
#nullable disable
}}";
}
else if (resultType == null)
{
generatedCode =
$@"namespace {flowsNamespace}
{{
#nullable enable
[Cleipnir.Flows.SourceGeneration.SourceGeneratedFlowsAttribute]
{accessibilityModifier} class {flowsName} : Cleipnir.Flows.Flows<{flowType}, {paramType}>
{{
public {flowsName}(Cleipnir.Flows.FlowsContainer flowsContainer)
: base(flowName: ""{flowName}"", flowsContainer) {{ }}
public {flowsName}(Cleipnir.Flows.FlowsContainer flowsContainer, Cleipnir.Flows.Options? options = null)
: base(flowName: ""{flowName}"", flowsContainer, options) {{ }}
}}
#nullable disable
}}";
}
else
{
generatedCode =
$@"namespace {flowsNamespace}
{{
#nullable enable
[Cleipnir.Flows.SourceGeneration.SourceGeneratedFlowsAttribute]
{accessibilityModifier} class {flowsName} : Cleipnir.Flows.Flows<{flowType}, {paramType}, {resultType}>
{{
public {flowsName}(Cleipnir.Flows.FlowsContainer flowsContainer)
: base(flowName: ""{flowName}"", flowsContainer) {{ }}
public {flowsName}(Cleipnir.Flows.FlowsContainer flowsContainer, Cleipnir.Flows.Options? options = null)
: base(flowName: ""{flowName}"", flowsContainer, options) {{ }}
}}
#nullable disable
}}";
}

if (flowInformation.StateTypeSymbol != null)
{
var getStateStr = $@"
#nullable enable
public Task<{stateType}?> GetState(string functionInstanceId)
=> GetState<{stateType}>(functionInstanceId);
#nullable disable";
=> GetState<{stateType}>(functionInstanceId);";
var constructorEndPosition = generatedCode.IndexOf("{ }", StringComparison.Ordinal);
generatedCode = generatedCode.Insert(constructorEndPosition + 3, getStateStr);
}
Expand Down

0 comments on commit eafcc5b

Please sign in to comment.