diff --git a/src/core/hle/service/nwm/nwm_uds.cpp b/src/core/hle/service/nwm/nwm_uds.cpp index 19527f8ab4..93d2b188ce 100644 --- a/src/core/hle/service/nwm/nwm_uds.cpp +++ b/src/core/hle/service/nwm/nwm_uds.cpp @@ -79,7 +79,11 @@ static u8 network_channel = DefaultNetworkChannel; static NetworkInfo network_info; // Mapping of mac addresses to their respective node_ids. -static std::map node_map; +struct Node { + bool connected; + u16 node_id; +}; +static std::map node_map; // Event that will generate and send the 802.11 beacon frames. static CoreTiming::EventType* beacon_broadcast_event; @@ -165,10 +169,12 @@ static void BroadcastNodeMap() { std::memcpy(packet.data.data(), &size, sizeof(size)); std::size_t offset = sizeof(size); for (const auto& node : node_map) { - std::memcpy(packet.data.data() + offset, node.first.data(), sizeof(node.first)); - std::memcpy(packet.data.data() + offset + sizeof(node.first), &node.second, - sizeof(node.second)); - offset += sizeof(node.first) + sizeof(node.second); + if (node.second.connected) { + std::memcpy(packet.data.data() + offset, node.first.data(), sizeof(node.first)); + std::memcpy(packet.data.data() + offset + sizeof(node.first), &node.second.node_id, + sizeof(node.second.node_id)); + offset += sizeof(node.first) + sizeof(node.second.node_id); + } } SendPacket(packet); @@ -176,6 +182,11 @@ static void BroadcastNodeMap() { static void HandleNodeMapPacket(const Network::WifiPacket& packet) { std::lock_guard lock(connection_status_mutex); + if (connection_status.status == static_cast(NetworkStatus::ConnectedAsHost)) { + LOG_DEBUG(Service_NWM, "Ignored NodeMapPacket since connection_status is host"); + return; + } + node_map.clear(); std::size_t num_entries; Network::MacAddress address; @@ -185,7 +196,8 @@ static void HandleNodeMapPacket(const Network::WifiPacket& packet) { for (std::size_t i = 0; i < num_entries; ++i) { std::memcpy(&address, packet.data.data() + offset, sizeof(address)); std::memcpy(&id, packet.data.data() + offset + sizeof(address), sizeof(id)); - node_map[address] = id; + node_map[address].connected = true; + node_map[address].node_id = id; offset += sizeof(address) + sizeof(id); } } @@ -218,7 +230,12 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) { "Could not join network"); { std::lock_guard lock(connection_status_mutex); - ASSERT(connection_status.status == static_cast(NetworkStatus::Connecting)); + if (connection_status.status != static_cast(NetworkStatus::Connecting)) { + LOG_DEBUG(Service_NWM, + "Ignored AssociationResponseFrame because connection status is {}", + connection_status.status); + return; + } } // Send the EAPoL-Start packet to the server. @@ -245,14 +262,21 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { return; } - auto node = DeserializeNodeInfoFromFrame(packet.data); - - if (connection_status.max_nodes == connection_status.total_nodes) { - // Reject connection attempt - LOG_ERROR(Service_NWM, "Reached maximum nodes, but reject packet wasn't sent."); - // TODO(B3N30): Figure out what packet is sent here + auto node_it = node_map.find(packet.transmitter_address); + if (node_it == node_map.end()) { + LOG_DEBUG(Service_NWM, "Connection sequence aborted, because the AuthenticationFrame " + "of the client wasn't recieved"); return; } + if (node_it->second.connected) { + LOG_DEBUG(Service_NWM, + "Connection sequence aborted, because the client is already connected"); + return; + } + + ASSERT(connection_status.max_nodes != connection_status.total_nodes); + + auto node = DeserializeNodeInfoFromFrame(packet.data); // Get an unused network node id u16 node_id = GetNextAvailableNodeId(); @@ -268,7 +292,8 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { network_info.total_nodes++; - node_map[packet.transmitter_address] = node.network_node_id; + node_map[packet.transmitter_address].node_id = node.network_node_id; + node_map[packet.transmitter_address].connected = true; BroadcastNodeMap(); @@ -321,6 +346,7 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { connection_status_event->Signal(); connection_event->Signal(); } else if (connection_status.status == static_cast(NetworkStatus::ConnectedAsClient)) { + // TODO(B3N30): Remove that section and send/receive a proper connection_status packet // On a 3ds this packet wouldn't be addressed to already connected clients // We use this information because in the current implementation the host // isn't broadcasting the node information @@ -349,6 +375,14 @@ static void HandleSecureDataPacket(const Network::WifiPacket& packet) { std::unique_lock lock(connection_status_mutex, std::defer_lock); std::lock(hle_lock, lock); + if (connection_status.status != static_cast(NetworkStatus::ConnectedAsHost) && + connection_status.status != static_cast(NetworkStatus::ConnectedAsClient)) { + // TODO(B3N30): Handle spectators + LOG_DEBUG(Service_NWM, "Ignored SecureDataPacket, because connection status is {}", + connection_status.status); + return; + } + if (secure_data.src_node_id == connection_status.network_node_id) { // Ignore packets that came from ourselves. return; @@ -464,12 +498,24 @@ void HandleAuthenticationFrame(const Network::WifiPacket& packet) { connection_status.status); return; } + if (node_map.find(packet.transmitter_address) != node_map.end()) { + LOG_ERROR(Service_NWM, "Connection sequence aborted, because there is already a " + "connected client with that MAC-Adress"); + return; + } + if (connection_status.max_nodes == connection_status.total_nodes) { + // Reject connection attempt + LOG_ERROR(Service_NWM, "Reached maximum nodes, but reject packet wasn't sent."); + // TODO(B3N30): Figure out what packet is sent here + return; + } // Respond with an authentication response frame with SEQ2 auth_request.channel = network_channel; auth_request.data = GenerateAuthenticationFrame(AuthenticationSeq::SEQ2); auth_request.destination_address = packet.transmitter_address; auth_request.type = WifiPacket::PacketType::Authentication; + node_map[packet.transmitter_address].connected = false; } SendPacket(auth_request); @@ -492,7 +538,15 @@ void HandleDeauthenticationFrame(const Network::WifiPacket& packet) { return; } - u16 node_id = node_map[packet.transmitter_address]; + s32 node_id = node_map[packet.transmitter_address].node_id; + bool connected = node_map[packet.transmitter_address].connected; + node_map.erase(packet.transmitter_address); + + if (!connected) { + LOG_DEBUG(Service_NWM, "Received DeauthenticationFrame from a not connected MAC Address"); + return; + } + auto node = std::find_if(node_info.begin(), node_info.end(), [&node_id](const NodeInfo& info) { return info.network_node_id == node_id; }); @@ -501,8 +555,13 @@ void HandleDeauthenticationFrame(const Network::WifiPacket& packet) { connection_status.node_bitmask &= ~(1 << (node_id - 1)); connection_status.changed_nodes |= 1 << (node_id - 1); connection_status.total_nodes--; + connection_status.nodes[node_id - 1] = 0; network_info.total_nodes--; + // TODO(B3N30): broadcast new connection_status to clients + + node->Reset(); + connection_status_event->Signal(); } @@ -656,6 +715,7 @@ void NWM_UDS::InitializeWithVersion(Kernel::HLERequestContext& ctx) { connection_status.status = static_cast(NetworkStatus::NotConnected); node_info.clear(); node_info.push_back(current_node); + channel_data.clear(); } IPC::RequestBuilder rb = rp.MakeBuilder(1, 2); @@ -1012,8 +1072,9 @@ void NWM_UDS::SendTo(Kernel::HLERequestContext& ctx) { } else if (dest_node_id != 1) { // Send to specific client auto destination = - std::find_if(node_map.begin(), node_map.end(), - [dest_node_id](const auto& node) { return node.second == dest_node_id; }); + std::find_if(node_map.begin(), node_map.end(), [dest_node_id](const auto& node) { + return node.second.node_id == dest_node_id && node.second.connected; + }); if (destination == node_map.end()) { LOG_ERROR(Service_NWM, "tried to send packet to unknown dest id {}", dest_node_id); rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, diff --git a/src/core/hle/service/nwm/nwm_uds.h b/src/core/hle/service/nwm/nwm_uds.h index d19bd4a1c6..920d7b69a8 100644 --- a/src/core/hle/service/nwm/nwm_uds.h +++ b/src/core/hle/service/nwm/nwm_uds.h @@ -32,6 +32,12 @@ struct NodeInfo { INSERT_PADDING_BYTES(4); u16_le network_node_id; INSERT_PADDING_BYTES(6); + + void Reset() { + friend_code_seed = 0; + username.fill(0); + network_node_id = 0; + } }; static_assert(sizeof(NodeInfo) == 40, "NodeInfo has incorrect size.");