传知代码-基于图神经网络的知识追踪方法(论文复现)

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

1.论文概述

论文链接提出了一种基于图神经网络的知识追踪方法,称为基于图的知识追踪(GKT)。将知识结构构建为图,其中节点对应于概念,边对应于它们之间的关系,将知识追踪任务构建为图神经网络中的时间序列节点级分类问题。在两个开放数据集上的实证验证表明,方法可以更好地预测学生的表现,并且该模型比先前的方法具有更可解释的预测。
贡献如下:
(1)展示了知识追踪可以重新构想为图神经网络的应用。
(2)为了实现需要输入模型的图结构,在许多情况下并不明确的情况下,我们提出了各种方法,并使用实证验证进行了比较。
(3)证明了所提出的方法比先前的方法更准确和可解释的预测。

2.论文方法

下面是本文提出GKT的体系结构。
在这里插入图片描述

2.1 聚合

模型聚合了回答的概念及其相邻概念的隐藏状态和嵌入。这种聚合使用隐藏状态、表示正确和错误答案的输入向量 xt​,以及概念及其回答的嵌入矩阵Ex 和Ec 进行,
在这里插入图片描述

2.2 更新

接下来,模型根据聚集的特征和知识图结构更新隐藏状态。这一步骤确保模型融合了当前概念及其在知识图中的相邻节点的信息。
在这里插入图片描述

2.3 预测

最后,模型输出学生在下一时间步正确回答每个概念的预测概率
在这里插入图片描述

3. 实验

3.1 数据集

使用了学生数学练习日志的两个开放数据集:ASSISTments 2009-2010“skill-builder”由在线教育服务 ASSISTments1(以下称为“ASSISTments”)提供和 Bridge to Algebra 2006-2007 [19] 用于KDDCup 教育数据挖掘挑战赛(以下简称“KDDCup”)。在这两个数据集中,每个练习都分配了人类预定义的知识概念标签。
使用特定条件预处理每个数据集。对于ASSISTments,将同时回答的日志合二为一,随后提取与命名概念标签相关联的日志,最后提取与至少10次回答的概念标签相关联的日志。对于 KDDCup,将问题和步骤的组合视为一个答案,然后提取与命名且非哑元的概念标签相关联的日志,最后提取至少 10 次回答的概念标签相关联的日志。由于频繁同时出现的标签,将同时的回答日志组合成一组可以防止不公平的高预测性能。排除未命名或虚拟的概念标签可以消除噪音。用回答每个概念标签的次数对日志进行阈值处理,以确保有足够数量的日志来消除噪音。在使用上述条件对数据集进行预处理后,为 ASSISTments 数据集获得了 62, 955 个日志,由 1, 000 名学生和 101 项技能组成,并为 KDDCup 数据集获得了 98, 200 条日志,由 1, 000 名学生和 211 项技能组成。
在这里插入图片描述

3.2 实验步骤

Step1:处理数据集

在这里插入图片描述

Step2:进行训练

在这里插入图片描述

3.3 实验结果

在这里插入图片描述

4.核心代码

