8wDlpd.png
8wDFp9.png
8wDEOx.png
8wDMfH.png
8wDKte.png

Model.evaluate() 返回浮点数,而不是列表

Dewi 2月前

39 0

我有一个多任务神经网络。我想确保当我在模型上调用 Model.evaluate() 时,我看到的分数是损失的总和。但是,它返回的是标量而不是...

我有一个多任务神经网络。我想确保当我调用 Model.evaluate() 我的模型时,我看到的分数是损失的总和。但是,它返回的是标量而不是列表,所以我不确定这个损失是什么。根据文档,应该为多个输出或损失返回一个标量列表。下面是一个最小的可重现示例

import numpy as np
from keras.layers import Input, Dense
from keras.models import Model

X = np.random.random((10, 10))
y = {'pi': np.random.random((10,)), 'u':  np.random.random((10,))}

in_layer = Input(shape=X.shape[1:])
out1 = Dense(1, name='pi')(in_layer)
out2 = Dense(1, name='u')(in_layer)
model = Model(inputs=in_layer, outputs=[out1,out2])
model.compile(loss={'pi': 'mean_squared_error', 'u': 'mean_squared_error'}, optimizer = 'adam')

model.fit(X,y)
print(model.evaluate(X, y)) # Returns a float.

我尝试将其 y 作为列表传递,但仍然得到相同的结果。 print(model.metrics_names) 返回 'loss' .

帖子版权声明 1、本帖标题:Model.evaluate() 返回浮点数,而不是列表
    本站网址:http://xjnalaquan.com/
