spring-websocket实现聊天室功能
最近看到有些人的博客中有聊天室的功能所以我也在我博客中写了一个,不过他们用的是java原生的,这里我使用了spring封装的spring-websocket
Spring-WebSocket配置
我们第一步要先配置一下websocket 的基本信息
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
|
@Configuration @EnableWebSocket public class ZVerifyWebSocketConfig implements WebSocketConfigurer {
@Override public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) { webSocketHandlerRegistry .addHandler(new ZVerifyWebSocketHandler(), "/ws-connect") .addInterceptors(new ZVerifyWebSocketInterceptor()) .setAllowedOrigins("*"); }
}
|
其中连接处理器和拦截器是我们自己定义的
"/ws-connect"
就是我们的路径
因为想要建立连接首先要通过我们的拦截器所以按照逻辑来写拦截器
前置拦截器
这个前置拦截器一般我们会做安全的校验和一系列处理,这里我就简单了写了一下,这里要做安全校验是因为我们定义的websocket并没有托管给我所使用的安全框架去验证用户,所以在这里要简单校验一下,
前置处理器的创建要去实现HandshakeInterceptor接口然后重写beforeHandshake,afterHandshake,两个方法,beforeHandshake是用做握手前置校验的,afterHandshake是做握手后置校验的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
|
@Configuration public class ZVerifyWebSocketInterceptor implements HandshakeInterceptor { @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler, Map<String, Object> attr) {
System.out.println("---- 握手之前触发 " + StpUtil.getTokenValue());
if(!StpUtil.isLogin()) { System.out.println("---- 未授权客户端,连接失败"); return false; }
attr.put("userId", StpUtil.getLoginIdAsLong()); return true; }
@Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) { System.out.println("---- 握手之后触发 "); }
}
|
连接处理器
这里是我们的主要处理器,基本上所有重要业务都在这里
首先创建一个自己的ZVerifyWebSocketHandler然后再去继承TextWebSocketHandler我们可以定制的去实现里边的方法,这里我就按照我自己的博客需求进行重写了,如果需要可以自行扩展。
重要属性
这个是用来存放我们当前在线的人的信息的,用于广播和人数统计还有私信
进入聊天成功的逻辑
首先重写afterConnectionEstablished()方法这个方法是在连接开启的时候触发的,也就是我握手成功之后,因为是聊天室所以功能防QQ做了,在登录之后会看到当前博客群聊中的在线人数,然后加载聊天记录。这一些简单的过程
-
首先要从session中取到当前连接的用户id,这里我要解释一下这个userId是从哪来的,是在我的握手之前触发的那个beforeHandshake()中写的项目中用的安全框架为Sa-Token,不熟悉的请自行查阅,拿到用户id之后将当前用户的webSocketSession存放到map中
-
更新当前的在线人数,这个处理是比较简单的
就是获取一下map的大小就是当前在线人数,然后发送广播消息,这里说一下广播消息其实很简单就是将map中的webSocketSession都取出来然后挨个发送消息注意这里要加一个锁因为不加锁的话可能会导致消息前后异常
-
加载历史记录也很平常就是将我们聊天记录存到数据库中,然后将其xxx小时的消息加载出来,然后想当前登录用户发送这里我使用的是历史12小时
收到消息之后处理逻辑
处理收到消息逻辑是handleTextMessage()方法里边有两个参数一个是发送消息的session,一个是包装的消息对象TextMessage,首先先带大家看一下TextMessage是个什么东西,我们在通过webSocketSession发送消息的时候可以发送多种对象
这里我使用了TextMessage,所以就讲一下这里我们在创建TextMessage对象的时候传入参数通过源码可以知道我可以传入一个可读的char值序列然后会将其转换成字符串调用抽象类的构造方法
第二个参数的意义是这是否是作为一系列部分消息发送的消息的最后一部分。到这里可以知道我们发送的消息就是抽象类AbstractWebSocketMessage中的payload属性,所以在这里我买可以通过这个入参拿到数据,然后根据其数据的第一个参数,也就是当前的类型去进行对应的逻辑处理,这里就没什么难点了
连接关闭
连接关闭的时候讲当前的用户session从map中remove掉就好如需扩展请自己进行逻辑的修改
源码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
| package com.zang.blogz.handler;
import cn.hutool.core.date.DateUtil; import cn.hutool.json.JSONUtil; import com.alibaba.fastjson.JSON; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.zang.blogz.dto.ChatRecordDTO; import com.zang.blogz.dto.RecallMessageDTO; import com.zang.blogz.dto.WebsocketMessageDTO; import com.zang.blogz.enmus.ChatTypeEnum; import com.zang.blogz.enmus.FilePathEnum; import com.zang.blogz.entity.ChatRecord; import com.zang.blogz.entity.UserInfo; import com.zang.blogz.model.input.ro.VoiceRO; import com.zang.blogz.service.ChatRecordService; import com.zang.blogz.service.UserInfoService; import com.zang.blogz.steam.optional.Opp; import com.zang.blogz.strategy.context.UploadStrategyContext; import com.zang.blogz.utils.BeanCopyUtils; import com.zang.blogz.utils.IpUtil; import lombok.Data; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException; import java.net.InetAddress; import java.util.Collection; import java.util.Date; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap;
@Data @Service @ServerEndpoint(value = "/ws-connect") public class ZVerifyWebSocketHandler extends TextWebSocketHandler {
private static ChatRecordService chatRecordService;
@Autowired public void setChatRecordDao(ChatRecordService chatRecordService) { ZVerifyWebSocketHandler.chatRecordService = chatRecordService; }
private static UserInfoService userInfoService;
@Autowired public void setUserInfoService(UserInfoService userInfoService) { ZVerifyWebSocketHandler.userInfoService = userInfoService; }
private static UploadStrategyContext uploadStrategyContext;
@Autowired public void setUploadStrategyContext(UploadStrategyContext uploadStrategyContext) { ZVerifyWebSocketHandler.uploadStrategyContext = uploadStrategyContext; }
public static String HEADER_NAME = "X-Real-IP";
private static ConcurrentHashMap<String, WebSocketSession> webSocketSessionMaps = new ConcurrentHashMap<>();
@Override public void afterConnectionEstablished(WebSocketSession session) throws Exception {
String userId = session.getAttributes().get("userId").toString(); webSocketSessionMaps.put(HEADER_NAME + userId, session); updateOnlineCount();
ChatRecordDTO chatRecordDTO = listChartRecords(session);
WebsocketMessageDTO messageDTO = WebsocketMessageDTO.builder() .type(ChatTypeEnum.HISTORY_RECORD.getType()) .data(chatRecordDTO) .build(); synchronized (session) { session.sendMessage(new TextMessage(JSON.toJSONString(messageDTO))); } String tips = "Web-Socket 连接成功,sid=" + session.getId() + ",userId=" + userId; System.out.println(tips);
}
private ChatRecordDTO listChartRecords(WebSocketSession session) {
String ipAddress = session.getAcceptedProtocol();
LambdaQueryWrapper<ChatRecord> queryWrapper = new LambdaQueryWrapper<>();
queryWrapper.ge(ChatRecord::getCreateTime, DateUtil.offsetHour(new Date(), -12));
return ChatRecordDTO.builder() .chatRecordList(chatRecordService.list(queryWrapper)) .ipAddress(ipAddress) .ipSource(IpUtil.getIpSource(ipAddress)) .build(); }
private void updateOnlineCount() throws IOException {
WebsocketMessageDTO messageDTO = WebsocketMessageDTO.builder() .type(ChatTypeEnum.ONLINE_COUNT.getType()) .data(webSocketSessionMaps.size()) .build(); broadcastMessage(messageDTO); }
@Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status){ String userId = session.getAttributes().get("userId").toString(); webSocketSessionMaps.remove(HEADER_NAME + userId);
}
@Override public void handleTextMessage(WebSocketSession session, TextMessage message) throws IOException {
String ipAddress = null; WebsocketMessageDTO messageDTO = JSONUtil.toBean(message.getPayload(), WebsocketMessageDTO.class, false); switch (Objects.requireNonNull(ChatTypeEnum.getChatType(messageDTO.getType()))) { case SEND_MESSAGE:
String data = String.valueOf(messageDTO.getData()) ; InetAddress address = Objects.requireNonNull(session.getLocalAddress()).getAddress(); if (Opp.of(address).isNonNull()){
ipAddress = address.getHostAddress(); }
String userId = session.getAttributes().get("userId").toString(); UserInfo byId = userInfoService.getById(Integer.valueOf(userId));
ChatRecord chatRecord = new ChatRecord();
chatRecord.setContent(data); chatRecord.setType(messageDTO.getType()); chatRecord.setAvatar(byId.getAvatar()); chatRecord.setNickname(byId.getNickname()); chatRecord.setUserId(byId.getId()); chatRecord.setIpAddress(ipAddress); String ipSource = IpUtil.getIpSource(ipAddress); chatRecord.setIpSource(ipSource); chatRecordService.save(chatRecord);
messageDTO.setData(chatRecord); broadcastMessage(messageDTO); break; case RECALL_MESSAGE: RecallMessageDTO recallMessage = JSON.parseObject(JSON.toJSONString(messageDTO.getData()), RecallMessageDTO.class); chatRecordService.removeById(recallMessage.getId()); broadcastMessage(messageDTO); break; case HEART_BEAT: messageDTO.setData("pong"); session.sendMessage(new TextMessage((JSON.toJSONString(messageDTO))));
default: break; } }
public static void sendMessage(WebSocketSession session, String message) { try { System.out.println("向sid为:" + session.getId() + ",发送:" + message); session.sendMessage(new TextMessage(message)); } catch (IOException e) { throw new RuntimeException(e); } }
public static void sendMessage(long userId, String message) { WebSocketSession session = webSocketSessionMaps.get(HEADER_NAME + userId); if(session != null) { sendMessage(session, message); } }
private void broadcastMessage(WebsocketMessageDTO messageDTO) throws IOException {
Collection<WebSocketSession> sessions = webSocketSessionMaps.values();
for (WebSocketSession webSocketService : sessions) { synchronized (webSocketService){ TextMessage textMessage = new TextMessage(JSON.toJSONString(messageDTO)); webSocketService.sendMessage(textMessage); }
} }
public void sendVoice(VoiceRO voiceRO) { String content = uploadStrategyContext.executeUploadStrategy(voiceRO.getFile(), FilePathEnum.VOICE.getPath()); voiceRO.setContent(content); ChatRecord chatRecord = BeanCopyUtils.copyObject(voiceRO, ChatRecord.class); chatRecordService.save(chatRecord); WebsocketMessageDTO messageDTO = WebsocketMessageDTO.builder() .type(ChatTypeEnum.VOICE_MESSAGE.getType()) .data(chatRecord) .build(); try { broadcastMessage(messageDTO); } catch (IOException e) { e.printStackTrace(); } }
}
|