class GKT(KTM):
    def __init__(self, ku_num, graph, hidden_num, net_params: dict = None, loss_params=None):
        super(GKT, self).__init__()
        self.gkt_model = GKTNet(
            ku_num,
            graph,
            hidden_num,
            **(net_params if net_params is not None else {})
        )
        # self.gkt_model = GKTNet(ku_num, graph, hidden_num)
        self.loss_params = loss_params if loss_params is not None else {}

    def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
        loss_function = SLMLoss(**self.loss_params)
        trainer = torch.optim.Adam(self.gkt_model.parameters(), lr)

        for e in range(epoch):
            losses = []
            for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e):
                # convert to device
                question: torch.Tensor = question.to(device)
                data: torch.Tensor = data.to(device)
                data_mask: torch.Tensor = data_mask.to(device)
                label: torch.Tensor = label.to(device)
                pick_index: torch.Tensor = pick_index.to(device)
                label_mask: torch.Tensor = label_mask.to(device)

                # real training
                predicted_response, _ = self.gkt_model(question, data, data_mask)

                loss = loss_function(predicted_response, pick_index, label, label_mask)

                # back propagation
                trainer.zero_grad()
                loss.backward()
                trainer.step()

                losses.append(loss.mean().item())
            print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses))))

            if test_data is not None:
                auc, accuracy = self.eval(test_data)
                print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))

    def eval(self, test_data, device="cpu") -> tuple:
        self.gkt_model.eval()
        y_true = []
        y_pred = []

        for (question, data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"):
            # convert to device
            question: torch.Tensor = question.to(device)
            data: torch.Tensor = data.to(device)
            data_mask: torch.Tensor = data_mask.to(device)
            label: torch.Tensor = label.to(device)
            pick_index: torch.Tensor = pick_index.to(device)
            label_mask: torch.Tensor = label_mask.to(device)

            # real evaluating
            output, _ = self.gkt_model(question, data, data_mask)
            output = output[:, :-1]
            output = pick(output, pick_index.to(output.device))
            pred = tensor2list(output)
            label = tensor2list(label)
            for i, length in enumerate(label_mask.numpy().tolist()):
                length = int(length)
                y_true.extend(label[i][:length])
                y_pred.extend(pred[i][:length])
        self.gkt_model.train()
        return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5)

    def save(self, filepath) -> ...:
        torch.save(self.gkt_model.state_dict(), filepath)
        logging.info("save parameters to %s" % filepath)

    def load(self, filepath):
        self.gkt_model.load_state_dict(torch.load(filepath))
        logging.info("load parameters from %s" % filepath)

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/884467.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【JavaEE】——单例模式引起的多线程安全问题:“饿汉/懒汉”模式,及解决思路和方法(面试高频)

阿华代码,不是逆风,就是我疯,你们的点赞收藏是我前进最大的动力!!希望本文内容能够帮助到你! 目录 一:单例模式(singleton) 1:概念 二:“饿汉模…

Enhancing Trust in LLMs: Algorithms for Comparing and Interpreting LLMs

文章目录 题目摘要引言透明度的必要性对信任的追求困惑度测量自然语言处理(NLP)评估指标零投学习绩效少量学习性能迁移学习对抗测试公平和偏见稳健性评估LLMMaps基准测试和排行榜分层分析布鲁姆分类法的可视化幻觉评分知识分层策略利用机器学习模型进行层级生成注意力可视化LLM…

css五种定位总结

在 CSS 中,定位(Positioning)主要有五种模式,每种模式的行为和特点不同,以下是 static、relative、absolute、fixed 和 sticky 五种定位方式的对比总结: 1. static(默认定位) 特性…

阿里云函数计算 x NVIDIA 加速企业 AI 应用落地

作者:付宇轩 前言 阿里云函数计算(Function Compute, FC)是一种无服务器(Serverless)计算服务,允许用户在无需管理底层基础设施的情况下,直接运行代码。与传统的计算架构相比,函数…

【2023工业3D异常检测文献】PointCore: 基于局部-全局特征的高效无监督点云异常检测器

PointCore: Efficient Unsupervised Point Cloud Anomaly Detector Using Local-Global Features 1、Background 当前的点云异常检测器可以分为两类: (1)基于重建的方法,通过自动编码器重建输入点云数据,并通过比较原…

07-阿里云镜像仓库

07-阿里云镜像仓库 注册阿里云 先注册一个阿里云账号:https://www.aliyun.com/ 进入容器镜像服务控制台 工作台》容器》容器服务》容器镜像服务 实例列表》个人实例 仓库管理》镜像仓库》命名空间》创建命名空间 仓库管理》镜像仓库》镜像仓库》创建镜像仓库 使…

【AI】深度学习的数学--核心公式

1 梯度下降 f ( x Δ x , y Δ y ) ≃ f ( x , y ) ∂ f ( x , y ) ∂ x Δ x ∂ f ( x , y ) ∂ y Δ y f(x\Delta x,y\Delta y) \simeq f(x,y)\frac{\partial f(x,y)}{\partial x}\Delta x\frac{\partial f(x,y)}{\partial y}\Delta y f(xΔx,yΔy)≃f(x,y)∂x∂f(x,y)​…

MySQL 性能剖析全攻略

在使用 MySQL 数据库的过程中,性能问题往往是让开发者和管理员头疼的难题。为了有效地解决这些问题,我们需要对 MySQL 进行性能剖析。那么,如何在 MySQL 中进行性能剖析呢?本文将为你详细介绍。 一、为什么要进行性能剖析&#x…