2、本网站的资源部分来源于网络,如有侵权,请联系站长进行删除处理。
3、会员发帖仅代表会员个人观点,并不代表本站赞同其观点和对其真实性负责。
4、本站一律禁止以任何方式发布或转载任何违法的相关信息,访客发现请向站长举报
5、站长邮箱:yeweds@126.com 除非注明,本帖由Dewi在本站《tensorflow》版块原创发布, 转载请注明出处!
最新回复 (0)
  • 我正在尝试加载我在 Google Colab 中训练过的模型,使用 Ubuntu 24.04、conda 24.5.0、NVIDIA-SMI 555.58.02、驱动程序版本 556.12、CUDA 12.5,我的所有驱动程序都已更新,但是当我……

    我正在尝试加载一个我在 Google Colab 中训练过的模型,使用 Ubuntu 24.04、conda 24.5.0、NVIDIA-SMI 555.58.02、驱动程序版本 556.12、CUDA 12.5。我的所有驱动程序都已更新,但当我尝试运行以下代码时,它会卡在加载模型上。我已经尝试从 tensorflow.keras.models 导入模型函数并直接从 keras 加载,结果相同。我不得不将 python 从 3.12 降级到 3.9,否则 tensorflow 将无法工作,我也尝试加载在 colab 中加载的其他模型,结果类似。我可以选择将模型另存为 .keras,但这样做时,它会显示“加载模型时出错:SavedModel 文件不存在:model4.keras/{saved_model.pbtxt|saved_model.pb}”。我是否应该尝试特定的软件包版本组合以使此代码正常工作?欢迎任何指导。

    import tensorflow as tf
    from tensorflow import keras
    
    print(tf.test.is_built_with_cuda())
    print(tf.__version__)
    
    try:
        model1 = keras.models.load_model('model.h5')
        print("Model loaded successfully")
    except OSError as e:
        print(f"Error loading model: {e}")
    
    exit()
    

    这是完整的输出:

    2024-07-12 15:01:09.226398: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
        True
        2.4.1
        2024-07-12 15:01:09.624582: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
        2024-07-12 15:01:09.625313: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
        2024-07-12 15:01:09.728371: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:01:09.728405: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
        pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 4070 computeCapability: 8.9
        coreClock: 2.52GHz coreCount: 46 deviceMemorySize: 11.99GiB deviceMemoryBandwidth: 469.43GiB/s
        2024-07-12 15:01:09.728420: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
        2024-07-12 15:01:09.729041: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.10
        2024-07-12 15:01:09.729073: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.10
        2024-07-12 15:01:09.729788: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
        2024-07-12 15:01:09.729909: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
        2024-07-12 15:01:09.730543: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
        2024-07-12 15:01:09.730889: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.10
        2024-07-12 15:01:09.732249: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.7
        2024-07-12 15:01:09.732301: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:01:09.732325: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:01:09.732338: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
        2024-07-12 15:01:09.732474: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
        To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
        2024-07-12 15:01:09.734681: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
        2024-07-12 15:01:09.734728: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:01:09.734740: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
        pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 4070 computeCapability: 8.9
        coreClock: 2.52GHz coreCount: 46 deviceMemorySize: 11.99GiB deviceMemoryBandwidth: 469.43GiB/s
        2024-07-12 15:01:09.734745: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
        2024-07-12 15:01:09.734750: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.10
        2024-07-12 15:01:09.734754: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.10
        2024-07-12 15:01:09.734757: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
        2024-07-12 15:01:09.734760: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
        2024-07-12 15:01:09.734763: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
        2024-07-12 15:01:09.734780: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.10
        2024-07-12 15:01:09.734796: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.7
        2024-07-12 15:01:09.734809: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:01:09.734818: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:01:09.734821: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
        2024-07-12 15:01:09.734835: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
        2024-07-12 15:02:14.023639: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
        2024-07-12 15:02:14.023660: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267]      0 
        2024-07-12 15:02:14.023672: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 0:   N 
        2024-07-12 15:02:14.024163: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:02:14.024181: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1489] Could not identify NUMA node of platform GPU id 0, defaulting to 0.  Your kernel may not have been built with NUMA support.
        2024-07-12 15:02:14.024211: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:02:14.024234: E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:927] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
        Your kernel may have been built without NUMA support.
        2024-07-12 15:02:14.024260: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10293 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 4070, pci bus id: 0000:01:00.0, compute capability: 8.9)
        Traceback (most recent call last):
          File "/home/abner/density/teste.py", line 8, in <module>
            model1 = keras.models.load_model('model.h5')
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/saving/save.py", line 206, in load_model
            return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/saving/hdf5_format.py", line 183, in load_model_from_hdf5
            model = model_config_lib.model_from_config(model_config,
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/saving/model_config.py", line 64, in model_from_config
            return deserialize(config, custom_objects=custom_objects)
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/layers/serialization.py", line 173, in deserialize
            return generic_utils.deserialize_keras_object(
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 354, in deserialize_keras_object
            return cls.from_config(
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/engine/functional.py", line 668, in from_config
            input_tensors, output_tensors, created_layers = reconstruct_from_config(
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/engine/functional.py", line 1275, in reconstruct_from_config
            process_layer(layer_data)
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/engine/functional.py", line 1257, in process_layer
            layer = deserialize_layer(layer_data, custom_objects=custom_objects)
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/layers/serialization.py", line 173, in deserialize
            return generic_utils.deserialize_keras_object(
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 360, in deserialize_keras_object
            return cls.from_config(cls_config)
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py", line 720, in from_config
            return cls(**config)
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/layers/pooling.py", line 862, in __init__
            super(GlobalPooling2D, self).__init__(**kwargs)
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py", line 517, in _method_wrapper
            result = method(self, *args, **kwargs)
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py", line 340, in __init__
            generic_utils.validate_kwargs(kwargs, allowed_kwargs)
          File "/home/abner/anaconda3/envs/density/lib/python3.9/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 808, in validate_kwargs
            raise TypeError(error_message, kwarg)
        TypeError: ('Keyword argument not understood:', 'keepdims')
    
  • 为此,您需要单独指定指标,否则 Model.evaluate 将汇总结果。对编译函数进行以下更改:

    model.compile(
        loss={'pi': 'mean_squared_error', 'u': 'mean_squared_error'}, 
        optimizer='adam',
        metrics={'pi': 'mean_squared_error', 'u': 'mean_squared_error'}
    )
    

    查看此 链接 以获取有关 API 的更多信息。

  • 我正在使用 tensorflow 2.x 训练模型。到目前为止,我使用的是 2.14.1,没有遇到任何问题。现在,我升级到了 2.16.2(也尝试了 2.17.0),但相同的代码不再起作用。整个模型都崩溃了……

    我正在使用 tensorflow 2.x 训练模型。到目前为止,我使用的是 2.14.1,没有遇到任何问题。现在,我升级到了 2.16.2(也尝试了 2.17.0),但相同的代码不再起作用。整个模型创建工作正常,我可以做所有事情,直到训练本身(使用 model.fit(....))。一旦训练开始,tf 就会抱怨我没有用损失函数编译我的模型。

    然而,这就是我所拥有的:

    model.compile(optimizer="adam", 
        loss = tf.keras.losses.CategoricalCrossentropy(), 
        metrics = all_metrics)
    

    我也尝试禁用 GPU,但遇到了同样的问题。 我真的想强调一下,这在 2.14.1 中运行良好 。我需要新版本,因为我想使用 tensorflow 概率。

    有人知道为什么会发生这种情况吗?谢谢,

  • const networkChart = NodeClassification; const itemKey = Object.values(networkChart).map((x) => Object.keys(x)[0]) const itemKey1 = Object.values(networkChart).map((x) => Object.values(x)...

    const networkChart = NodeClassification;
      const itemKey = Object.values(networkChart).map((x) => Object.keys(x)[0])
      const itemKey1 = Object.values(networkChart).map((x) => Object.values(x)[0])
    
      itemKey1.map((item, index) => {
        return (
        item.group = itemKey[index],
        item.Vendor = Object.keys(networkChart)[index]
        )
      })
    
      // itemKey1.map((item, index) => {
      //   item.group = itemKey[index].toUpperCase()
      // })
    
      const netArr = itemKey1?.map((data, index) => {
        return {
          "group": data.group,
          "key": data.Vendor,
          "value": data.DownCount
        }
      })

    \'NodeClassification\' 下方的每个对象都嵌套有对象,每个对象都有项目对象,我需要将其拆分以形成对象数组。对象数组的结果如下所示,形成具有新键和现有值的结构。下面附上需要在对象数组上形成此结构的栏。

    JSON Object Structure:
        "NodeClassification": {
                 "Mobile": {
                    "NORMAL": {
                       "DownCount": 2
                    },
                    "VIP": {
                       "DownCount": 2
                    }
                 },
                 "MobileRAN": {
                    "NORMAL": {
                       "DownCount": 4
                    },
                     "VIP": {
                           "DownCount": 4
                        },
                    "CRITICAL": {
                           "DownCount": 4
                        }
                 }
              }
            
        Result I need based on Array Of Object Structure:
        [{
            group: 'NORMAL',
            key: 'Mobile',
            value: 2
          },
          {
            group: 'VIP',
            key: 'Mobile',
            value: 2
          },
          {
            group: 'NORMAL',
            key: 'MobileRAN',
            value: 4
          },
          {
            group: 'VIP',
            key: 'MobileRAN',
            value: 4
          },
          {
            group: 'CRITICAL',
            key: 'MobileRAN',
            value: 4
          }
        ] 
    

    NodeClassifcation JSON Bar chart

  • 我有一个如下数据文件(简化,我有更多列):timestampframe_idxgaze_pos_xgaze_pos_ygaze_dir_xgaze_dir_ygaze_dir_z 02269.1745893.314500.1360.165454-0.022245...

    我有一个如下的数据文件(简化,我有更多列):

    时间戳 帧 IDX 凝视位置x 凝视位置y gaze_dir_x 凝视目录y 凝视目录z
    0 2269.17 四十五 893.314 500.136 0.165454 -0.0222454 0.985967
    1 2274.17 四十五 896.61 502.564 0.176397 -0.0098666 0.98427
    2 2279.17 四十六 900.592 499.049 0.189087 -0.018215 0.981791
    3 2284.17 四十六 906.321 478.184 0.18891 -0.0307506 0.981513
    4 2289.17 四十六 893.465 502.793 0.175493 -0.0210113 0.984257
    5 2294.17 四十六 898.629 497.182 0.190142 -0.0151722 0.981639
    6 2299.3 四十六 893.554 496.782 0.183007 -0.0150504 0.982996
    7 2304.3 四十六 905.338 482.343 0.188236 -0.0249608 0.981807
    8 2309.3 四十六 897.44 495.476 0.187434 -0.0199951 0.982074
    9 2424.3 四十八 893.358 495.474 0.171512 -0.0198278 0.984982

    像这样的对象(再次简化):

    class Gaze:
        def __init__(self, ts, frame_idx, gaze2D, gaze_dir3D=None):
            self.ts = ts
            self.frame_idx = frame_idx
            self.gaze2D = gaze2D
            self.gaze_dir3D = gaze_dir3D
    

    其中 gaze2D 是一个包含 的numpy数组 [gaze_pos_x, gaze_pos_y] gaze_dir3D 是一个包含 [gaze_dir_x, gaze_dir_y, gaze_dir_z] .

    我想高效地加载数据文件并 Gaze 每行创建一个对象。我已经实现了下面的操作,但这非常慢:

    def readDataFromFile(fileName):
        gazes   = []
        data    = pd.read_csv(str(fileName), delimiter='\t', index_col=False, dtype=defaultdict(lambda: float, frame_idx=int))
        allCols = tuple([c for c in data.columns if col in c] for col in (
            'gaze_pos','gaze_dir'))
        # allCols -> ([gaze_pos_x, gaze_pos_y],[gaze_dir_x, gaze_dir_y, gaze_dir_z]), a list can be empty if a set of columns is missing (gaze_dir is optional)
    
        # run through all rows
        for _, row in data.iterrows():
            frame_idx = int(row['frame_idx'])  # must cast to int as pd.Series seems to lose typing of dataframe.... :s
            ts        = row['timestamp']
    
            # get all values (None if columns not present)
            # again need to cast to float despite all items in the series being a float, because the dtype of the series is object... :s
            args = tuple(row[c].astype('float').to_numpy() if c else None for c in allCols)
            gazes.append(Gaze(ts, frame_idx, *args))
        return gazes
    

    如上所述,这非常慢,行迭代需要很长时间,对于我的用例来说,这太慢了。有没有更有效的方法来实现这一点?使用类似的读入函数 csv.DictReader 会快一点,但仍然太慢了。

  • 我现在已经添加了代码,因为我只访问第 0 个索引对象,而不能循环对象中的每个索引

  • 嗯,在 Pandas 中迭代数据框的行总是会慢得多。如果可能的话,你应该尝试矢量化你的代码。

  • 您可以获取条目并映射嵌套条目。

    const
        data = { NodeClassification: { Mobile: { NORMAL: { DownCount: 2 }, VIP: { DownCount: 2 } }, MobileRAN: { NORMAL: { DownCount: 4 }, VIP: { DownCount: 4 }, CRITICAL: { DownCount: 4 } } } },
        result = Object
            .entries(data.NodeClassification)
            .flatMap(([key, o]) => Object
                .entries(o)
                .map(([group, { DownCount: value }]) => ({ group, key, value }))
            );
    
    console.log(result);
    .as-console-wrapper { max-height: 100% !important; top: 0; }
  • 引用 10

    Object#entries 的结合 Array#reduce 可能会有所帮助:

    const input={NodeClassification:{Mobile:{NORMAL:{DownCount:2},VIP:{DownCount:2}},MobileRAN:{NORMAL:{DownCount:4},VIP:{DownCount:4},CRITICAL:{DownCount:4}}}};
    
    const result = Object.entries(input.NodeClassification)
      .reduce((r, [group, obj]) => (
        Object.entries(obj).forEach(([key, {DownCount: value}]) => r.push({group, key, value})
      ), r), []);
      
    console.log(result);

    较小但性能稍差:

    const input={NodeClassification:{Mobile:{NORMAL:{DownCount:2},VIP:{DownCount:2}},MobileRAN:{NORMAL:{DownCount:4},VIP:{DownCount:4},CRITICAL:{DownCount:4}}}};
    
    const result = Object.entries(input.NodeClassification)
      .reduce((r, [group, obj]) => r.concat(Object.entries(obj).map(([key, {DownCount: value}]) => ({group, key, value}))), []);
      
    console.log(result);
      

    更小但性能更差(由于迭代器 Array#flatMap

    const input={NodeClassification:{Mobile:{NORMAL:{DownCount:2},VIP:{DownCount:2}},MobileRAN:{NORMAL:{DownCount:4},VIP:{DownCount:4},CRITICAL:{DownCount:4}}}};
    
    const result = Object.entries(input.NodeClassification)
      .flatMap(([group, obj]) => Object.entries(obj).map(([key, {DownCount: value}]) => ({group, key, value})));
      
    console.log(result);
  • 发布问题并被恳求进行矢量化给了我新的灵感。这是一个快速的解决方案!

    def readDataFromFile(fileName):
        df = pd.read_csv(str(fileName), delimiter='\t', index_col=False, dtype=defaultdict(lambda: float, frame_idx=int))
    
        # group columns into numpy arrays, insert None if missing
        cols = ('gaze_pos','gaze_dir')
        allCols = tuple([c for c in df.columns if col in c] for col in cols)
        for c,ac in zip(cols,allCols):
            if ac:
                df[c] = [x for x in df[ac].values]  # make list of numpy arrays
            else:
                df[c] = None
    
        # clean up so we can assign into gaze objects directly
        lookup = {'timestamp':'ts'} | {k:v for k,v in zip(cols,['gaze2D','gaze_dir3D'])}
        df = df.drop(columns=[c for c in df.columns if c not in lookup.keys() and c!='frame_idx'])
        df = df.rename(columns=lookup)
    
        # make the gaze objects
        gazes = [Gaze(**kwargs) for kwargs in df.to_dict(orient='records')]
    
        return gazes
    
  • Sinn 2月前 0 只看Ta
    引用 12

    如果您想要进一步压缩性能,我建议不要使用, pandas 而是加载行并手动转换它们,例如:

    class Gaze:
        def __init__(self, ts, frame_idx, gaze2D, gaze_dir3D=None):
            self.ts = ts
            self.frame_idx = frame_idx
            self.gaze2D = gaze2D
            self.gaze_dir3D = gaze_dir3D
    
        def __repr__(self):
            return f"Gaze(ts={self.ts}, frame_idx={self.frame_idx}, gaze2D={self.gaze2D}, gaze_dir3D={self.gaze_dir3D})"
    
    
    all_data = []
    with open("your_data.csv", "r") as f_in:
        next(f_in)  # skip headers
        for line in map(str.strip, f_in):
            if not line:  # skip empty lines
                continue
    
            (
                timestamp,
                frame_idx,
                gaze_pos_x,
                gaze_pos_y,
                gaze_dir_x,
                gaze_dir_y,
                gaze_dir_z,
            ) = line.split("\t")
    
            all_data.append(
                Gaze(
                    float(timestamp),
                    int(frame_idx),
                    (float(gaze_pos_x), float(gaze_pos_y)),
                    (float(gaze_dir_x), float(gaze_dir_y), float(gaze_dir_z)),
                )
            )
    
    print(all_data)
    

    印刷:

    [
        Gaze(
            ts=2269.17,
            frame_idx=45,
            gaze2D=(893.314, 500.136),
            gaze_dir3D=(0.165454, -0.0222454, 0.985967),
        ),
        Gaze(
            ts=2274.17,
            frame_idx=45,
            gaze2D=(896.61, 502.564),
            gaze_dir3D=(0.176397, -0.0098666, 0.98427),
        ),
        Gaze(
            ts=2279.17,
            frame_idx=46,
            gaze2D=(900.592, 499.049),
            gaze_dir3D=(0.189087, -0.018215, 0.981791),
        ),
    
    ...
    
返回
作者最近主题: