package org.apereo.cas.web.landtool.single.service; import java.time.format.DateTimeFormatter; import java.util.Collection; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArraySet; import org.apache.commons.lang3.StringUtils; import org.apereo.cas.CentralAuthenticationService; import org.apereo.cas.authentication.Authentication; import org.apereo.cas.authentication.principal.Principal; import org.apereo.cas.ticket.Ticket; import org.apereo.cas.ticket.TicketGrantingTicket; import org.apereo.cas.web.landtool.single.config.SingleLoginProperties; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; /** * @author Tanbin * @date 2018-11-30 */ public class SingleLoginService { private static final Logger LOGGER = LoggerFactory.getLogger(SingleLoginService.class); @Autowired public SingleLoginProperties singleLoginProperties; @Autowired private UserIdObtainService userIdObtainService; private CentralAuthenticationService centralAuthenticationService; public static CopyOnWriteArraySet set =new CopyOnWriteArraySet<>(); public SingleLoginService(CentralAuthenticationService service) { this.centralAuthenticationService = service; } /** * 取得同一用户下需要注销的tgt(即除了当前tgt外的该用户的所有tgt) * @param userId * @param tgtId * @param clientIp * @return */ public Collection getKictOutTickets(String userId, String tgtId, String clientIp) { Collection tickets = this.centralAuthenticationService.getTickets(ticket -> { if (ticket instanceof TicketGrantingTicket) { TicketGrantingTicket tgt = ((TicketGrantingTicket) ticket).getRoot(); Authentication authentication = tgt.getAuthentication(); // //DEBUG LOGGER.debug("#####ALL LGOINED USER: [{}], TGT:[{}], DATE:[{}], CLIENT IP:[{}]", authentication.getPrincipal().getId(), tgt.getId(), authentication.getAuthenticationDate().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME), authentication.getAttributes().get("clientIp")); // //DEBUG System.out.println("#####ALL LGOINED USER: [{}], TGT:[{}], DATE:[{}], CLIENT IP:[{}]"+ authentication.getPrincipal().getId()+"-"+tgt.getId()+"-"+ authentication.getAuthenticationDate().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)+"-"+ authentication.getAttributes().get("clientIp")); Principal a = authentication.getPrincipal(); return tgt != null && authentication != null && a != null && userId.equals(a.getId()) && (singleLoginProperties.isIgnoreSameIp() ? !clientIp.equals(authentication.getAttributes().get("clientIp")) : true) && (StringUtils.isBlank(tgtId) ? true : !tgtId.equals(tgt.getId())); } else { return false; } }); return tickets; } public Collection getKictOutTickets(String userId, String clientIp) { return getKictOutTickets(userId, null, clientIp); } /** * 踢出用户上次未退出的登录 * @param userId * @param tgtId * @param clientIp */ public void kickOutOldLogins(String userId, String tgtId, String clientIp) { Collection tickets = this.getKictOutTickets(userId, tgtId, clientIp); if (tickets != null && tickets.size() > 0) { LOGGER.warn("#####【单用户登录限制】正在注销 [{}]用户的 {}个Ticket", userId, tickets.size()); } // 注销 for (Ticket ticket : tickets) { LOGGER.warn("#####【单用户登录限制】注销Ticket: [{}]", ticket.getId()); centralAuthenticationService.destroyTicketGrantingTicket(ticket.getId()); } } public void kickOutOldLogins(TicketGrantingTicket tgt) { if (tgt != null) { String userId = tgt.getAuthentication().getPrincipal().getId(); String tgtId = tgt.getId(); String clientName = (String) tgt.getAuthentication().getAttributes().get("clientName"); String clientIp = (String) tgt.getAuthentication().getAttributes().get("clientIp"); List userIds = userIdObtainService.obtain(clientName, userId); if (userIds != null) { userIds.forEach(uid -> kickOutOldLogins(uid, tgtId, clientIp)); } } } }