diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 5e914b31e29e..b66b85b6bb51 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -783,7 +783,7 @@ private void DiscoverHubMethods(bool disableImplicitFromServiceParameters) } var methodName = - methodInfo.GetCustomAttribute()?.Name ?? + methodInfo.GetCustomAttribute(inherit: true)?.Name ?? methodInfo.Name; if (_methods.ContainsKey(methodName)) @@ -894,4 +894,4 @@ private static void SetActivityError(Activity? activity, Exception ex) activity?.SetTag("error.type", ex.GetType().FullName); activity?.SetStatus(ActivityStatusCode.Error); } -} \ No newline at end of file +} diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs index 90fabf90265b..1f4a59b23584 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs @@ -1367,24 +1367,59 @@ public bool SingleService([FromService] Service1 service) return true; } + public bool ServiceWithAttribute([FromService] Service1 service) + { + return true; + } + + public int ServiceWithStringAttribute([FromService] Service1 service, string value) + { + return 115; + } + public bool MultipleServices([FromService] Service1 service, [FromService] Service2 service2, [FromService] Service3 service3) { return true; } - public async Task ServicesAndParams(int value, [FromService] Service1 service, ChannelReader channelReader, [FromService] Service2 service2, bool value2) + public bool MultipleServicesWithAttribute([FromService] Service1 service, [FromService] Service2 service2) + { + return true; + } + + public int MixedParamsWithAttribute(int num, string text, [FromService] Service1 service, [FromService] Service2 service2) + { + return 111; + } + + public int ServiceAttributeBeforeParam([FromService] Service1 service, int num) + { + return num + 1; + } + + public async Task UploadWithServiceAttribute(int value, [FromService] Service1 service, ChannelReader channelReader) { int total = 0; while (await channelReader.WaitToReadAsync()) { total += await channelReader.ReadAsync(); } - return total + value; + return total + value + 1; } - - public int ServiceWithStringAttribute([FromService] Service1 service, string value) + + public int ServiceWithOptionalParam([FromService] Service1 service, int value = 42) { - return 115; + return value + 1; + } + + public async Task ServicesAndParams(int value, [FromService] Service1 service, ChannelReader channelReader, [FromService] Service2 service2, bool value2) + { + int total = 0; + while (await channelReader.WaitToReadAsync()) + { + total += await channelReader.ReadAsync(); + } + return total + value; } public int ServiceWithoutAttribute(Service1 service) @@ -1469,4 +1504,4 @@ public override async Task OnConnectedAsync() await Clients.Client(id).SendAsync("Test", 1); } } -} \ No newline at end of file +} diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index 486f6c53ea02..6a1fa55c5b24 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -1,13 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers; -using System.Diagnostics; -using System.Globalization; -using System.IO.Pipelines; -using System.Security.Claims; -using System.Text; -using System.Threading.Channels; using MessagePack; using MessagePack.Formatters; using MessagePack.Resolvers; @@ -29,6 +22,13 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Newtonsoft.Json.Serialization; +using System.Buffers; +using System.Diagnostics; +using System.Globalization; +using System.IO.Pipelines; +using System.Security.Claims; +using System.Text; +using System.Threading.Channels; namespace Microsoft.AspNetCore.SignalR.Tests; @@ -5347,6 +5347,197 @@ public async Task GracefulCloseDisablesReconnect() } } + [Fact] + public async Task HubMethodCanInjectServiceWithAttribute() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithAttribute)).DefaultTimeout(); + Assert.True(Assert.IsType(res.Result)); + } + } + + [Fact] + public async Task HubMethodCanInjectServiceWithStringParameterAndAttribute() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithStringAttribute), "test").DefaultTimeout(); + Assert.Equal(115L, res.Result); + } + } + + [Fact] + public async Task HubMethodWithFromServicesAttributeFailsIfServiceNotFound() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSignalR(o => o.EnableDetailedErrors = true); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithAttribute)).DefaultTimeout(); + Assert.Equal("An unexpected error occurred invoking 'ServiceWithAttribute' on the server. InvalidOperationException: No service for type 'Microsoft.AspNetCore.SignalR.Tests.Service1' has been registered.", res.Error); + } + } + + [Fact] + public async Task HubMethodWithMultipleFromServicesAttributes() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(ServicesHub.MultipleServicesWithAttribute)).DefaultTimeout(); + Assert.True(Assert.IsType(res.Result)); + } + } + + [Fact] + public async Task HubMethodWithMixedParametersAndFromServicesAttribute() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(ServicesHub.MixedParamsWithAttribute), 10, "test").DefaultTimeout(); + Assert.Equal(111L, res.Result); + } + } + + [Fact] + public async Task HubMethodWithFromServicesAttributeAndNullParameter() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithStringAttribute), (string)null).DefaultTimeout(); + Assert.Equal(115L, res.Result); + } + } + + [Fact] + public async Task HubMethodWithFromServicesAttributeBeforeRegularParameter() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(ServicesHub.ServiceAttributeBeforeParam), 42).DefaultTimeout(); + Assert.Equal(43L, res.Result); + } + } + + [Fact] + public async Task HubMethodWithFromServicesAttributeAndUploadStream() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(ServicesHub.UploadWithServiceAttribute), new[] { "id" }, new object[] { 5 }).DefaultTimeout(); + + await client.SendHubMessageAsync(new StreamItemMessage("id", 10)).DefaultTimeout(); + await client.SendHubMessageAsync(new StreamItemMessage("id", 20)).DefaultTimeout(); + await client.SendHubMessageAsync(CompletionMessage.Empty("id")).DefaultTimeout(); + + var response = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(31L, response.Result); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task HubMethodWithFromServicesAttributeRespectsDisableImplicitOption(bool disableImplicit) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSignalR(options => + { + options.EnableDetailedErrors = true; + options.DisableImplicitFromServicesParameters = disableImplicit; + }); + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + // ServiceWithAttribute explicitly uses [FromServices] so it should work regardless of DisableImplicitFromServicesParameters + var res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithAttribute)).DefaultTimeout(); + Assert.True(Assert.IsType(res.Result)); + } + } + + [Fact] + public async Task HubMethodWithFromServicesAttributeOnOptionalParameter() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + // Call with value + var res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithOptionalParam), 100).DefaultTimeout(); + Assert.Equal(101L, res.Result); + + // Call without value (using default) + res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithOptionalParam)).DefaultTimeout(); + Assert.Equal(43L, res.Result); + } + } + #pragma warning disable CA2252 // This API requires opting into preview features private class TestReconnectFeature : IStatefulReconnectFeature #pragma warning restore CA2252 // This API requires opting into preview features @@ -5477,4 +5668,4 @@ public static async Task> ReadAllAsync(this IAsyncEnumerable