diff --git a/src/core/hle/service/nwm/nwm_uds.cpp b/src/core/hle/service/nwm/nwm_uds.cpp index 87a6b0ecaa..1bf13a2963 100644 --- a/src/core/hle/service/nwm/nwm_uds.cpp +++ b/src/core/hle/service/nwm/nwm_uds.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -27,6 +28,12 @@ namespace Service { namespace NWM { +namespace ErrCodes { +enum { + NotInitialized = 2, +}; +} // namespace ErrCodes + // Event that is signaled every time the connection status changes. static Kernel::SharedPtr connection_status_event; @@ -37,6 +44,8 @@ static Kernel::SharedPtr recv_buffer_memory; // Connection status of this 3DS. static ConnectionStatus connection_status{}; +static std::atomic initialized(false); + /* Node information about the current network. * The amount of elements in this vector is always the maximum number * of nodes specified in the network configuration. @@ -47,8 +56,17 @@ static NodeList node_info; // Node information about our own system. static NodeInfo current_node; -// Mapping of bind node ids to their respective events. -static std::unordered_map> bind_node_events; +struct BindNodeData { + u32 bind_node_id; ///< Id of the bind node associated with this data. + u8 channel; ///< Channel that this bind node was bound to. + u16 network_node_id; ///< Node id this bind node is associated with, only packets from this + /// network node will be received. + Kernel::SharedPtr event; ///< Receive event for this bind node. + std::deque> received_packets; ///< List of packets received on this channel. +}; + +// Mapping of data channels to their internal data. +static std::unordered_map channel_data; // The WiFi network channel that the network is currently on. // Since we're not actually interacting with physical radio waves, this is just a dummy value. @@ -75,6 +93,9 @@ constexpr size_t MaxBeaconFrames = 15; // List of the last beacons received from the network. static std::list received_beacons; +// Network node id used when a SecureData packet is addressed to every connected node. +constexpr u16 BroadcastNetworkNodeId = 0xFFFF; + /** * Returns a list of received 802.11 beacon frames from the specified sender since the last call. */ @@ -159,7 +180,9 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) { } static void HandleEAPoLPacket(const Network::WifiPacket& packet) { - std::lock_guard lock(connection_status_mutex); + std::unique_lock hle_lock(HLE::g_hle_lock, std::defer_lock); + std::unique_lock lock(connection_status_mutex, std::defer_lock); + std::lock(hle_lock, lock); if (GetEAPoLFrameType(packet.data) == EAPoLStartMagic) { if (connection_status.status != static_cast(NetworkStatus::ConnectedAsHost)) { @@ -205,10 +228,9 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { SendPacket(eapol_logoff); // TODO(B3N30): Broadcast updated node list // The 3ds does this presumably to support spectators. - std::lock_guard lock(HLE::g_hle_lock); connection_status_event->Signal(); } else { - if (connection_status.status != static_cast(NetworkStatus::NotConnected)) { + if (connection_status.status != static_cast(NetworkStatus::Connecting)) { LOG_DEBUG(Service_NWM, "Connection sequence aborted, because connection status is %u", connection_status.status); return; @@ -237,11 +259,63 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { // Some games require ConnectToNetwork to block, for now it doesn't // If blocking is implemented this lock needs to be changed, // otherwise it might cause deadlocks - std::lock_guard lock(HLE::g_hle_lock); connection_status_event->Signal(); } } +static void HandleSecureDataPacket(const Network::WifiPacket& packet) { + auto secure_data = ParseSecureDataHeader(packet.data); + std::unique_lock hle_lock(HLE::g_hle_lock, std::defer_lock); + std::unique_lock lock(connection_status_mutex, std::defer_lock); + std::lock(hle_lock, lock); + + if (secure_data.src_node_id == connection_status.network_node_id) { + // Ignore packets that came from ourselves. + return; + } + + if (secure_data.dest_node_id != connection_status.network_node_id && + secure_data.dest_node_id != BroadcastNetworkNodeId) { + // The packet wasn't addressed to us, we can only act as a router if we're the host. + // However, we might have received this packet due to a broadcast from the host, in that + // case just ignore it. + ASSERT_MSG(packet.destination_address == Network::BroadcastMac || + connection_status.status == static_cast(NetworkStatus::ConnectedAsHost), + "Can't be a router if we're not a host"); + + if (connection_status.status == static_cast(NetworkStatus::ConnectedAsHost) && + secure_data.dest_node_id != BroadcastNetworkNodeId) { + // Broadcast the packet so the right receiver can get it. + // TODO(B3N30): Is there a flag that makes this kind of routing be unicast instead of + // multicast? Perhaps this is a way to allow spectators to see some of the packets. + Network::WifiPacket out_packet = packet; + out_packet.destination_address = Network::BroadcastMac; + SendPacket(out_packet); + } + return; + } + + // The packet is addressed to us (or to everyone using the broadcast node id), handle it. + // TODO(B3N30): We don't currently send nor handle management frames. + ASSERT(!secure_data.is_management); + + // TODO(B3N30): Allow more than one bind node per channel. + auto channel_info = channel_data.find(secure_data.data_channel); + // Ignore packets from channels we're not interested in. + if (channel_info == channel_data.end()) + return; + + if (channel_info->second.network_node_id != BroadcastNetworkNodeId && + channel_info->second.network_node_id != secure_data.src_node_id) + return; + + // Add the received packet to the data queue. + channel_info->second.received_packets.emplace_back(packet.data); + + // Signal the data event. We can do this directly because we locked g_hle_lock + channel_info->second.event->Signal(); +} + /* * Start a connection sequence with an UDS server. The sequence starts by sending an 802.11 * authentication frame with SEQ1. @@ -251,7 +325,7 @@ void StartConnectionSequence(const MacAddress& server) { WifiPacket auth_request; { std::lock_guard lock(connection_status_mutex); - ASSERT(connection_status.status == static_cast(NetworkStatus::NotConnected)); + connection_status.status = static_cast(NetworkStatus::Connecting); // TODO(Subv): Handle timeout. @@ -329,7 +403,7 @@ static void HandleDataFrame(const Network::WifiPacket& packet) { HandleEAPoLPacket(packet); break; case EtherType::SecureData: - // TODO(B3N30): Handle SecureData packets + HandleSecureDataPacket(packet); break; } } @@ -482,6 +556,8 @@ static void InitializeWithVersion(Interface* self) { recv_buffer_memory = Kernel::g_handle_table.Get(sharedmem_handle); + initialized = true; + ASSERT_MSG(recv_buffer_memory->size == sharedmem_size, "Invalid shared memory size."); { @@ -535,6 +611,48 @@ static void GetConnectionStatus(Interface* self) { LOG_DEBUG(Service_NWM, "called"); } +/** + * NWM_UDS::GetNodeInformation service function. + * Returns the node inforamtion structure for the currently connected node. + * Inputs: + * 0 : Command header. + * 1 : Node ID. + * Outputs: + * 0 : Return header + * 1 : Result of function, 0 on success, otherwise error code + * 2-11 : NodeInfo structure. + */ +static void GetNodeInformation(Interface* self) { + IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0xD, 1, 0); + u16 network_node_id = rp.Pop(); + + if (!initialized) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotInitialized, ErrorModule::UDS, + ErrorSummary::StatusChanged, ErrorLevel::Status)); + return; + } + + { + std::lock_guard lock(connection_status_mutex); + auto itr = std::find_if(node_info.begin(), node_info.end(), + [network_node_id](const NodeInfo& node) { + return node.network_node_id == network_node_id; + }); + if (itr == node_info.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Status)); + return; + } + + IPC::RequestBuilder rb = rp.MakeBuilder(11, 0); + rb.Push(RESULT_SUCCESS); + rb.PushRaw(*itr); + } + LOG_DEBUG(Service_NWM, "called"); +} + /** * NWM_UDS::Bind service function. * Binds a BindNodeId to a data channel and retrieves a data event. @@ -557,29 +675,85 @@ static void Bind(Interface* self) { u8 data_channel = rp.Pop(); u16 network_node_id = rp.Pop(); - // TODO(Subv): Store the data channel and verify it when receiving data frames. - LOG_DEBUG(Service_NWM, "called"); - if (data_channel == 0) { + if (data_channel == 0 || bind_node_id == 0) { IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, ErrorSummary::WrongArgument, ErrorLevel::Usage)); return; } + constexpr size_t MaxBindNodes = 16; + if (channel_data.size() >= MaxBindNodes) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::OutOfMemory, ErrorModule::UDS, + ErrorSummary::OutOfResource, ErrorLevel::Status)); + return; + } + + constexpr u32 MinRecvBufferSize = 0x5F4; + if (recv_buffer_size < MinRecvBufferSize) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::TooLarge, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Usage)); + return; + } + // Create a new event for this bind node. - // TODO(Subv): Signal this event when new data is received on this data channel. auto event = Kernel::Event::Create(Kernel::ResetType::OneShot, "NWM::BindNodeEvent" + std::to_string(bind_node_id)); - bind_node_events[bind_node_id] = event; + std::lock_guard lock(connection_status_mutex); + + ASSERT(channel_data.find(data_channel) == channel_data.end()); + // TODO(B3N30): Support more than one bind node per channel. + channel_data[data_channel] = {bind_node_id, data_channel, network_node_id, event}; IPC::RequestBuilder rb = rp.MakeBuilder(1, 2); - rb.Push(RESULT_SUCCESS); rb.PushCopyHandles(Kernel::g_handle_table.Create(event).Unwrap()); } +/** + * NWM_UDS::Unbind service function. + * Unbinds a BindNodeId from a data channel. + * Inputs: + * 1 : BindNodeId + * Outputs: + * 0 : Return header + * 1 : Result of function, 0 on success, otherwise error code + */ +static void Unbind(Interface* self) { + IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x12, 1, 0); + + u32 bind_node_id = rp.Pop(); + if (bind_node_id == 0) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Usage)); + return; + } + + std::lock_guard lock(connection_status_mutex); + + auto itr = + std::find_if(channel_data.begin(), channel_data.end(), [bind_node_id](const auto& data) { + return data.second.bind_node_id == bind_node_id; + }); + + if (itr != channel_data.end()) { + channel_data.erase(itr); + } + + IPC::RequestBuilder rb = rp.MakeBuilder(5, 0); + rb.Push(RESULT_SUCCESS); + rb.Push(bind_node_id); + // TODO(B3N30): Find out what the other return values are + rb.Push(0); + rb.Push(0); + rb.Push(0); +} + /** * NWM_UDS::BeginHostingNetwork service function. * Creates a network and starts broadcasting its presence. @@ -606,13 +780,14 @@ static void BeginHostingNetwork(Interface* self) { LOG_DEBUG(Service_NWM, "called"); - Memory::ReadBlock(network_info_address, &network_info, sizeof(NetworkInfo)); - - // The real UDS module throws a fatal error if this assert fails. - ASSERT_MSG(network_info.max_nodes > 1, "Trying to host a network of only one member."); - { std::lock_guard lock(connection_status_mutex); + + Memory::ReadBlock(network_info_address, &network_info, sizeof(NetworkInfo)); + + // The real UDS module throws a fatal error if this assert fails. + ASSERT_MSG(network_info.max_nodes > 1, "Trying to host a network of only one member."); + connection_status.status = static_cast(NetworkStatus::ConnectedAsHost); // Ensure the application data size is less than the maximum value. @@ -626,11 +801,13 @@ static void BeginHostingNetwork(Interface* self) { connection_status.max_nodes = network_info.max_nodes; // Resize the nodes list to hold max_nodes. + node_info.clear(); node_info.resize(network_info.max_nodes); // There's currently only one node in the network (the host). connection_status.total_nodes = 1; network_info.total_nodes = 1; + // The host is always the first node connection_status.network_node_id = 1; current_node.network_node_id = 1; @@ -639,12 +816,22 @@ static void BeginHostingNetwork(Interface* self) { connection_status.node_bitmask |= 1; // Notify the application that the first node was set. connection_status.changed_nodes |= 1; - node_info[0] = current_node; - } - // If the game has a preferred channel, use that instead. - if (network_info.channel != 0) - network_channel = network_info.channel; + if (auto room_member = Network::GetRoomMember().lock()) { + if (room_member->IsConnected()) { + network_info.host_mac_address = room_member->GetMacAddress(); + } else { + network_info.host_mac_address = {{0x0, 0x0, 0x0, 0x0, 0x0, 0x0}}; + } + } + node_info[0] = current_node; + + // If the game has a preferred channel, use that instead. + if (network_info.channel != 0) + network_channel = network_info.channel; + else + network_info.channel = DefaultNetworkChannel; + } connection_status_event->Signal(); @@ -652,8 +839,7 @@ static void BeginHostingNetwork(Interface* self) { CoreTiming::ScheduleEvent(msToCycles(DefaultBeaconInterval * MillisecondsPerTU), beacon_broadcast_event, 0); - LOG_WARNING(Service_NWM, - "An UDS network has been created, but broadcasting it is unimplemented."); + LOG_DEBUG(Service_NWM, "An UDS network has been created."); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); rb.Push(RESULT_SUCCESS); @@ -722,31 +908,25 @@ static void SendTo(Interface* self) { size_t desc_size; const VAddr input_address = rp.PopStaticBuffer(&desc_size, false); - ASSERT(desc_size == data_size); + ASSERT(desc_size >= data_size); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); - u16 network_node_id; - - { - std::lock_guard lock(connection_status_mutex); - if (connection_status.status != static_cast(NetworkStatus::ConnectedAsClient) && - connection_status.status != static_cast(NetworkStatus::ConnectedAsHost)) { - rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, - ErrorSummary::InvalidState, ErrorLevel::Status)); - return; - } - - if (dest_node_id == connection_status.network_node_id) { - rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, - ErrorSummary::WrongArgument, ErrorLevel::Status)); - return; - } - - network_node_id = connection_status.network_node_id; + std::lock_guard lock(connection_status_mutex); + if (connection_status.status != static_cast(NetworkStatus::ConnectedAsClient) && + connection_status.status != static_cast(NetworkStatus::ConnectedAsHost)) { + rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, + ErrorSummary::InvalidState, ErrorLevel::Status)); + return; } - // TODO(Subv): Do something with the flags. + if (dest_node_id == connection_status.network_node_id) { + rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Status)); + return; + } + + // TODO(B3N30): Do something with the flags. constexpr size_t MaxSize = 0x5C6; if (data_size > MaxSize) { @@ -758,20 +938,116 @@ static void SendTo(Interface* self) { std::vector data(data_size); Memory::ReadBlock(input_address, data.data(), data.size()); - // TODO(Subv): Increment the sequence number after each sent packet. + // TODO(B3N30): Increment the sequence number after each sent packet. u16 sequence_number = 0; - std::vector data_payload = - GenerateDataPayload(data, data_channel, dest_node_id, network_node_id, sequence_number); + std::vector data_payload = GenerateDataPayload( + data, data_channel, dest_node_id, connection_status.network_node_id, sequence_number); - // TODO(Subv): Retrieve the MAC address of the dest_node_id and our own to encrypt + // TODO(B3N30): Retrieve the MAC address of the dest_node_id and our own to encrypt // and encapsulate the payload. - // TODO(Subv): Send the frame. + Network::WifiPacket packet; + // Data frames are sent to the host, who then decides where to route it to. If we're the host, + // just directly broadcast the frame. + if (connection_status.status == static_cast(NetworkStatus::ConnectedAsHost)) + packet.destination_address = Network::BroadcastMac; + else + packet.destination_address = network_info.host_mac_address; + packet.channel = network_channel; + packet.data = std::move(data_payload); + packet.type = Network::WifiPacket::PacketType::Data; + + SendPacket(packet); rb.Push(RESULT_SUCCESS); +} - LOG_WARNING(Service_NWM, "(STUB) called dest_node_id=%u size=%u flags=%u channel=%u", - static_cast(dest_node_id), data_size, flags, static_cast(data_channel)); +/** + * NWM_UDS::PullPacket service function. + * Receives a data frame from the specified bind node id + * Inputs: + * 0 : Command header. + * 1 : Bind node id. + * 2 : Max out buff size >> 2. + * 3 : Max out buff size. + * 64 : Output buffer descriptor + * 65 : Output buffer address + * Outputs: + * 0 : Return header + * 1 : Result of function, 0 on success, otherwise error code + * 2 : Received data size + * 3 : u16 Source network node id + * 4 : Buffer descriptor + * 5 : Buffer address + */ +static void PullPacket(Interface* self) { + IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x14, 3, 0); + + u32 bind_node_id = rp.Pop(); + u32 max_out_buff_size_aligned = rp.Pop(); + u32 max_out_buff_size = rp.Pop(); + + size_t desc_size; + const VAddr output_address = rp.PeekStaticBuffer(0, &desc_size); + ASSERT(desc_size == max_out_buff_size); + + std::lock_guard lock(connection_status_mutex); + if (connection_status.status != static_cast(NetworkStatus::ConnectedAsHost) && + connection_status.status != static_cast(NetworkStatus::ConnectedAsClient) && + connection_status.status != static_cast(NetworkStatus::ConnectedAsSpectator)) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, + ErrorSummary::InvalidState, ErrorLevel::Status)); + return; + } + + auto channel = + std::find_if(channel_data.begin(), channel_data.end(), [bind_node_id](const auto& data) { + return data.second.bind_node_id == bind_node_id; + }); + + if (channel == channel_data.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Usage)); + return; + } + + if (channel->second.received_packets.empty()) { + Memory::ZeroBlock(output_address, desc_size); + IPC::RequestBuilder rb = rp.MakeBuilder(3, 2); + rb.Push(RESULT_SUCCESS); + rb.Push(0); + rb.Push(0); + rb.PushStaticBuffer(output_address, desc_size, 0); + return; + } + + const auto& next_packet = channel->second.received_packets.front(); + + auto secure_data = ParseSecureDataHeader(next_packet); + auto data_size = secure_data.GetActualDataSize(); + + if (data_size > max_out_buff_size) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::TooLarge, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Usage)); + return; + } + + IPC::RequestBuilder rb = rp.MakeBuilder(3, 2); + Memory::ZeroBlock(output_address, desc_size); + // Write the actual data. + Memory::WriteBlock(output_address, + next_packet.data() + sizeof(LLCHeader) + sizeof(SecureDataHeader), + data_size); + + rb.Push(RESULT_SUCCESS); + rb.Push(data_size); + rb.Push(secure_data.src_node_id); + rb.PushStaticBuffer(output_address, desc_size, 0); + + channel->second.received_packets.pop_front(); } /** @@ -986,14 +1262,14 @@ const Interface::FunctionInfo FunctionTable[] = { {0x00090442, nullptr, "ConnectNetwork (deprecated)"}, {0x000A0000, nullptr, "DisconnectNetwork"}, {0x000B0000, GetConnectionStatus, "GetConnectionStatus"}, - {0x000D0040, nullptr, "GetNodeInformation"}, + {0x000D0040, GetNodeInformation, "GetNodeInformation"}, {0x000E0006, nullptr, "DecryptBeaconData (deprecated)"}, {0x000F0404, RecvBeaconBroadcastData, "RecvBeaconBroadcastData"}, {0x00100042, SetApplicationData, "SetApplicationData"}, {0x00110040, nullptr, "GetApplicationData"}, {0x00120100, Bind, "Bind"}, - {0x00130040, nullptr, "Unbind"}, - {0x001400C0, nullptr, "PullPacket"}, + {0x00130040, Unbind, "Unbind"}, + {0x001400C0, PullPacket, "PullPacket"}, {0x00150080, nullptr, "SetMaxSendDelay"}, {0x00170182, SendTo, "SendTo"}, {0x001A0000, GetChannel, "GetChannel"}, @@ -1018,9 +1294,10 @@ NWM_UDS::NWM_UDS() { NWM_UDS::~NWM_UDS() { network_info = {}; - bind_node_events.clear(); + channel_data.clear(); connection_status_event = nullptr; recv_buffer_memory = nullptr; + initialized = false; { std::lock_guard lock(connection_status_mutex); diff --git a/src/core/hle/service/nwm/uds_data.cpp b/src/core/hle/service/nwm/uds_data.cpp index cbeb75dfae..6a693c0796 100644 --- a/src/core/hle/service/nwm/uds_data.cpp +++ b/src/core/hle/service/nwm/uds_data.cpp @@ -279,6 +279,15 @@ std::vector GenerateDataPayload(const std::vector& data, u8 channel, u16 return buffer; } +SecureDataHeader ParseSecureDataHeader(const std::vector& data) { + SecureDataHeader header; + + // Skip the LLC header + std::memcpy(&header, data.data() + sizeof(LLCHeader), sizeof(header)); + + return header; +} + std::vector GenerateEAPoLStartFrame(u16 association_id, const NodeInfo& node_info) { EAPoLStartPacket eapol_start{}; eapol_start.association_id = association_id; diff --git a/src/core/hle/service/nwm/uds_data.h b/src/core/hle/service/nwm/uds_data.h index 76bccb1bfa..59906f677e 100644 --- a/src/core/hle/service/nwm/uds_data.h +++ b/src/core/hle/service/nwm/uds_data.h @@ -51,6 +51,10 @@ struct SecureDataHeader { u16_be sequence_number; u16_be dest_node_id; u16_be src_node_id; + + u32 GetActualDataSize() const { + return protocol_size - sizeof(SecureDataHeader); + } }; static_assert(sizeof(SecureDataHeader) == 14, "SecureDataHeader has the wrong size"); @@ -118,6 +122,11 @@ static_assert(sizeof(EAPoLLogoffPacket) == 0x298, "EAPoLLogoffPacket has the wro std::vector GenerateDataPayload(const std::vector& data, u8 channel, u16 dest_node, u16 src_node, u16 sequence_number); +/* + * Returns the SecureDataHeader stored in an 802.11 data frame. + */ +SecureDataHeader ParseSecureDataHeader(const std::vector& data); + /* * Generates an unencrypted 802.11 data frame body with the EAPoL-Start format for UDS * communication.