基于安卓开发大型体育场管理系统的设计与实现(源码+定制+讲解)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

《开题报告》基于SpringBoot框架的高校专业实习管理系统开题报告的设计与实现源码++学习文档+答辩讲解视频

开题报告 研究背景 在当今高等教育日益普及与深化的背景下,高校专业实习作为学生将理论知识转化为实践能力、提前适应社会工作环境的重要环节,其重要性不言而喻。然而,传统的高校专业实习管理模式往往存在信息不对称、流程繁琐、效率低下、…

SSM+Vue共享单车管理系统

目录 1 项目介绍2 项目截图3 核心代码3.1 Controller3.2 Service3.3 Dao3.4 spring-mybatis.xml3.5 spring-mvc.xml3.5 Vue 4 数据库表设计5 文档参考6 计算机毕设选题推荐7 源码获取 1 项目介绍 博主个人介绍:CSDN认证博客专家,CSDN平台Java领域优质创作…

代码随想录_刷题记录_第四次

二叉树 — 理论基础 种类: 满二叉树(所有层的节点都是满的,k:深度 节点数量:2^k - 1)完全二叉树(除了最后一层,其余层全满,并且最后一层从左到右连续)二叉搜…

信道衰落的公式

对于天线: 对于天线的面积计算: 天线的接收功率密度: 天线的接收功率: 移动无线信道(I) (xidian.edu.cn)https://web.xidian.edu.cn/zma/files/20150710_153736.pdf 更加常用的考虑了额外的信道衰落pathlo…

2024 maya的散布工具sppaint3d使用指南

目前工具其实可以分为三个版本 1 最老的原版 时间还是2011年的,只支持python2版的maya 2 作者python3更新版 后来作者看maya直到2022上还是没有类似好用方便的工具,于是更新到了2022版本 这个是原作者更新的2022版本,改成了python3&#…

敏感字段加密 - 华为OD统一考试(E卷)

2024华为OD机试(E卷+D卷+C卷)最新题库【超值优惠】Java/Python/C++合集 题目描述 【敏感字段加密】给定一个由多个命令字组成的命令字符串: 1、字符串长度小于等于127字节,只包含大小写字母,数字,下划线和偶数个双引号; 2、命令字之间以一个或多个下划线 进行分割; 3、可…

Study-Oracle-10-ORALCE19C-RAC集群搭建(一)

一、硬件信息及配套软件 1、硬件设置 RAC集群虚拟机:CPU:2C、内存:10G、操作系统:50G Openfile数据存储:200G (10G*2) 2、网络设置 主机名公有地址私有地址VIP共享存储(SAN)rac1192.168.49.13110.10.10.20192.168.49.141192.168.49.130rac2192.168.49.13210.10.10.3…

单体到微服务架构服务演化过程

架构服务化 聊聊从单体到微服务架构服务演化过程 单体分层架构 在 Web 应用程序发展的早期,大部分工程是将所有的服务端功能模块打包到单个巨石型(Monolith)应用中,譬如很多企业的 Java 应用程序打包为 war 包,最终会形…

JSP(Java Server Pages)基础使用二

简单练习在jsp页面上输出出乘法口诀表 既然大家都是来看这种代码的人了&#xff0c;那么这种输出乘法口诀表的这种简单算法肯定是难不住大家了&#xff0c;所以这次主要是来说jsp的使用格式问题。 <%--Created by IntelliJ IDEA.User: ***Date: 2024/7/18Time: 11:26To ch…

线性表二——栈stack

第一题 #include<bits/stdc.h> using namespace std; stack<char> s; int n; string ced;//如何匹配 出现的右括号转换成同类型的左括号&#xff0c;方便我们直接和栈顶元素 char cheak(char c){if(c)) return (;if(c]) return [;if(c}) return {;return \0;/…

css边框修饰

一、设置线条样式 通过 border-style 属性设置&#xff0c;可选择的一些属性如下&#xff1a; dotted&#xff1a;点线 dashed&#xff1a;虚线 solid&#xff1a;实线 double&#xff1a;双实线 效果如下&#xff1a; 二、设置边框线宽度 ① 通过 border-width 整体设置…