diff options
Diffstat (limited to 'spark-common/src/main/java/me/lucko/spark/common/ws')
4 files changed, 702 insertions, 0 deletions
diff --git a/spark-common/src/main/java/me/lucko/spark/common/ws/CryptoAlgorithm.java b/spark-common/src/main/java/me/lucko/spark/common/ws/CryptoAlgorithm.java new file mode 100644 index 0000000..f6cf1db --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/ws/CryptoAlgorithm.java @@ -0,0 +1,90 @@ +/* + * This file is part of spark. + * + * Copyright (c) lucko (Luck) <luck@lucko.me> + * Copyright (c) contributors + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.ws; + +import com.google.protobuf.ByteString; + +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.PublicKey; +import java.security.Signature; +import java.security.spec.X509EncodedKeySpec; + +/** + * An algorithm for keypair/signature cryptography. + */ +public enum CryptoAlgorithm { + + Ed25519("Ed25519", 255, "Ed25519"), + RSA2048("RSA", 2048, "SHA256withRSA"); + + private final String keyAlgorithm; + private final int keySize; + private final String signatureAlgorithm; + + CryptoAlgorithm(String keyAlgorithm, int keySize, String signatureAlgorithm) { + this.keyAlgorithm = keyAlgorithm; + this.keySize = keySize; + this.signatureAlgorithm = signatureAlgorithm; + } + + public KeyPairGenerator createKeyPairGenerator() throws NoSuchAlgorithmException { + return KeyPairGenerator.getInstance(this.keyAlgorithm); + } + + public KeyFactory createKeyFactory() throws NoSuchAlgorithmException { + return KeyFactory.getInstance(this.keyAlgorithm); + } + + public Signature createSignature() throws NoSuchAlgorithmException { + return Signature.getInstance(this.signatureAlgorithm); + } + + public KeyPair generateKeyPair() { + try { + KeyPairGenerator generator = createKeyPairGenerator(); + generator.initialize(this.keySize); + return generator.generateKeyPair(); + } catch (Exception e) { + throw new RuntimeException("Exception generating keypair", e); + } + } + + public PublicKey decodePublicKey(byte[] bytes) throws IllegalArgumentException { + try { + X509EncodedKeySpec spec = new X509EncodedKeySpec(bytes); + KeyFactory factory = createKeyFactory(); + return factory.generatePublic(spec); + } catch (Exception e) { + throw new IllegalArgumentException("Exception parsing public key", e); + } + } + + public PublicKey decodePublicKey(ByteString bytes) throws IllegalArgumentException { + if (bytes == null) { + return null; + } + return decodePublicKey(bytes.toByteArray()); + } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/ws/TrustedKeyStore.java b/spark-common/src/main/java/me/lucko/spark/common/ws/TrustedKeyStore.java new file mode 100644 index 0000000..1605a38 --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/ws/TrustedKeyStore.java @@ -0,0 +1,139 @@ +/* + * This file is part of spark. + * + * Copyright (c) lucko (Luck) <luck@lucko.me> + * Copyright (c) contributors + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.ws; + +import me.lucko.spark.common.util.Configuration; + +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.util.Base64; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +/** + * A store of trusted public keys. + */ +public class TrustedKeyStore { + private static final String TRUSTED_KEYS_OPTION = "trustedKeys"; + + /** The spark configuration */ + private final Configuration configuration; + /** Gets the local public/private key */ + private final CompletableFuture<KeyPair> localKeyPair; + /** A set of remote public keys to trust */ + private final Set<PublicKey> remoteTrustedKeys; + /** A mpa of pending remote public keys */ + private final Map<String, PublicKey> remotePendingKeys = new HashMap<>(); + + public TrustedKeyStore(Configuration configuration) { + this.configuration = configuration; + this.localKeyPair = CompletableFuture.supplyAsync(ViewerSocketConnection.CRYPTO::generateKeyPair); + this.remoteTrustedKeys = new HashSet<>(); + readTrustedKeys(); + } + + /** + * Gets the local public key. + * + * @return the local public key + */ + public PublicKey getLocalPublicKey() { + return this.localKeyPair.join().getPublic(); + } + + /** + * Gets the local private key. + * + * @return the local private key + */ + public PrivateKey getLocalPrivateKey() { + return this.localKeyPair.join().getPrivate(); + } + + /** + * Checks if a remote public key is trusted + * + * @param publicKey the public key + * @return if the key is trusted + */ + public boolean isKeyTrusted(PublicKey publicKey) { + return publicKey != null && this.remoteTrustedKeys.contains(publicKey); + } + + /** + * Adds a pending public key to be trusted in the future. + * + * @param clientId the client id submitting the key + * @param publicKey the public key + */ + public void addPendingKey(String clientId, PublicKey publicKey) { + this.remotePendingKeys.put(clientId, publicKey); + } + + /** + * Trusts a previously submitted remote public key + * + * @param clientId the id of the client that submitted the key + * @return true if the key was found and trusted + */ + public boolean trustPendingKey(String clientId) { + PublicKey key = this.remotePendingKeys.remove(clientId); + if (key == null) { + return false; + } + + this.remoteTrustedKeys.add(key); + writeTrustedKeys(); + return true; + } + + /** + * Reads trusted keys from the configuration + */ + private void readTrustedKeys() { + for (String encodedKey : this.configuration.getStringList(TRUSTED_KEYS_OPTION)) { + try { + PublicKey publicKey = ViewerSocketConnection.CRYPTO.decodePublicKey(Base64.getDecoder().decode(encodedKey)); + this.remoteTrustedKeys.add(publicKey); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + /** + * Writes trusted keys to the configuration + */ + private void writeTrustedKeys() { + List<String> encodedKeys = this.remoteTrustedKeys.stream() + .map(key -> Base64.getEncoder().encodeToString(key.getEncoded())) + .collect(Collectors.toList()); + + this.configuration.setStringList(TRUSTED_KEYS_OPTION, encodedKeys); + } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocket.java b/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocket.java new file mode 100644 index 0000000..5c7e08c --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocket.java @@ -0,0 +1,255 @@ +/* + * This file is part of spark. + * + * Copyright (c) lucko (Luck) <luck@lucko.me> + * Copyright (c) contributors + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.ws; + +import com.google.protobuf.ByteString; + +import me.lucko.spark.common.SparkPlatform; +import me.lucko.spark.common.sampler.AbstractSampler; +import me.lucko.spark.common.sampler.Sampler; +import me.lucko.spark.common.sampler.window.ProfilingWindowUtils; +import me.lucko.spark.common.util.MediaTypes; +import me.lucko.spark.common.util.ws.BytesocksClient; +import me.lucko.spark.proto.SparkProtos; +import me.lucko.spark.proto.SparkSamplerProtos; +import me.lucko.spark.proto.SparkWebSocketProtos.ClientConnect; +import me.lucko.spark.proto.SparkWebSocketProtos.ClientPing; +import me.lucko.spark.proto.SparkWebSocketProtos.PacketWrapper; +import me.lucko.spark.proto.SparkWebSocketProtos.ServerConnectResponse; +import me.lucko.spark.proto.SparkWebSocketProtos.ServerPong; +import me.lucko.spark.proto.SparkWebSocketProtos.ServerUpdateSamplerData; +import me.lucko.spark.proto.SparkWebSocketProtos.ServerUpdateStatistics; + +import java.security.PublicKey; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; + +/** + * Represents a connection with the spark viewer. + */ +public class ViewerSocket implements ViewerSocketConnection.Listener, AutoCloseable { + + /** Allow 60 seconds for the first client to connect */ + private static final long SOCKET_INITIAL_TIMEOUT = TimeUnit.SECONDS.toMillis(60); + + /** Once established, expect a ping at least once every 30 seconds */ + private static final long SOCKET_ESTABLISHED_TIMEOUT = TimeUnit.SECONDS.toMillis(30); + + /** The spark platform */ + private final SparkPlatform platform; + /** The export props to use when exporting the sampler data */ + private final Sampler.ExportProps exportProps; + /** The underlying connection */ + private final ViewerSocketConnection socket; + + private boolean closed = false; + private final long socketOpenTime = System.currentTimeMillis(); + private long lastPing = 0; + private String lastPayloadId = null; + + public ViewerSocket(SparkPlatform platform, BytesocksClient client, Sampler.ExportProps exportProps) throws Exception { + this.platform = platform; + this.exportProps = exportProps; + this.socket = new ViewerSocketConnection(platform, client, this); + } + + private void log(String message) { + this.platform.getPlugin().log(Level.INFO, "[Viewer - " + this.socket.getChannelId() + "] " + message); + } + + /** + * Gets the initial payload to send to the viewer. + * + * @return the payload + */ + public SparkSamplerProtos.SocketChannelInfo getPayload() { + return SparkSamplerProtos.SocketChannelInfo.newBuilder() + .setChannelId(this.socket.getChannelId()) + .setPublicKey(ByteString.copyFrom(this.platform.getTrustedKeyStore().getLocalPublicKey().getEncoded())) + .build(); + } + + public boolean isOpen() { + return !this.closed && this.socket.isOpen(); + } + + /** + * Called each time the sampler rotates to a new window. + * + * @param sampler the sampler + */ + public void processWindowRotate(AbstractSampler sampler) { + if (this.closed) { + return; + } + + long time = System.currentTimeMillis(); + if ((time - this.socketOpenTime) > SOCKET_INITIAL_TIMEOUT && (time - this.lastPing) > SOCKET_ESTABLISHED_TIMEOUT) { + log("No clients have pinged for 30s, closing socket"); + close(); + return; + } + + // no clients connected yet! + if (this.lastPing == 0) { + return; + } + + try { + SparkSamplerProtos.SamplerData samplerData = sampler.toProto(this.platform, this.exportProps); + String key = this.platform.getBytebinClient().postContent(samplerData, MediaTypes.SPARK_SAMPLER_MEDIA_TYPE, "live").key(); + sendUpdatedSamplerData(key); + } catch (Exception e) { + this.platform.getPlugin().log(Level.WARNING, "Error whilst sending updated sampler data to the socket"); + e.printStackTrace(); + } + } + + /** + * Called when the sampler stops. + * + * @param sampler the sampler + */ + public void processSamplerStopped(AbstractSampler sampler) { + if (this.closed) { + return; + } + + close(); + } + + @Override + public void close() { + this.socket.sendPacket(builder -> builder.setServerPong(ServerPong.newBuilder() + .setOk(false) + .build() + )); + this.socket.close(); + this.closed = true; + } + + @Override + public boolean isKeyTrusted(PublicKey publicKey) { + return this.platform.getTrustedKeyStore().isKeyTrusted(publicKey); + } + + /** + * Sends a message to the socket to say that the given client is now trusted. + * + * @param clientId the client id + */ + public void sendClientTrustedMessage(String clientId) { + this.socket.sendPacket(builder -> builder.setServerConnectResponse(ServerConnectResponse.newBuilder() + .setClientId(clientId) + .setState(ServerConnectResponse.State.ACCEPTED) + .build() + )); + } + + /** + * Sends a message to the socket to indicate that updated sampler data is available + * + * @param payloadId the payload id of the updated data + */ + public void sendUpdatedSamplerData(String payloadId) { + this.socket.sendPacket(builder -> builder.setServerUpdateSampler(ServerUpdateSamplerData.newBuilder() + .setPayloadId(payloadId) + .build() + )); + this.lastPayloadId = payloadId; + } + + /** + * Sends a message to the socket with updated statistics + * + * @param platform the platform statistics + * @param system the system statistics + */ + public void sendUpdatedStatistics(SparkProtos.PlatformStatistics platform, SparkProtos.SystemStatistics system) { + this.socket.sendPacket(builder -> builder.setServerUpdateStatistics(ServerUpdateStatistics.newBuilder() + .setPlatform(platform) + .setSystem(system) + .build() + )); + } + + @Override + public void onPacket(PacketWrapper packet, boolean verified, PublicKey publicKey) throws Exception { + switch (packet.getPacketCase()) { + case CLIENT_PING: + onClientPing(packet.getClientPing(), publicKey); + break; + case CLIENT_CONNECT: + onClientConnect(packet.getClientConnect(), verified, publicKey); + break; + default: + throw new IllegalArgumentException("Unexpected packet: " + packet.getPacketCase()); + } + } + + private void onClientPing(ClientPing packet, PublicKey publicKey) { + this.lastPing = System.currentTimeMillis(); + this.socket.sendPacket(builder -> builder.setServerPong(ServerPong.newBuilder() + .setOk(!this.closed) + .setData(packet.getData()) + .build() + )); + } + + private void onClientConnect(ClientConnect packet, boolean verified, PublicKey publicKey) { + if (publicKey == null) { + throw new IllegalStateException("Missing public key"); + } + + this.lastPing = System.currentTimeMillis(); + + String clientId = packet.getClientId(); + log("Client connected: clientId=" + clientId + ", keyhash=" + hashPublicKey(publicKey) + ", desc=" + packet.getDescription()); + + ServerConnectResponse.Builder resp = ServerConnectResponse.newBuilder() + .setClientId(clientId) + .setSettings(ServerConnectResponse.Settings.newBuilder() + .setSamplerInterval(ProfilingWindowUtils.WINDOW_SIZE_SECONDS) + .setStatisticsInterval(10) + .build() + ); + + if (this.lastPayloadId != null) { + resp.setLastPayloadId(this.lastPayloadId); + } + + if (this.closed) { + resp.setState(ServerConnectResponse.State.REJECTED); + } else if (verified) { + resp.setState(ServerConnectResponse.State.ACCEPTED); + } else { + resp.setState(ServerConnectResponse.State.UNTRUSTED); + this.platform.getTrustedKeyStore().addPendingKey(clientId, publicKey); + } + + this.socket.sendPacket(builder -> builder.setServerConnectResponse(resp.build())); + } + + private static String hashPublicKey(PublicKey publicKey) { + return publicKey == null ? "null" : Integer.toHexString(publicKey.hashCode()); + } + +} diff --git a/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocketConnection.java b/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocketConnection.java new file mode 100644 index 0000000..f870cb7 --- /dev/null +++ b/spark-common/src/main/java/me/lucko/spark/common/ws/ViewerSocketConnection.java @@ -0,0 +1,218 @@ +/* + * This file is part of spark. + * + * Copyright (c) lucko (Luck) <luck@lucko.me> + * Copyright (c) contributors + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +package me.lucko.spark.common.ws; + +import com.google.protobuf.ByteString; + +import me.lucko.spark.common.SparkPlatform; +import me.lucko.spark.common.util.ws.BytesocksClient; +import me.lucko.spark.proto.SparkWebSocketProtos.PacketWrapper; +import me.lucko.spark.proto.SparkWebSocketProtos.RawPacket; + +import java.io.IOException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.Signature; +import java.util.Base64; +import java.util.function.Consumer; +import java.util.logging.Level; + +/** + * Controls a websocket connection between a spark server (the plugin/mod) and a spark client (the web viewer). + */ +public class ViewerSocketConnection implements BytesocksClient.Listener, AutoCloseable { + + /** The protocol version */ + public static final int VERSION_1 = 1; + /** The crypto algorithm used to sign/verify messages sent between the server and client */ + public static final CryptoAlgorithm CRYPTO = CryptoAlgorithm.RSA2048; + + /** The platform */ + private final SparkPlatform platform; + /** The underlying listener */ + private final Listener listener; + /** The private key used to sign messages sent from this connection */ + private final PrivateKey privateKey; + /** The bytesocks socket */ + private final BytesocksClient.Socket socket; + + public ViewerSocketConnection(SparkPlatform platform, BytesocksClient client, Listener listener) throws Exception { + this.platform = platform; + this.listener = listener; + this.privateKey = platform.getTrustedKeyStore().getLocalPrivateKey(); + this.socket = client.createAndConnect(this); + } + + public interface Listener { + + /** + * Checks if the given public key is trusted + * + * @param publicKey the public key + * @return true if trusted + */ + boolean isKeyTrusted(PublicKey publicKey); + + /** + * Handles a packet sent to the socket + * + * @param packet the packet that was sent + * @param verified if the packet was signed by a trusted key + * @param publicKey the public key the packet was signed with + */ + void onPacket(PacketWrapper packet, boolean verified, PublicKey publicKey) throws Exception; + } + + /** + * Gets the bytesocks channel id + * + * @return the channel id + */ + public String getChannelId() { + return this.socket.getChannelId(); + } + + /** + * Gets if the underlying socket is open + * + * @return true if the socket is open + */ + public boolean isOpen() { + return this.socket.isOpen(); + } + + @Override + public void onText(CharSequence data) { + try { + RawPacket packet = decodeRawPacket(data); + handleRawPacket(packet); + } catch (Exception e) { + this.platform.getPlugin().log(Level.WARNING, "Exception occurred while reading data from the socket"); + e.printStackTrace(); + } + } + + @Override + public void onError(Throwable error) { + this.platform.getPlugin().log(Level.INFO, "Socket error: " + error.getClass().getName() + " " + error.getMessage()); + error.printStackTrace(); + } + + @Override + public void onClose(int statusCode, String reason) { + //this.platform.getPlugin().log(Level.INFO, "Socket closed with status " + statusCode + " and reason " + reason); + } + + /** + * Sends a packet to the socket. + * + * @param packetBuilder the builder to construct the wrapper packet + */ + public void sendPacket(Consumer<PacketWrapper.Builder> packetBuilder) { + PacketWrapper.Builder builder = PacketWrapper.newBuilder(); + packetBuilder.accept(builder); + PacketWrapper wrapper = builder.build(); + + try { + sendPacket(wrapper); + } catch (Exception e) { + this.platform.getPlugin().log(Level.WARNING, "Exception occurred while sending data to the socket"); + e.printStackTrace(); + } + } + + /** + * Sends a packet to the socket. + * + * @param packet the packet to send + */ + private void sendPacket(PacketWrapper packet) throws Exception { + ByteString msg = packet.toByteString(); + + // sign the message using the server private key + Signature sign = CRYPTO.createSignature(); + sign.initSign(this.privateKey); + sign.update(msg.asReadOnlyByteBuffer()); + byte[] signature = sign.sign(); + + sendRawPacket(RawPacket.newBuilder() + .setVersion(VERSION_1) + .setSignature(ByteString.copyFrom(signature)) + .setMessage(msg) + .build() + ); + } + + /** + * Sends a raw packet to the socket. + * + * @param packet the packet to send + */ + private void sendRawPacket(RawPacket packet) throws IOException { + byte[] buf = packet.toByteArray(); + String encoded = Base64.getEncoder().encodeToString(buf); + this.socket.send(encoded); + } + + /** + * Decodes a raw packet sent to the socket. + * + * @param data the encoded data + * @return the decoded packet + */ + private RawPacket decodeRawPacket(CharSequence data) throws IOException { + byte[] buf = Base64.getDecoder().decode(data.toString()); + return RawPacket.parseFrom(buf); + } + + /** + * Handles a raw packet sent to the socket + * + * @param packet the packet + */ + private void handleRawPacket(RawPacket packet) throws Exception { + int version = packet.getVersion(); + if (version != VERSION_1) { + throw new IllegalArgumentException("Unsupported packet version " + version); + } + + ByteString message = packet.getMessage(); + PublicKey publicKey = CRYPTO.decodePublicKey(packet.getPublicKey()); + ByteString signature = packet.getSignature(); + + boolean verified = false; + if (signature != null && publicKey != null && this.listener.isKeyTrusted(publicKey)) { + Signature sign = CRYPTO.createSignature(); + sign.initVerify(publicKey); + sign.update(message.asReadOnlyByteBuffer()); + + verified = sign.verify(signature.toByteArray()); + } + + PacketWrapper wrapper = PacketWrapper.parseFrom(message); + this.listener.onPacket(wrapper, verified, publicKey); + } + + @Override + public void close() { + this.socket.close(1001 /* going away */, "spark plugin disconnected"); + } +} |