|
@@ -0,0 +1,137 @@
|
|
|
+/**
|
|
|
+ * Copyright (c) 2016-2020 业主系统 All rights reserved.
|
|
|
+ * <p>
|
|
|
+ * https://www.yezhu.io
|
|
|
+ * <p>
|
|
|
+ * 版权所有,侵权必究!
|
|
|
+ */
|
|
|
+
|
|
|
+package com.kioor.websocket;
|
|
|
+
|
|
|
+import com.kioor.common.exception.ErrorCode;
|
|
|
+import com.kioor.common.exception.RenException;
|
|
|
+import com.kioor.common.utils.JsonUtils;
|
|
|
+import com.kioor.user.entity.TokenEntity;
|
|
|
+import com.kioor.user.service.TokenService;
|
|
|
+import com.kioor.websocket.config.WebSocketConfig;
|
|
|
+import com.kioor.websocket.data.MessageData;
|
|
|
+import com.kioor.websocket.data.WebSocketData;
|
|
|
+import jakarta.websocket.*;
|
|
|
+import jakarta.websocket.server.PathParam;
|
|
|
+import jakarta.websocket.server.ServerEndpoint;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.springframework.context.ApplicationContext;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
+
|
|
|
+/**
|
|
|
+ * WebSocket服务
|
|
|
+ *
|
|
|
+ * @author Mark sunlightcs@gmail.com
|
|
|
+ */
|
|
|
+@Slf4j
|
|
|
+@Component
|
|
|
+@ServerEndpoint(value = "/websocket/{token}", configurator = WebSocketConfig.class)
|
|
|
+public class WebSocketServer {
|
|
|
+
|
|
|
+ private static ApplicationContext applicationContext;
|
|
|
+
|
|
|
+ public static void setApplicationContext(ApplicationContext applicationContext) {
|
|
|
+ WebSocketServer.applicationContext = applicationContext;
|
|
|
+ }
|
|
|
+ /**
|
|
|
+ * 客户端连接信息
|
|
|
+ */
|
|
|
+ private static Map<String, WebSocketData> servers = new ConcurrentHashMap<>();
|
|
|
+ /**
|
|
|
+ * 存放所有在线的客户端
|
|
|
+ */
|
|
|
+ private static Map<Long, Session> clients = new ConcurrentHashMap<>();
|
|
|
+
|
|
|
+ @OnOpen
|
|
|
+ public void open(@PathParam("token") String token,Session session) {
|
|
|
+// Long userId = (Long) session.getUserProperties().get(Constant.USER_KEY);
|
|
|
+ TokenService tokenService = applicationContext.getBean(TokenService.class);
|
|
|
+
|
|
|
+ //通过token查找用户id
|
|
|
+ TokenEntity tokenEntity = tokenService.getByToken(token);
|
|
|
+ if (tokenEntity == null || tokenEntity.getExpireDate().getTime() < System.currentTimeMillis()) {
|
|
|
+ throw new RenException(ErrorCode.TOKEN_INVALID);
|
|
|
+ }
|
|
|
+ if(clients.containsKey(tokenEntity.getUserId())){
|
|
|
+ servers.remove(clients.get(tokenEntity.getUserId()).getId());
|
|
|
+ clients.remove(tokenEntity.getUserId());
|
|
|
+ }
|
|
|
+ servers.put(session.getId(), new WebSocketData(tokenEntity.getUserId(), session));
|
|
|
+ clients.put(tokenEntity.getUserId(), session);
|
|
|
+ }
|
|
|
+
|
|
|
+ @OnClose
|
|
|
+ public void onClose(Session session) {
|
|
|
+ //客户端断开连接
|
|
|
+ servers.remove(session.getId());
|
|
|
+ log.debug("websocket close, session id:" + session.getId());
|
|
|
+ }
|
|
|
+
|
|
|
+ @OnError
|
|
|
+ public void onError(Session session, Throwable throwable) {
|
|
|
+ servers.remove(session.getId());
|
|
|
+ log.error(throwable.getMessage(), throwable);
|
|
|
+ }
|
|
|
+
|
|
|
+ @OnMessage
|
|
|
+ public void onMessage(Session session, String msg) {
|
|
|
+ log.info("session id: " + session.getId() + ", message:" + msg);
|
|
|
+
|
|
|
+ MessageData<String> message = new MessageData<String>().msg("你说啥子");
|
|
|
+// List<Long> userIdList = new ArrayList<>();
|
|
|
+// userIdList.add(1806260428364058625L);
|
|
|
+// sendMessage(userIdList, message);
|
|
|
+ sendMessage(1806260428364058625L, message);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 发送信息
|
|
|
+ *
|
|
|
+ * @param userIdList 用户ID列表
|
|
|
+ * @param message 消息内容
|
|
|
+ */
|
|
|
+ public void sendMessage(List<Long> userIdList, MessageData<?> message) {
|
|
|
+ userIdList.forEach(userId -> sendMessage(userId, message));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 发送信息
|
|
|
+ *
|
|
|
+ * @param userId 用户ID
|
|
|
+ * @param message 消息内容
|
|
|
+ */
|
|
|
+ public void sendMessage(Long userId, MessageData<?> message) {
|
|
|
+ servers.values().forEach(info -> {
|
|
|
+ if (userId.equals(info.getUserId())) {
|
|
|
+ sendMessage(info.getSession(), message);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 发送信息给全部用户
|
|
|
+ *
|
|
|
+ * @param message 消息内容
|
|
|
+ */
|
|
|
+ public void sendMessageAll(MessageData<?> message) {
|
|
|
+ servers.values().forEach(info -> sendMessage(info.getSession(), message));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void sendMessage(Session session, MessageData<?> message) {
|
|
|
+ try {
|
|
|
+ session.getBasicRemote().sendText(JsonUtils.toJsonString(message));
|
|
|
+ } catch (IOException e) {
|
|
|
+ log.error("send message error," + e.getMessage(), e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|