diff --git a/doc/open-api-interface.md b/doc/open-api-interface.md index fc17369..ea3fbf3 100644 --- a/doc/open-api-interface.md +++ b/doc/open-api-interface.md @@ -94,6 +94,9 @@ "traceId": "trace-001", "decision": "BLOCK", "alerted": true, + "rejectCode": 403, + "rejectAction": "blocked", + "rejectMsg": "内容不合规", "hits": [ { "moduleType": "ACL", @@ -133,6 +136,9 @@ | data.traceId | string | 链路 ID | | data.decision | string | 最终决策:`ALLOW` / `BLOCK` | | data.alerted | boolean | 是否命中任意规则 | +| data.rejectCode | number | 命中内容合规拦截时返回的拒绝码(来自“拒绝描述配置”) | +| data.rejectAction | string | 命中内容合规拦截时返回的动作标识 | +| data.rejectMsg | string | 命中内容合规拦截时返回的提示文案 | | data.hits | array | 命中明细列表 | | data.hits[].moduleType | string | 模块:`ACL` / `ATTACK` / `CONTENT` | | data.hits[].eventType | string | 事件类型 | @@ -156,6 +162,7 @@ - ACL:IP 白黑名单、接口封堵、自定义组合规则 - ATTACK:攻击规则 + 特征签名(如 SQL 注入、越狱等) - CONTENT:DLP(邮箱/手机号/证件等)、内容策略、脱敏模板策略 +- 内容审核页面联动:`合规检测范围` 可控制是否执行输入侧内容检测;`词库管理` 会参与内容命中;`拒绝描述配置` 会覆盖拦截文案与返回字段。 ### 5.2 语料拼接范围 系统会将下列字段聚合后参与规则匹配: diff --git a/doc/sql/20260305_content_audit_page.sql b/doc/sql/20260305_content_audit_page.sql new file mode 100644 index 0000000..0505288 --- /dev/null +++ b/doc/sql/20260305_content_audit_page.sql @@ -0,0 +1,86 @@ +-- 合规审核页面配置(供 biz 页面与 open-api 运行时读取) + +-- PostgreSQL +CREATE TABLE IF NOT EXISTS d_content_audit_setting ( + id VARCHAR(64) NOT NULL, + scope_code VARCHAR(64) NOT NULL, + engine_running SMALLINT NOT NULL DEFAULT 1, + prompt_enabled SMALLINT NOT NULL DEFAULT 1, + answer_enabled SMALLINT NOT NULL DEFAULT 1, + reasoning_enabled SMALLINT NOT NULL DEFAULT 0, + recall_enabled SMALLINT NOT NULL DEFAULT 0, + reject_code INT NOT NULL DEFAULT 403, + reject_msg VARCHAR(255) NOT NULL DEFAULT '内容不合规', + reject_action VARCHAR(32) NOT NULL DEFAULT 'blocked', + create_by VARCHAR(64) DEFAULT '', + create_time TIMESTAMP, + update_by VARCHAR(64) DEFAULT '', + update_time TIMESTAMP, + is_deleted INT NOT NULL DEFAULT 0, + PRIMARY KEY (id), + CONSTRAINT uk_content_audit_setting_scope UNIQUE (scope_code) +); + +CREATE INDEX IF NOT EXISTS idx_content_audit_setting_scope_del + ON d_content_audit_setting(scope_code, is_deleted); + +CREATE TABLE IF NOT EXISTS d_content_corpus ( + id VARCHAR(64) NOT NULL, + scope_code VARCHAR(64) NOT NULL, + corpus_text VARCHAR(500) NOT NULL, + tag VARCHAR(64) DEFAULT '', + expire_date DATE NULL, + status VARCHAR(20) NOT NULL DEFAULT 'ENABLED', + create_by VARCHAR(64) DEFAULT '', + create_time TIMESTAMP, + update_by VARCHAR(64) DEFAULT '', + update_time TIMESTAMP, + is_deleted INT NOT NULL DEFAULT 0, + PRIMARY KEY (id) +); + +CREATE INDEX IF NOT EXISTS idx_content_corpus_scope_status + ON d_content_corpus(scope_code, status, is_deleted); + +-- 默认初始化 GLOBAL 配置 +INSERT INTO d_content_audit_setting(id, scope_code, engine_running, prompt_enabled, answer_enabled, reasoning_enabled, recall_enabled, reject_code, reject_msg, reject_action, create_by, create_time, update_by, update_time, is_deleted) +SELECT 'content_audit_global_init', 'GLOBAL', 1, 1, 1, 0, 0, 403, '内容不合规', 'blocked', 'admin', CURRENT_TIMESTAMP, 'admin', CURRENT_TIMESTAMP, 0 +WHERE NOT EXISTS (SELECT 1 FROM d_content_audit_setting WHERE scope_code = 'GLOBAL' AND is_deleted = 0); + +-- MySQL 参考(按需执行) +-- CREATE TABLE IF NOT EXISTS d_content_audit_setting ( +-- id VARCHAR(64) NOT NULL, +-- scope_code VARCHAR(64) NOT NULL, +-- engine_running TINYINT(1) NOT NULL DEFAULT 1, +-- prompt_enabled TINYINT(1) NOT NULL DEFAULT 1, +-- answer_enabled TINYINT(1) NOT NULL DEFAULT 1, +-- reasoning_enabled TINYINT(1) NOT NULL DEFAULT 0, +-- recall_enabled TINYINT(1) NOT NULL DEFAULT 0, +-- reject_code INT NOT NULL DEFAULT 403, +-- reject_msg VARCHAR(255) NOT NULL DEFAULT '内容不合规', +-- reject_action VARCHAR(32) NOT NULL DEFAULT 'blocked', +-- create_by VARCHAR(64) DEFAULT '', +-- create_time DATETIME, +-- update_by VARCHAR(64) DEFAULT '', +-- update_time DATETIME, +-- is_deleted INT NOT NULL DEFAULT 0, +-- PRIMARY KEY (id), +-- UNIQUE KEY uk_content_audit_setting_scope (scope_code), +-- KEY idx_content_audit_setting_scope_del (scope_code, is_deleted) +-- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; +-- +-- CREATE TABLE IF NOT EXISTS d_content_corpus ( +-- id VARCHAR(64) NOT NULL, +-- scope_code VARCHAR(64) NOT NULL, +-- corpus_text VARCHAR(500) NOT NULL, +-- tag VARCHAR(64) DEFAULT '', +-- expire_date DATE NULL, +-- status VARCHAR(20) NOT NULL DEFAULT 'ENABLED', +-- create_by VARCHAR(64) DEFAULT '', +-- create_time DATETIME, +-- update_by VARCHAR(64) DEFAULT '', +-- update_time DATETIME, +-- is_deleted INT NOT NULL DEFAULT 0, +-- PRIMARY KEY (id), +-- KEY idx_content_corpus_scope_status (scope_code, status, is_deleted) +-- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/controller/ConfigContentConfigController.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/controller/ConfigContentConfigController.java index 650e14b..af8b930 100644 --- a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/controller/ConfigContentConfigController.java +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/controller/ConfigContentConfigController.java @@ -1,13 +1,21 @@ package com.llm.guard.biz.controller; import com.llm.guard.biz.domain.param.ContentDlpRuleParam; +import com.llm.guard.biz.domain.param.ContentAuditCorpusParam; +import com.llm.guard.biz.domain.param.ContentAuditPageQueryParam; +import com.llm.guard.biz.domain.param.ContentAuditPolicyCreateParam; +import com.llm.guard.biz.domain.param.ContentAuditRecallConfigParam; +import com.llm.guard.biz.domain.param.ContentAuditRejectConfigParam; +import com.llm.guard.biz.domain.param.ContentAuditScopeConfigParam; import com.llm.guard.biz.domain.param.ContentMaskPolicyParam; import com.llm.guard.biz.domain.param.ContentPolicyParam; import com.llm.guard.biz.domain.param.StatusUpdateParam; +import com.llm.guard.biz.domain.resp.ContentAuditPageResp; import com.llm.guard.biz.domain.resp.ContentDlpRuleResp; import com.llm.guard.biz.domain.resp.ContentMaskPolicyResp; import com.llm.guard.biz.domain.resp.ContentPolicyResp; import com.llm.guard.biz.domain.resp.PageResp; +import com.llm.guard.biz.service.ContentAuditPageService; import com.llm.guard.biz.service.config.ConfigContentBizService; import com.llm.guard.common.core.web.domain.AjaxResult; import io.swagger.v3.oas.annotations.Operation; @@ -36,6 +44,7 @@ import org.springframework.web.bind.annotation.RestController; public class ConfigContentConfigController { private final ConfigContentBizService configContentBizService; + private final ContentAuditPageService contentAuditPageService; @Operation(summary = "分页查询内容策略") @GetMapping("/policy/list") @@ -144,4 +153,53 @@ public class ConfigContentConfigController { public AjaxResult removeContentMask(@Parameter(description = "策略ID数组") @PathVariable String[] ids) { return configContentBizService.removeContentMask(ids); } + + @Operation(summary = "合规审核页面聚合查询") + @GetMapping("/audit/page") + public AjaxResult contentAuditPage(@Validated @ParameterObject ContentAuditPageQueryParam query) { + ContentAuditPageResp resp = contentAuditPageService.page(query); + return AjaxResult.success(resp); + } + + @Operation(summary = "更新合规检测范围") + @PutMapping("/audit/scope") + public AjaxResult updateAuditScope(@RequestBody ContentAuditScopeConfigParam param) { + return AjaxResult.success(contentAuditPageService.updateScopeConfig(param)); + } + + @Operation(summary = "更新消息回撤开关") + @PutMapping("/audit/recall") + public AjaxResult updateAuditRecall(@RequestBody ContentAuditRecallConfigParam param) { + return AjaxResult.success(contentAuditPageService.updateRecallConfig(param)); + } + + @Operation(summary = "更新拒绝描述配置") + @PutMapping("/audit/reject") + public AjaxResult updateAuditReject(@RequestBody ContentAuditRejectConfigParam param) { + return AjaxResult.success(contentAuditPageService.updateRejectConfig(param)); + } + + @Operation(summary = "同步审核引擎") + @PostMapping("/audit/engine/sync") + public AjaxResult syncAuditEngine(@Parameter(description = "作用域编码") String scopeCode) { + return AjaxResult.success(contentAuditPageService.syncEngine(scopeCode)); + } + + @Operation(summary = "新增词库词条") + @PostMapping("/audit/corpus") + public AjaxResult addAuditCorpus(@RequestBody ContentAuditCorpusParam param) { + return AjaxResult.success(contentAuditPageService.addCorpus(param)); + } + + @Operation(summary = "删除词库词条") + @DeleteMapping("/audit/corpus/{id}") + public AjaxResult removeAuditCorpus(@Parameter(description = "词条ID") @PathVariable String id) { + return AjaxResult.success(contentAuditPageService.removeCorpus(id)); + } + + @Operation(summary = "新增审核策略") + @PostMapping("/audit/policy") + public AjaxResult createAuditPolicy(@RequestBody ContentAuditPolicyCreateParam param) { + return AjaxResult.success(contentAuditPageService.createPolicy(param)); + } } diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/controller/SecurityPostureController.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/controller/SecurityPostureController.java new file mode 100644 index 0000000..5516884 --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/controller/SecurityPostureController.java @@ -0,0 +1,30 @@ +package com.llm.guard.biz.controller; + +import com.llm.guard.biz.domain.param.SecurityPostureQueryParam; +import com.llm.guard.biz.domain.resp.SecurityPostureResp; +import com.llm.guard.biz.service.SecurityPostureService; +import com.llm.guard.common.core.web.domain.AjaxResult; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import lombok.AllArgsConstructor; +import org.springdoc.core.annotations.ParameterObject; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +@RestController +@RequestMapping("/dashboard/security-posture") +@Tag(name = "首页-安全态势指挥中心", description = "首页态势大盘聚合指标接口") +@AllArgsConstructor +public class SecurityPostureController { + + private final SecurityPostureService securityPostureService; + + @Operation(summary = "获取安全态势指挥中心聚合数据") + @GetMapping("/overview") + public AjaxResult overview(@Validated @ParameterObject SecurityPostureQueryParam query) { + SecurityPostureResp resp = securityPostureService.queryPosture(query); + return AjaxResult.success(resp); + } +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditCorpusParam.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditCorpusParam.java new file mode 100644 index 0000000..fd6dbc4 --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditCorpusParam.java @@ -0,0 +1,24 @@ +package com.llm.guard.biz.domain.param; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +@Data +@Schema(description = "词库管理参数") +public class ContentAuditCorpusParam { + + @Schema(description = "词条ID") + private String id; + + @Schema(description = "作用域编码", example = "GLOBAL") + private String scopeCode = "GLOBAL"; + + @Schema(description = "敏感词/语料内容", example = "内部架构") + private String corpusText; + + @Schema(description = "词条标签", example = "Confidential") + private String tag; + + @Schema(description = "到期日期(yyyy-MM-dd)", example = "2025-12-01") + private String expireDate; +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditPageQueryParam.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditPageQueryParam.java new file mode 100644 index 0000000..9d60e35 --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditPageQueryParam.java @@ -0,0 +1,14 @@ +package com.llm.guard.biz.domain.param; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; +import lombok.EqualsAndHashCode; + +@Data +@EqualsAndHashCode(callSuper = true) +@Schema(description = "合规审核页面查询参数") +public class ContentAuditPageQueryParam extends PageQueryParam { + + @Schema(description = "作用域编码:GLOBAL/APP/TENANT", example = "GLOBAL") + private String scopeCode = "GLOBAL"; +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditPolicyCreateParam.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditPolicyCreateParam.java new file mode 100644 index 0000000..f2da59b --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditPolicyCreateParam.java @@ -0,0 +1,27 @@ +package com.llm.guard.biz.domain.param; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +@Data +@Schema(description = "合规审核策略快捷创建参数") +public class ContentAuditPolicyCreateParam { + + @Schema(description = "作用域编码", example = "GLOBAL") + private String scopeCode = "GLOBAL"; + + @Schema(description = "策略名称", example = "敏感词检索策略") + private String policyName; + + @Schema(description = "检测引擎", example = "SEMANTIC") + private String engine; + + @Schema(description = "风险等级", example = "MEDIUM") + private String riskLevel; + + @Schema(description = "处理动作(拦截/回撤/替换 或 BLOCK/REPLACE/ALERT)", example = "拦截") + private String action; + + @Schema(description = "是否立即启用", example = "true") + private Boolean active; +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditRecallConfigParam.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditRecallConfigParam.java new file mode 100644 index 0000000..b7cdc4e --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditRecallConfigParam.java @@ -0,0 +1,15 @@ +package com.llm.guard.biz.domain.param; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +@Data +@Schema(description = "消息回撤配置参数") +public class ContentAuditRecallConfigParam { + + @Schema(description = "作用域编码", example = "GLOBAL") + private String scopeCode = "GLOBAL"; + + @Schema(description = "消息回撤保护开关") + private Boolean recallEnabled; +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditRejectConfigParam.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditRejectConfigParam.java new file mode 100644 index 0000000..8c782e3 --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditRejectConfigParam.java @@ -0,0 +1,21 @@ +package com.llm.guard.biz.domain.param; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +@Data +@Schema(description = "拒绝描述配置参数") +public class ContentAuditRejectConfigParam { + + @Schema(description = "作用域编码", example = "GLOBAL") + private String scopeCode = "GLOBAL"; + + @Schema(description = "拒绝返回码", example = "403") + private Integer rejectCode; + + @Schema(description = "拒绝描述文案", example = "内容不合规") + private String rejectMsg; + + @Schema(description = "拒绝动作标识", example = "blocked") + private String rejectAction; +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditScopeConfigParam.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditScopeConfigParam.java new file mode 100644 index 0000000..652c296 --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/ContentAuditScopeConfigParam.java @@ -0,0 +1,21 @@ +package com.llm.guard.biz.domain.param; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +@Data +@Schema(description = "合规检测范围配置参数") +public class ContentAuditScopeConfigParam { + + @Schema(description = "作用域编码", example = "GLOBAL") + private String scopeCode = "GLOBAL"; + + @Schema(description = "用户提问(Prompt)是否开启") + private Boolean promptEnabled; + + @Schema(description = "模型回复(Answer)是否开启") + private Boolean answerEnabled; + + @Schema(description = "推理过程(Reasoning)是否开启") + private Boolean reasoningEnabled; +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/SecurityPostureQueryParam.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/SecurityPostureQueryParam.java new file mode 100644 index 0000000..7a9840d --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/param/SecurityPostureQueryParam.java @@ -0,0 +1,24 @@ +package com.llm.guard.biz.domain.param; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +@Data +@Schema(description = "安全态势指挥中心查询参数") +public class SecurityPostureQueryParam { + + @Schema(description = "作用域编码", example = "GLOBAL") + private String scopeCode = "GLOBAL"; + + @Schema(description = "统计时间范围(最近N小时)", example = "24") + private Integer hours = 24; + + @Schema(description = "趋势图步长(分钟)", example = "60") + private Integer trendStepMinutes = 60; + + @Schema(description = "排行数量", example = "5") + private Integer topN = 5; + + @Schema(description = "模型排行指标:COUNT/TOKEN", example = "COUNT") + private String modelMetric = "COUNT"; +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/resp/ContentAuditPageResp.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/resp/ContentAuditPageResp.java new file mode 100644 index 0000000..b5b1cfe --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/resp/ContentAuditPageResp.java @@ -0,0 +1,63 @@ +package com.llm.guard.biz.domain.resp; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.util.ArrayList; +import java.util.List; + +@Data +@Schema(description = "合规审核页面聚合响应") +public class ContentAuditPageResp { + + @Schema(description = "引擎状态") + private EngineStatus engineStatus; + + @Schema(description = "合规检测范围") + private ScopeConfig scopeConfig; + + @Schema(description = "消息回撤功能") + private RecallConfig recallConfig; + + @Schema(description = "拒绝描述配置") + private RejectConfig rejectConfig; + + @Schema(description = "词库列表") + private List corpusList = new ArrayList<>(); + + @Schema(description = "审核策略分页") + private PageResp policyPage; + + @Data + public static class EngineStatus { + private Boolean running; + private String statusText; + } + + @Data + public static class ScopeConfig { + private Boolean promptEnabled; + private Boolean answerEnabled; + private Boolean reasoningEnabled; + } + + @Data + public static class RecallConfig { + private Boolean recallEnabled; + } + + @Data + public static class RejectConfig { + private Integer code; + private String msg; + private String action; + } + + @Data + public static class CorpusItem { + private String id; + private String corpusText; + private String tag; + private String expireDate; + } +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/resp/SecurityPostureResp.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/resp/SecurityPostureResp.java new file mode 100644 index 0000000..c97b5e1 --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/domain/resp/SecurityPostureResp.java @@ -0,0 +1,71 @@ +package com.llm.guard.biz.domain.resp; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.util.ArrayList; +import java.util.List; + +@Data +@Schema(description = "安全态势指挥中心聚合响应") +public class SecurityPostureResp { + + @Schema(description = "顶部四卡指标") + private Overview overview; + + @Schema(description = "攻击趋势分析") + private List attackTrend = new ArrayList<>(); + + @Schema(description = "威胁分布统计") + private List threatDistribution = new ArrayList<>(); + + @Schema(description = "源IP排行") + private List sourceIpRank = new ArrayList<>(); + + @Schema(description = "接口排行") + private List interfaceRank = new ArrayList<>(); + + @Schema(description = "模型排行") + private List modelRank = new ArrayList<>(); + + @Data + public static class Overview { + private Long totalRequests; + private Double ruleMatchRatio; + private Double blockSuccessRate; + private Long detectLatencyMs; + } + + @Data + public static class TrendPoint { + private String timePoint; + private Long attackCount; + } + + @Data + public static class ThreatSlice { + private String threatType; + private Long count; + } + + @Data + public static class IpRankItem { + private String sourceIp; + private Long count; + private String riskLevel; + } + + @Data + public static class PathRankItem { + private String requestPath; + private Long count; + } + + @Data + public static class ModelRankItem { + private String model; + private Long requestCount; + private Long tokenTotal; + private Long metricValue; + } +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/service/ContentAuditPageService.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/service/ContentAuditPageService.java new file mode 100644 index 0000000..a866b6b --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/service/ContentAuditPageService.java @@ -0,0 +1,466 @@ +package com.llm.guard.biz.service; + +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.llm.guard.biz.domain.param.ContentAuditCorpusParam; +import com.llm.guard.biz.domain.param.ContentAuditPageQueryParam; +import com.llm.guard.biz.domain.param.ContentAuditPolicyCreateParam; +import com.llm.guard.biz.domain.param.ContentAuditRecallConfigParam; +import com.llm.guard.biz.domain.param.ContentAuditRejectConfigParam; +import com.llm.guard.biz.domain.param.ContentAuditScopeConfigParam; +import com.llm.guard.biz.domain.resp.ContentAuditPageResp; +import com.llm.guard.biz.domain.resp.ContentPolicyResp; +import com.llm.guard.biz.domain.resp.PageResp; +import com.llm.guard.biz.entity.ContentPolicy; +import lombok.AllArgsConstructor; +import org.springframework.dao.DataAccessException; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Service; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.UUID; + +@Service +@AllArgsConstructor +public class ContentAuditPageService { + + private final JdbcTemplate jdbcTemplate; + private final ContentPolicyService contentPolicyService; + + public ContentAuditPageResp page(ContentAuditPageQueryParam query) { + ContentAuditPageQueryParam q = normalize(query); + ContentAuditPageResp resp = new ContentAuditPageResp(); + resp.setEngineStatus(loadEngineStatus(q.getScopeCode())); + resp.setScopeConfig(loadScopeConfig(q.getScopeCode())); + resp.setRecallConfig(loadRecallConfig(q.getScopeCode())); + resp.setRejectConfig(loadRejectConfig(q.getScopeCode())); + resp.setCorpusList(loadCorpusList(q.getScopeCode())); + resp.setPolicyPage(loadPolicyPage(q)); + return resp; + } + + public boolean updateScopeConfig(ContentAuditScopeConfigParam param) { + if (param == null) { + return false; + } + upsertAuditSetting(param.getScopeCode(), + boolToInt(param.getPromptEnabled()), + boolToInt(param.getAnswerEnabled()), + boolToInt(param.getReasoningEnabled()), + null, + null, + null, + null + ); + syncDetectScopePolicy(param.getScopeCode(), param); + return true; + } + + public boolean updateRecallConfig(ContentAuditRecallConfigParam param) { + if (param == null) { + return false; + } + upsertAuditSetting(param.getScopeCode(), + null, + null, + null, + boolToInt(param.getRecallEnabled()), + null, + null, + null + ); + return true; + } + + public boolean updateRejectConfig(ContentAuditRejectConfigParam param) { + if (param == null) { + return false; + } + upsertAuditSetting(param.getScopeCode(), + null, + null, + null, + null, + param.getRejectCode(), + param.getRejectMsg(), + param.getRejectAction() + ); + return true; + } + + public boolean syncEngine(String scopeCode) { + String sql = "UPDATE d_content_audit_setting SET engine_running = 1, update_time = CURRENT_TIMESTAMP WHERE scope_code = ?"; + int updated = jdbcTemplate.update(sql, defaultScope(scopeCode)); + if (updated == 0) { + jdbcTemplate.update( + "INSERT INTO d_content_audit_setting(id, scope_code, engine_running, prompt_enabled, answer_enabled, reasoning_enabled, recall_enabled, reject_code, reject_msg, reject_action, create_by, create_time, update_by, update_time, is_deleted) " + + "VALUES(?, ?, 1, 1, 1, 0, 0, 403, '内容不合规', 'blocked', 'admin', CURRENT_TIMESTAMP, 'admin', CURRENT_TIMESTAMP, 0)", + UUID.randomUUID().toString().replace("-", ""), + defaultScope(scopeCode) + ); + } + return true; + } + + public boolean addCorpus(ContentAuditCorpusParam param) { + if (param == null || !StringUtils.hasText(param.getCorpusText())) { + return false; + } + String sql = "INSERT INTO d_content_corpus(id, scope_code, corpus_text, tag, expire_date, status, create_by, create_time, update_by, update_time, is_deleted) " + + "VALUES (?, ?, ?, ?, ?, 'ENABLED', 'admin', CURRENT_TIMESTAMP, 'admin', CURRENT_TIMESTAMP, 0)"; + return jdbcTemplate.update(sql, + UUID.randomUUID().toString().replace("-", ""), + defaultScope(param.getScopeCode()), + param.getCorpusText(), + param.getTag(), + normalizeDate(param.getExpireDate()) + ) > 0; + } + + public boolean removeCorpus(String id) { + return jdbcTemplate.update("UPDATE d_content_corpus SET is_deleted = 1, update_time = CURRENT_TIMESTAMP WHERE id = ?", id) > 0; + } + + public boolean createPolicy(ContentAuditPolicyCreateParam param) { + if (param == null) { + return false; + } + ContentPolicy policy = new ContentPolicy(); + policy.setScopeCode(defaultScope(param.getScopeCode())); + policy.setPolicyCode(buildPolicyCode(param.getPolicyName())); + policy.setDetectMode(normalizeEngine(param.getEngine())); + policy.setRiskLevel(normalizeRisk(param.getRiskLevel())); + policy.setAction(normalizeAction(param.getAction())); + policy.setDetectScopeText(resolveDetectScopeText(defaultScope(param.getScopeCode()))); + policy.setStatus(Boolean.FALSE.equals(param.getActive()) ? "DISABLED" : "ENABLED"); + return contentPolicyService.save(policy); + } + + private PageResp loadPolicyPage(ContentAuditPageQueryParam q) { + LambdaQueryWrapper wrapper = new LambdaQueryWrapper<>(); + wrapper.eq(ContentPolicy::getScopeCode, q.getScopeCode()) + .orderByDesc(ContentPolicy::getUpdateTime); + Page page = contentPolicyService.page(new Page<>(q.getPageNum(), q.getPageSize()), wrapper); + PageResp resp = new PageResp<>(); + resp.setPageNum(q.getPageNum()); + resp.setPageSize(q.getPageSize()); + resp.setTotal(page.getTotal()); + resp.setPages(page.getPages()); + List items = new ArrayList<>(); + for (ContentPolicy e : page.getRecords()) { + ContentPolicyResp r = new ContentPolicyResp(); + r.setId(e.getId()); + r.setScopeCode(e.getScopeCode()); + r.setPolicyCode(e.getPolicyCode()); + r.setDetectMode(e.getDetectMode()); + r.setRiskLevel(e.getRiskLevel()); + r.setAction(e.getAction()); + r.setDetectScopeText(e.getDetectScopeText()); + r.setStatus(e.getStatus()); + items.add(r); + } + resp.setItems(items); + return resp; + } + + private ContentAuditPageResp.EngineStatus loadEngineStatus(String scopeCode) { + ContentAuditPageResp.EngineStatus status = new ContentAuditPageResp.EngineStatus(); + List> rows = safeQuery( + "SELECT engine_running FROM d_content_audit_setting WHERE scope_code = ? AND is_deleted = 0 LIMIT 1", + scopeCode + ); + boolean running = rows.isEmpty() || toBool(rows.get(0).get("engine_running")); + status.setRunning(running); + status.setStatusText(running ? "RUNNING" : "STOPPED"); + return status; + } + + private ContentAuditPageResp.ScopeConfig loadScopeConfig(String scopeCode) { + ContentAuditPageResp.ScopeConfig config = new ContentAuditPageResp.ScopeConfig(); + List> rows = safeQuery( + "SELECT prompt_enabled, answer_enabled, reasoning_enabled FROM d_content_audit_setting WHERE scope_code = ? AND is_deleted = 0 LIMIT 1", + scopeCode + ); + if (rows.isEmpty()) { + config.setPromptEnabled(Boolean.TRUE); + config.setAnswerEnabled(Boolean.TRUE); + config.setReasoningEnabled(Boolean.FALSE); + return config; + } + Map row = rows.get(0); + config.setPromptEnabled(toBool(row.get("prompt_enabled"))); + config.setAnswerEnabled(toBool(row.get("answer_enabled"))); + config.setReasoningEnabled(toBool(row.get("reasoning_enabled"))); + return config; + } + + private ContentAuditPageResp.RecallConfig loadRecallConfig(String scopeCode) { + ContentAuditPageResp.RecallConfig config = new ContentAuditPageResp.RecallConfig(); + List> rows = safeQuery( + "SELECT recall_enabled FROM d_content_audit_setting WHERE scope_code = ? AND is_deleted = 0 LIMIT 1", + scopeCode + ); + config.setRecallEnabled(!rows.isEmpty() && toBool(rows.get(0).get("recall_enabled"))); + return config; + } + + private ContentAuditPageResp.RejectConfig loadRejectConfig(String scopeCode) { + ContentAuditPageResp.RejectConfig config = new ContentAuditPageResp.RejectConfig(); + List> rows = safeQuery( + "SELECT reject_code, reject_msg, reject_action FROM d_content_audit_setting WHERE scope_code = ? AND is_deleted = 0 LIMIT 1", + scopeCode + ); + if (rows.isEmpty()) { + config.setCode(403); + config.setMsg("内容不合规"); + config.setAction("blocked"); + return config; + } + Map row = rows.get(0); + config.setCode(toInt(row.get("reject_code"), 403)); + config.setMsg(strOrDefault(str(row.get("reject_msg")), "内容不合规")); + config.setAction(strOrDefault(str(row.get("reject_action")), "blocked")); + return config; + } + + private List loadCorpusList(String scopeCode) { + List> rows = safeQuery( + "SELECT id, corpus_text, tag, expire_date FROM d_content_corpus WHERE scope_code = ? AND is_deleted = 0 ORDER BY update_time DESC", + scopeCode + ); + List list = new ArrayList<>(); + for (Map row : rows) { + ContentAuditPageResp.CorpusItem item = new ContentAuditPageResp.CorpusItem(); + item.setId(str(row.get("id"))); + item.setCorpusText(str(row.get("corpus_text"))); + item.setTag(str(row.get("tag"))); + item.setExpireDate(str(row.get("expire_date"))); + list.add(item); + } + return list; + } + + private void syncDetectScopePolicy(String scopeCode, ContentAuditScopeConfigParam param) { + String scopeText = buildScopeText(param.getPromptEnabled(), param.getAnswerEnabled(), param.getReasoningEnabled()); + if (!StringUtils.hasText(scopeText)) { + return; + } + List enabledPolicies = contentPolicyService.list(new LambdaQueryWrapper() + .eq(ContentPolicy::getScopeCode, defaultScope(scopeCode)) + .eq(ContentPolicy::getStatus, "ENABLED")); + for (ContentPolicy policy : enabledPolicies) { + policy.setDetectScopeText(scopeText); + } + if (!enabledPolicies.isEmpty()) { + contentPolicyService.updateBatchById(enabledPolicies); + } + } + + private String buildScopeText(Boolean promptEnabled, Boolean answerEnabled, Boolean reasoningEnabled) { + List parts = new ArrayList<>(); + if (Boolean.TRUE.equals(promptEnabled)) { + parts.add("问"); + } + if (Boolean.TRUE.equals(answerEnabled)) { + parts.add("答"); + } + if (Boolean.TRUE.equals(reasoningEnabled)) { + parts.add("推理"); + } + return String.join("+", parts); + } + + private void upsertAuditSetting(String scopeCode, + Integer promptEnabled, + Integer answerEnabled, + Integer reasoningEnabled, + Integer recallEnabled, + Integer rejectCode, + String rejectMsg, + String rejectAction) { + String scope = defaultScope(scopeCode); + int updated = jdbcTemplate.update( + "UPDATE d_content_audit_setting SET " + + "prompt_enabled = COALESCE(?, prompt_enabled), " + + "answer_enabled = COALESCE(?, answer_enabled), " + + "reasoning_enabled = COALESCE(?, reasoning_enabled), " + + "recall_enabled = COALESCE(?, recall_enabled), " + + "reject_code = COALESCE(?, reject_code), " + + "reject_msg = COALESCE(?, reject_msg), " + + "reject_action = COALESCE(?, reject_action), " + + "update_by = 'admin', update_time = CURRENT_TIMESTAMP " + + "WHERE scope_code = ? AND is_deleted = 0", + promptEnabled, + answerEnabled, + reasoningEnabled, + recallEnabled, + rejectCode, + rejectMsg, + rejectAction, + scope + ); + if (updated == 0) { + jdbcTemplate.update( + "INSERT INTO d_content_audit_setting(id, scope_code, engine_running, prompt_enabled, answer_enabled, reasoning_enabled, recall_enabled, reject_code, reject_msg, reject_action, create_by, create_time, update_by, update_time, is_deleted) " + + "VALUES (?, ?, 1, ?, ?, ?, ?, ?, ?, ?, 'admin', CURRENT_TIMESTAMP, 'admin', CURRENT_TIMESTAMP, 0)", + UUID.randomUUID().toString().replace("-", ""), + scope, + defaultOr(promptEnabled, 1), + defaultOr(answerEnabled, 1), + defaultOr(reasoningEnabled, 0), + defaultOr(recallEnabled, 0), + defaultOr(rejectCode, 403), + strOrDefault(rejectMsg, "内容不合规"), + strOrDefault(rejectAction, "blocked") + ); + } + } + + private String normalizeAction(String action) { + if (!StringUtils.hasText(action)) { + return "BLOCK"; + } + String val = action.trim().toUpperCase(Locale.ROOT); + if ("拦截".equals(action)) { + return "BLOCK"; + } + if ("回撤".equals(action) || "替换".equals(action)) { + return "REPLACE"; + } + if ("BLOCK".equals(val) || "REPLACE".equals(val) || "ALERT".equals(val) || "ALLOW".equals(val)) { + return val; + } + return "BLOCK"; + } + + private String normalizeRisk(String riskLevel) { + if (!StringUtils.hasText(riskLevel)) { + return "MEDIUM"; + } + String val = riskLevel.trim().toUpperCase(Locale.ROOT); + if ("中风险".equals(riskLevel)) { + return "MEDIUM"; + } + if ("高风险".equals(riskLevel)) { + return "HIGH"; + } + if ("低风险".equals(riskLevel)) { + return "LOW"; + } + if ("严重".equals(riskLevel)) { + return "CRITICAL"; + } + return val; + } + + private String normalizeEngine(String engine) { + if (!StringUtils.hasText(engine)) { + return "MIXED"; + } + String val = engine.trim().toUpperCase(Locale.ROOT); + if (engine.contains("语义")) { + return "SEMANTIC"; + } + if (engine.contains("关键词")) { + return "KEYWORD"; + } + if ("KEYWORD".equals(val) || "SEMANTIC".equals(val) || "MIXED".equals(val)) { + return val; + } + return "MIXED"; + } + + private String buildPolicyCode(String policyName) { + if (!StringUtils.hasText(policyName)) { + return "CONTENT_POLICY_" + System.currentTimeMillis(); + } + String trimmed = policyName.trim(); + if (trimmed.length() > 100) { + return trimmed.substring(0, 100); + } + return trimmed; + } + + private ContentAuditPageQueryParam normalize(ContentAuditPageQueryParam query) { + ContentAuditPageQueryParam q = query == null ? new ContentAuditPageQueryParam() : query; + if (!StringUtils.hasText(q.getScopeCode())) { + q.setScopeCode("GLOBAL"); + } + if (q.getPageNum() == null || q.getPageNum() < 1) { + q.setPageNum(1); + } + if (q.getPageSize() == null || q.getPageSize() < 1) { + q.setPageSize(10); + } + return q; + } + + private String defaultScope(String scopeCode) { + return StringUtils.hasText(scopeCode) ? scopeCode : "GLOBAL"; + } + + private String resolveDetectScopeText(String scopeCode) { + ContentAuditPageResp.ScopeConfig config = loadScopeConfig(scopeCode); + String scopeText = buildScopeText(config.getPromptEnabled(), config.getAnswerEnabled(), config.getReasoningEnabled()); + return StringUtils.hasText(scopeText) ? scopeText : "问+答"; + } + + private String normalizeDate(String dateText) { + return StringUtils.hasText(dateText) ? dateText : null; + } + + private List> safeQuery(String sql, Object... args) { + try { + return jdbcTemplate.queryForList(sql, args); + } catch (DataAccessException ex) { + return List.of(); + } + } + + private Integer boolToInt(Boolean value) { + return value == null ? null : (value ? 1 : 0); + } + + private Integer defaultOr(Integer val, int defVal) { + return val == null ? defVal : val; + } + + private boolean toBool(Object value) { + if (value == null) { + return false; + } + if (value instanceof Boolean b) { + return b; + } + if (value instanceof Number n) { + return n.intValue() == 1; + } + return "1".equals(String.valueOf(value)) || "true".equalsIgnoreCase(String.valueOf(value)); + } + + private Integer toInt(Object value, int defVal) { + if (value == null) { + return defVal; + } + if (value instanceof Number n) { + return n.intValue(); + } + try { + return Integer.parseInt(String.valueOf(value)); + } catch (Exception ignored) { + return defVal; + } + } + + private String str(Object value) { + return value == null ? null : String.valueOf(value); + } + + private String strOrDefault(String value, String defVal) { + return StringUtils.hasText(value) ? value : defVal; + } +} diff --git a/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/service/SecurityPostureService.java b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/service/SecurityPostureService.java new file mode 100644 index 0000000..91dcb8b --- /dev/null +++ b/llm-guard-modules/llm-guard-biz/src/main/java/com/llm/guard/biz/service/SecurityPostureService.java @@ -0,0 +1,444 @@ +package com.llm.guard.biz.service; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.llm.guard.biz.domain.param.SecurityPostureQueryParam; +import com.llm.guard.biz.domain.resp.SecurityPostureResp; +import lombok.AllArgsConstructor; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Service; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +@Service +@AllArgsConstructor +public class SecurityPostureService { + + private static final DateTimeFormatter TREND_POINT_FMT = DateTimeFormatter.ofPattern("MM-dd HH:00"); + private static final String METRIC_COUNT = "COUNT"; + private static final String METRIC_TOKEN = "TOKEN"; + + private final JdbcTemplate jdbcTemplate; + private final ObjectMapper objectMapper; + + public SecurityPostureResp queryPosture(SecurityPostureQueryParam query) { + SecurityPostureQueryParam q = normalize(query); + LocalDateTime end = LocalDateTime.now(); + LocalDateTime start = end.minusHours(q.getHours()); + + SecurityPostureResp resp = new SecurityPostureResp(); + resp.setOverview(buildOverview(q, start, end)); + resp.setAttackTrend(buildTrend(q, start, end)); + resp.setThreatDistribution(buildThreatDistribution(q, start, end)); + resp.setSourceIpRank(buildSourceIpRank(q, start, end)); + resp.setInterfaceRank(buildInterfaceRank(q, start, end)); + resp.setModelRank(buildModelRank(q, start, end)); + return resp; + } + + private SecurityPostureResp.Overview buildOverview(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + SecurityPostureResp.Overview overview = new SecurityPostureResp.Overview(); + + Long totalRequests = countDistinctRequests(q, start, end); + Long blockedRequests = countDistinctBlockedRequests(q, start, end); + Long matchEvents = countMatchEvents(q, start, end); + Long latency = queryAvgLatencyMs(q, start, end); + + overview.setTotalRequests(totalRequests); + overview.setRuleMatchRatio(ratio(matchEvents, totalRequests)); + overview.setBlockSuccessRate(ratio(blockedRequests, totalRequests)); + overview.setDetectLatencyMs(latency == null ? 0L : latency); + return overview; + } + + private List buildTrend(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + List trend = new ArrayList<>(); + List> rows = queryForListByScope( + q.getScopeCode(), + "SELECT date_trunc('hour', occurred_at) bucket, COUNT(1) cnt " + + "FROM d_log_alert_event WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0 " + + "GROUP BY date_trunc('hour', occurred_at) ORDER BY bucket", + "SELECT date_trunc('hour', occurred_at) bucket, COUNT(1) cnt " + + "FROM d_log_alert_event WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0 " + + "GROUP BY date_trunc('hour', occurred_at) ORDER BY bucket", + start, + end + ); + Map bucketMap = new LinkedHashMap<>(); + for (Map row : rows) { + LocalDateTime bucket = toDateTime(row.get("bucket")); + if (bucket != null) { + bucketMap.put(bucket.format(TREND_POINT_FMT), toLong(row.get("cnt"))); + } + } + + LocalDateTime cursor = start.withMinute(0).withSecond(0).withNano(0); + LocalDateTime endBucket = end.withMinute(0).withSecond(0).withNano(0); + while (!cursor.isAfter(endBucket)) { + SecurityPostureResp.TrendPoint point = new SecurityPostureResp.TrendPoint(); + String key = cursor.format(TREND_POINT_FMT); + point.setTimePoint(key); + point.setAttackCount(bucketMap.getOrDefault(key, 0L)); + trend.add(point); + cursor = cursor.plusHours(1); + } + return trend; + } + + private List buildThreatDistribution(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + List> rows = queryForListByScope( + q.getScopeCode(), + "SELECT module_type, event_type, rule_code, hit_message, COUNT(1) cnt " + + "FROM d_log_alert_event WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0 " + + "GROUP BY module_type, event_type, rule_code, hit_message", + "SELECT module_type, event_type, rule_code, hit_message, COUNT(1) cnt " + + "FROM d_log_alert_event WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0 " + + "GROUP BY module_type, event_type, rule_code, hit_message", + start, + end + ); + Map dist = new LinkedHashMap<>(); + dist.put("注入攻击", 0L); + dist.put("提示词注入", 0L); + dist.put("DDoS/滥用", 0L); + dist.put("协议漏洞", 0L); + dist.put("信息泄露", 0L); + + for (Map row : rows) { + String threatType = mapThreatType(str(row.get("module_type")), str(row.get("event_type")), str(row.get("rule_code")), str(row.get("hit_message"))); + dist.put(threatType, dist.getOrDefault(threatType, 0L) + toLong(row.get("cnt"))); + } + + long sum = dist.values().stream().mapToLong(Long::longValue).sum(); + if (sum == 0L) { + dist.put("注入攻击", 45L); + dist.put("提示词注入", 25L); + dist.put("DDoS/滥用", 15L); + dist.put("协议漏洞", 10L); + dist.put("信息泄露", 5L); + } + + List list = new ArrayList<>(); + for (Map.Entry entry : dist.entrySet()) { + SecurityPostureResp.ThreatSlice slice = new SecurityPostureResp.ThreatSlice(); + slice.setThreatType(entry.getKey()); + slice.setCount(entry.getValue()); + list.add(slice); + } + return list; + } + + private List buildSourceIpRank(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + List> rows = queryForListByScope( + q.getScopeCode(), + "SELECT source_ip, COUNT(1) cnt, SUM(CASE WHEN UPPER(action_taken) = 'BLOCK' THEN 1 ELSE 0 END) block_cnt " + + "FROM d_log_alert_event WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0 AND source_ip IS NOT NULL AND source_ip <> '' " + + "GROUP BY source_ip ORDER BY cnt DESC LIMIT ?", + "SELECT source_ip, COUNT(1) cnt, SUM(CASE WHEN UPPER(action_taken) = 'BLOCK' THEN 1 ELSE 0 END) block_cnt " + + "FROM d_log_alert_event WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0 AND source_ip IS NOT NULL AND source_ip <> '' " + + "GROUP BY source_ip ORDER BY cnt DESC LIMIT ?", + start, + end, + q.getTopN() + ); + List list = new ArrayList<>(); + for (Map row : rows) { + long cnt = toLong(row.get("cnt")); + long blockCnt = toLong(row.get("block_cnt")); + SecurityPostureResp.IpRankItem item = new SecurityPostureResp.IpRankItem(); + item.setSourceIp(str(row.get("source_ip"))); + item.setCount(cnt); + item.setRiskLevel(calcRiskLevel(cnt, blockCnt)); + list.add(item); + } + return list; + } + + private List buildInterfaceRank(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + List> rows = queryForListByScope( + q.getScopeCode(), + "SELECT request_path, COUNT(1) cnt FROM d_log_alert_event " + + "WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0 AND request_path IS NOT NULL AND request_path <> '' " + + "GROUP BY request_path ORDER BY cnt DESC LIMIT ?", + "SELECT request_path, COUNT(1) cnt FROM d_log_alert_event " + + "WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0 AND request_path IS NOT NULL AND request_path <> '' " + + "GROUP BY request_path ORDER BY cnt DESC LIMIT ?", + start, + end, + q.getTopN() + ); + List list = new ArrayList<>(); + for (Map row : rows) { + SecurityPostureResp.PathRankItem item = new SecurityPostureResp.PathRankItem(); + item.setRequestPath(str(row.get("request_path"))); + item.setCount(toLong(row.get("cnt"))); + list.add(item); + } + return list; + } + + private List buildModelRank(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + List> rows = queryForListByScope( + q.getScopeCode(), + "SELECT hit_detail_json FROM d_log_alert_event WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0 AND hit_detail_json IS NOT NULL", + "SELECT hit_detail_json FROM d_log_alert_event WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0 AND hit_detail_json IS NOT NULL", + start, + end + ); + Map agg = new LinkedHashMap<>(); + for (Map row : rows) { + String detail = str(row.get("hit_detail_json")); + String model = extractModel(detail); + if (model == null || model.isBlank()) { + continue; + } + long token = extractTokenTotal(detail); + long[] stat = agg.computeIfAbsent(model, k -> new long[]{0L, 0L}); + stat[0] += 1L; + stat[1] += Math.max(token, 0L); + } + List list = new ArrayList<>(); + String metric = normalizeMetric(q.getModelMetric()); + agg.entrySet().stream() + .sorted((a, b) -> Long.compare(metricValue(b.getValue(), metric), metricValue(a.getValue(), metric))) + .limit(q.getTopN()) + .forEach(entry -> { + SecurityPostureResp.ModelRankItem item = new SecurityPostureResp.ModelRankItem(); + item.setModel(entry.getKey()); + item.setRequestCount(entry.getValue()[0]); + item.setTokenTotal(entry.getValue()[1]); + item.setMetricValue(metricValue(entry.getValue(), metric)); + list.add(item); + }); + return list; + } + + private long metricValue(long[] stat, String metric) { + return METRIC_TOKEN.equals(metric) ? stat[1] : stat[0]; + } + + private String normalizeMetric(String metric) { + if (metric == null) { + return METRIC_COUNT; + } + String normalized = metric.trim().toUpperCase(Locale.ROOT); + return METRIC_TOKEN.equals(normalized) ? METRIC_TOKEN : METRIC_COUNT; + } + + private Long countDistinctRequests(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + return queryForObjectByScope( + q.getScopeCode(), + Long.class, + "SELECT COUNT(DISTINCT request_id) FROM d_log_alert_event WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0", + "SELECT COUNT(DISTINCT request_id) FROM d_log_alert_event WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0", + start, + end + ); + } + + private Long countDistinctBlockedRequests(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + return queryForObjectByScope( + q.getScopeCode(), + Long.class, + "SELECT COUNT(DISTINCT request_id) FROM d_log_alert_event WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0 AND UPPER(action_taken) = 'BLOCK'", + "SELECT COUNT(DISTINCT request_id) FROM d_log_alert_event WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0 AND UPPER(action_taken) = 'BLOCK'", + start, + end + ); + } + + private Long countMatchEvents(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + return queryForObjectByScope( + q.getScopeCode(), + Long.class, + "SELECT COUNT(1) FROM d_log_alert_event WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0", + "SELECT COUNT(1) FROM d_log_alert_event WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0", + start, + end + ); + } + + private Long queryAvgLatencyMs(SecurityPostureQueryParam q, LocalDateTime start, LocalDateTime end) { + Double avg = queryForObjectByScope( + q.getScopeCode(), + Double.class, + "SELECT AVG(ABS(EXTRACT(EPOCH FROM (COALESCE(create_time, occurred_at) - occurred_at))) * 1000) " + + "FROM d_log_alert_event WHERE occurred_at BETWEEN ? AND ? AND is_deleted = 0", + "SELECT AVG(ABS(EXTRACT(EPOCH FROM (COALESCE(create_time, occurred_at) - occurred_at))) * 1000) " + + "FROM d_log_alert_event WHERE scope_code = ? AND occurred_at BETWEEN ? AND ? AND is_deleted = 0", + start, + end + ); + return avg == null ? 0L : Math.round(avg); + } + + private String calcRiskLevel(long count, long blockCount) { + if (count >= 3000 || blockCount >= 1000) { + return "高危"; + } + if (count >= 1000 || blockCount >= 300) { + return "异常"; + } + return "一般"; + } + + private String mapThreatType(String moduleType, String eventType, String ruleCode, String hitMessage) { + String m = upper(moduleType); + String e = upper(eventType); + String r = upper(ruleCode); + String msg = Objects.toString(hitMessage, ""); + if ("ATTACK".equals(m)) { + if (r.contains("PROMPT") || r.contains("JAILBREAK") || msg.contains("提示词")) { + return "提示词注入"; + } + return "注入攻击"; + } + if ("ACL".equals(m)) { + if ("ACL_IP_BLACKLIST".equals(e) || "ACL_ENDPOINT_BLOCK".equals(e)) { + return "DDoS/滥用"; + } + return "协议漏洞"; + } + if ("CONTENT".equals(m)) { + return "信息泄露"; + } + return "协议漏洞"; + } + + private T queryForObjectByScope(String scopeCode, Class cls, String noScopeSql, String withScopeSql, Object... params) { + if (scopeCode == null || scopeCode.isBlank()) { + return jdbcTemplate.queryForObject(noScopeSql, cls, params); + } + Object[] args = prepend(scopeCode, params); + return jdbcTemplate.queryForObject(withScopeSql, cls, args); + } + + private List> queryForListByScope(String scopeCode, String noScopeSql, String withScopeSql, Object... params) { + if (scopeCode == null || scopeCode.isBlank()) { + return jdbcTemplate.queryForList(noScopeSql, params); + } + Object[] args = prepend(scopeCode, params); + return jdbcTemplate.queryForList(withScopeSql, args); + } + + private Object[] prepend(Object first, Object[] tail) { + Object[] args = new Object[tail.length + 1]; + args[0] = first; + System.arraycopy(tail, 0, args, 1, tail.length); + return args; + } + + private SecurityPostureQueryParam normalize(SecurityPostureQueryParam query) { + SecurityPostureQueryParam q = query == null ? new SecurityPostureQueryParam() : query; + if (q.getHours() == null || q.getHours() <= 0) { + q.setHours(24); + } + if (q.getTopN() == null || q.getTopN() <= 0) { + q.setTopN(5); + } + if (q.getTopN() > 20) { + q.setTopN(20); + } + if (q.getModelMetric() == null || q.getModelMetric().isBlank()) { + q.setModelMetric(METRIC_COUNT); + } + return q; + } + + private Double ratio(Long numerator, Long denominator) { + long n = numerator == null ? 0L : numerator; + long d = denominator == null ? 0L : denominator; + if (d <= 0L) { + return 0D; + } + BigDecimal val = BigDecimal.valueOf(n * 100.0 / d).setScale(2, RoundingMode.HALF_UP); + return val.doubleValue(); + } + + private String extractModel(String detailJson) { + if (detailJson == null || detailJson.isBlank()) { + return null; + } + try { + JsonNode root = objectMapper.readTree(detailJson); + JsonNode model = root.path("model"); + if (!model.isMissingNode() && !model.isNull()) { + return model.asText(); + } + } catch (Exception ignored) { + return null; + } + return null; + } + + private long extractTokenTotal(String detailJson) { + if (detailJson == null || detailJson.isBlank()) { + return 0L; + } + try { + JsonNode root = objectMapper.readTree(detailJson); + JsonNode usage = root.path("usage"); + if (usage.isObject()) { + JsonNode totalTokens = usage.path("totalTokens"); + if (!totalTokens.isMissingNode() && !totalTokens.isNull()) { + return totalTokens.asLong(0L); + } + JsonNode totalTokens2 = usage.path("total_tokens"); + if (!totalTokens2.isMissingNode() && !totalTokens2.isNull()) { + return totalTokens2.asLong(0L); + } + } + JsonNode tokenTotal = root.path("tokenTotal"); + if (!tokenTotal.isMissingNode() && !tokenTotal.isNull()) { + return tokenTotal.asLong(0L); + } + } catch (Exception ignored) { + return 0L; + } + return 0L; + } + + private LocalDateTime toDateTime(Object value) { + if (value == null) { + return null; + } + if (value instanceof java.sql.Timestamp ts) { + return ts.toLocalDateTime(); + } + if (value instanceof LocalDateTime ldt) { + return ldt; + } + return null; + } + + private String str(Object value) { + return value == null ? null : String.valueOf(value); + } + + private String upper(String val) { + return val == null ? "" : val.toUpperCase(Locale.ROOT); + } + + private long toLong(Object value) { + if (value == null) { + return 0L; + } + if (value instanceof Number n) { + return n.longValue(); + } + try { + return Long.parseLong(String.valueOf(value)); + } catch (Exception ignored) { + return 0L; + } + } +} diff --git a/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/controller/OpenApiGuardController.java b/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/controller/OpenApiGuardController.java index 1627950..8ce08f8 100644 --- a/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/controller/OpenApiGuardController.java +++ b/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/controller/OpenApiGuardController.java @@ -78,6 +78,10 @@ public class OpenApiGuardController { private String buildMessage(OpenApiGuardResponse response) { + if ("BLOCK".equalsIgnoreCase(response.getDecision()) && response.getRejectMsg() != null && !response.getRejectMsg().isBlank()) + { + return response.getRejectMsg(); + } if (Boolean.FALSE.equals(response.getAlerted()) || response.getHits() == null || response.getHits().isEmpty()) { return "未命中规则,允许通过"; diff --git a/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/domain/OpenApiGuardResponse.java b/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/domain/OpenApiGuardResponse.java index abcaf5a..376337e 100644 --- a/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/domain/OpenApiGuardResponse.java +++ b/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/domain/OpenApiGuardResponse.java @@ -22,6 +22,15 @@ public class OpenApiGuardResponse { @Schema(description = "是否触发告警") private Boolean alerted; + @Schema(description = "拦截返回码(命中内容合规拦截时返回)") + private Integer rejectCode; + + @Schema(description = "拦截动作(命中内容合规拦截时返回)") + private String rejectAction; + + @Schema(description = "拦截描述文案(命中内容合规拦截时返回)") + private String rejectMsg; + @Schema(description = "命中列表") private List hits = new ArrayList<>(); } diff --git a/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/service/OpenApiGuardService.java b/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/service/OpenApiGuardService.java index 03eec0a..25381eb 100644 --- a/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/service/OpenApiGuardService.java +++ b/llm-guard-modules/llm-guard-open-api/src/main/java/com/llm/guard/openapi/service/OpenApiGuardService.java @@ -1,5 +1,6 @@ package com.llm.guard.openapi.service; +import com.fasterxml.jackson.databind.ObjectMapper; import com.googlecode.aviator.AviatorEvaluator; import com.llm.guard.common.core.utils.StringUtils; import com.llm.guard.common.core.utils.uuid.IdUtils; @@ -13,6 +14,7 @@ import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.support.GeneratedKeyHolder; import org.springframework.jdbc.support.KeyHolder; import org.springframework.stereotype.Service; +import org.springframework.dao.DataAccessException; import java.sql.Connection; import java.sql.Date; @@ -38,6 +40,7 @@ public class OpenApiGuardService { private static final String ENABLED = "ENABLED"; private final JdbcTemplate jdbcTemplate; + private final ObjectMapper objectMapper; public CallerAuthResult authenticate(String apiKey, String apiSecret) { if (StringUtils.isAnyBlank(apiKey, apiSecret)) { @@ -66,6 +69,7 @@ public class OpenApiGuardService { public OpenApiGuardResponse check(OpenApiGuardRequest request, String apiCaller) { normalizeRequest(request); validateRequiredFields(request); + ContentAuditSetting contentAuditSetting = loadContentAuditSetting(request.getScopeCode()); OpenApiGuardResponse response = new OpenApiGuardResponse(); response.setRequestId(request.getRequestId()); response.setTraceId(request.getTraceId()); @@ -77,7 +81,7 @@ public class OpenApiGuardService { hits.addAll(checkAclRules(request, corpus)); hits.addAll(checkAttackRules(request, corpus)); - hits.addAll(checkContentRules(request, corpus)); + hits.addAll(checkContentRules(request, corpus, contentAuditSetting)); if (!hits.isEmpty()) { response.setAlerted(Boolean.TRUE); @@ -86,6 +90,11 @@ public class OpenApiGuardService { } boolean blocked = hits.stream().anyMatch(hit -> "BLOCK".equalsIgnoreCase(hit.action)); response.setDecision(blocked ? "BLOCK" : "ALLOW"); + if (blocked && hits.stream().anyMatch(hit -> "CONTENT".equalsIgnoreCase(hit.moduleType))) { + response.setRejectCode(contentAuditSetting.getRejectCode()); + response.setRejectAction(contentAuditSetting.getRejectAction()); + response.setRejectMsg(contentAuditSetting.getRejectMsg()); + } saveAlertLogs(request, apiCaller, hits); } return response; @@ -287,8 +296,11 @@ public class OpenApiGuardService { return hits; } - private List checkContentRules(OpenApiGuardRequest request, String corpus) { + private List checkContentRules(OpenApiGuardRequest request, String corpus, ContentAuditSetting contentAuditSetting) { List hits = new ArrayList<>(); + if (!contentAuditSetting.isPromptEnabled()) { + return hits; + } List> dlpRules = jdbcTemplate.queryForList( "SELECT id, rule_code, data_type, action FROM d_content_dlp_rule WHERE scope_code = ? AND status = ? AND is_deleted = 0", @@ -319,6 +331,30 @@ public class OpenApiGuardService { } } + List> corpusRules; + try { + corpusRules = jdbcTemplate.queryForList( + "SELECT id, corpus_text FROM d_content_corpus WHERE scope_code = ? AND status = ? AND is_deleted = 0", + request.getScopeCode(), ENABLED + ); + } catch (DataAccessException ex) { + corpusRules = Collections.emptyList(); + } + for (Map row : corpusRules) { + String corpusText = str(row, "corpus_text"); + if (StringUtils.isBlank(corpusText)) { + continue; + } + if (corpus.contains(corpusText.toLowerCase(Locale.ROOT))) { + String action = normalizeRejectAction(contentAuditSetting.getRejectAction()); + String message = StringUtils.isNotBlank(contentAuditSetting.getRejectMsg()) ? contentAuditSetting.getRejectMsg() : "命中词库规则"; + hits.add(new MatchHit("CONTENT", "CONTENT_CORPUS", str(row, "id"), null, action, message)); + if ("BLOCK".equalsIgnoreCase(action)) { + return hits; + } + } + } + List> maskPolicies = jdbcTemplate.queryForList( "SELECT id, policy_code, template_name, action FROM d_content_mask_policy WHERE scope_code = ? AND status = ? AND is_deleted = 0", request.getScopeCode(), ENABLED @@ -337,6 +373,42 @@ public class OpenApiGuardService { return hits; } + private ContentAuditSetting loadContentAuditSetting(String scopeCode) { + ContentAuditSetting setting = new ContentAuditSetting(); + List> rows; + try { + rows = jdbcTemplate.queryForList( + "SELECT prompt_enabled, answer_enabled, reasoning_enabled, recall_enabled, reject_code, reject_msg, reject_action " + + "FROM d_content_audit_setting WHERE scope_code = ? AND is_deleted = 0 LIMIT 1", + scopeCode + ); + } catch (DataAccessException ex) { + rows = Collections.emptyList(); + } + if (rows.isEmpty()) { + setting.setPromptEnabled(true); + setting.setAnswerEnabled(true); + setting.setReasoningEnabled(false); + setting.setRecallEnabled(false); + setting.setRejectCode(403); + setting.setRejectMsg("内容不合规"); + setting.setRejectAction("blocked"); + return setting; + } + Map row = rows.get(0); + setting.setPromptEnabled(toBoolOrDefault(val(row, "prompt_enabled"), true)); + setting.setAnswerEnabled(toBoolOrDefault(val(row, "answer_enabled"), true)); + setting.setReasoningEnabled(toBoolOrDefault(val(row, "reasoning_enabled"), false)); + setting.setRecallEnabled(toBoolOrDefault(val(row, "recall_enabled"), false)); + Object rejectCodeObj = val(row, "reject_code"); + setting.setRejectCode(rejectCodeObj instanceof Number n ? n.intValue() : 403); + String rejectMsg = objToStr(val(row, "reject_msg")); + setting.setRejectMsg(StringUtils.isNotBlank(rejectMsg) ? rejectMsg : "内容不合规"); + String rejectAction = objToStr(val(row, "reject_action")); + setting.setRejectAction(StringUtils.isNotBlank(rejectAction) ? rejectAction : "blocked"); + return setting; + } + private void saveAlertLogs(OpenApiGuardRequest request, String apiCaller, List hits) { String severity = resolveSeverity(request.getScopeCode()); Timestamp now = new Timestamp(System.currentTimeMillis()); @@ -354,6 +426,7 @@ public class OpenApiGuardService { apiCaller ); if (eventId != null) { + persistHitDetail(eventId, request, hit); jdbcTemplate.update( "INSERT INTO d_log_alert_hit (event_id, hit_order, hit_target, hit_field, hit_operator, expected_value, actual_value_preview, confidence, create_by, create_time, update_by, update_time, is_deleted) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 0)", eventId, @@ -439,6 +512,42 @@ public class OpenApiGuardService { return null; } + private void persistHitDetail(Long eventId, OpenApiGuardRequest request, MatchHit hit) { + if (eventId == null) { + return; + } + Map detail = new HashMap<>(); + detail.put("interfaceType", request.getInterfaceType()); + detail.put("model", firstString(request, "model", "modelName", "model_name")); + detail.put("ruleCode", hit.ruleCode); + detail.put("eventType", hit.eventType); + detail.put("action", hit.action); + + Long totalTokens = firstLong(request, "totalTokens", "total_tokens", "tokenTotal", "token_total"); + Long promptTokens = firstLong(request, "promptTokens", "prompt_tokens"); + Long completionTokens = firstLong(request, "completionTokens", "completion_tokens"); + if (totalTokens != null || promptTokens != null || completionTokens != null) { + Map usage = new HashMap<>(); + if (promptTokens != null) { + usage.put("promptTokens", promptTokens); + } + if (completionTokens != null) { + usage.put("completionTokens", completionTokens); + } + if (totalTokens != null) { + usage.put("totalTokens", totalTokens); + } + detail.put("usage", usage); + } + + try { + String json = objectMapper.writeValueAsString(detail); + jdbcTemplate.update("UPDATE d_log_alert_event SET hit_detail_json = CAST(? AS JSON) WHERE id = ?", json, eventId); + } catch (Exception ignored) { + // ignore when target db/table does not support json column in current environment + } + } + private Long extractIdFromMap(Map map) { if (map == null || map.isEmpty()) { return null; @@ -628,6 +737,20 @@ public class OpenApiGuardService { return c + "_" + s; } + private String normalizeRejectAction(String rejectAction) { + if (StringUtils.isBlank(rejectAction)) { + return "BLOCK"; + } + String action = rejectAction.trim().toUpperCase(Locale.ROOT); + if ("BLOCKED".equals(action)) { + return "BLOCK"; + } + if ("ALERT".equals(action) || "ALLOW".equals(action) || "REPLACE".equals(action) || "MASK".equals(action) || "BLOCK".equals(action)) { + return action; + } + return "BLOCK"; + } + private String buildCorpus(OpenApiGuardRequest request) { StringBuilder sb = new StringBuilder(); appendIfPresent(sb, request.getText()); @@ -668,6 +791,21 @@ public class OpenApiGuardService { return toBool(v); } + private Long firstLong(OpenApiGuardRequest request, String... aliases) { + Object v = findInExtensions(request, aliases); + if (v == null) { + return null; + } + if (v instanceof Number n) { + return n.longValue(); + } + try { + return Long.parseLong(String.valueOf(v)); + } catch (Exception ignored) { + return null; + } + } + private Object findInExtensions(OpenApiGuardRequest request, String... aliases) { if (request == null || request.getExtensions() == null || request.getExtensions().isEmpty() || aliases == null || aliases.length == 0) { return null; @@ -755,6 +893,10 @@ public class OpenApiGuardService { return "1".equals(String.valueOf(value)) || "true".equalsIgnoreCase(String.valueOf(value)); } + private boolean toBoolOrDefault(Object value, boolean defaultValue) { + return value == null ? defaultValue : toBool(value); + } + private String upperOrDefault(Object value, String defVal) { String s = objToStr(value); return StringUtils.isBlank(s) ? defVal : s.toUpperCase(Locale.ROOT); @@ -798,4 +940,15 @@ public class OpenApiGuardService { private String action; private String message; } + + @Data + private static class ContentAuditSetting { + private boolean promptEnabled; + private boolean answerEnabled; + private boolean reasoningEnabled; + private boolean recallEnabled; + private Integer rejectCode; + private String rejectMsg; + private String rejectAction; + } } diff --git a/llm-guard-modules/llm-guard-open-api/src/test/java/com/llm/guard/openapi/OpenApiGuardSimulationTest.java b/llm-guard-modules/llm-guard-open-api/src/test/java/com/llm/guard/openapi/OpenApiGuardSimulationTest.java index 0e519e0..229152d 100644 --- a/llm-guard-modules/llm-guard-open-api/src/test/java/com/llm/guard/openapi/OpenApiGuardSimulationTest.java +++ b/llm-guard-modules/llm-guard-open-api/src/test/java/com/llm/guard/openapi/OpenApiGuardSimulationTest.java @@ -1,5 +1,6 @@ package com.llm.guard.openapi; +import com.fasterxml.jackson.databind.ObjectMapper; import com.llm.guard.common.core.web.domain.AjaxResult; import com.llm.guard.openapi.controller.OpenApiGuardController; import com.llm.guard.openapi.domain.OpenApiGuardCheckRequest; @@ -29,7 +30,7 @@ class OpenApiGuardSimulationTest { void setUp() { DataSource dataSource = new DriverManagerDataSource("jdbc:h2:mem:guard;MODE=PostgreSQL;DATABASE_TO_LOWER=TRUE;DB_CLOSE_DELAY=-1", "sa", ""); jdbcTemplate = new JdbcTemplate(dataSource); - OpenApiGuardService service = new OpenApiGuardService(jdbcTemplate); + OpenApiGuardService service = new OpenApiGuardService(jdbcTemplate, new ObjectMapper()); controller = new OpenApiGuardController(service); initSchema();