For a tensorflow layer I need a multi column argsort. So I implemented the following function:
import tensorflow as tf
def multi_column_argsort(tensor, columns_order):
sorted_indices = tf.range(start=0, limit=tf.shape(tensor)[0], dtype=tf.int32)
for col in reversed(columns_order):
col_vals = tf.gather(tensor[:, col], sorted_indices)
col_argsort = tf.argsort(col_vals, stable=True)
print("Column:", col)
print("Column Values:", col_vals.numpy())
print("Col Argsort:", col_argsort.numpy())
print("Sorted Indices Before:", sorted_indices.numpy())
sorted_indices = tf.gather(sorted_indices, col_argsort)
print("Sorted Indices After:", sorted_indices.numpy())
print("---")
return sorted_indices
After debugging this function for a while I found out that it was not sorting the 3 columns as expected because the argsort were not stable i.e. did not respect the previous sorting. To test this I used the following example:
points = tf.constant([[1.1, 2.0, 0.1],
[1.1, 1.0, 0.2],
[2.2, 1.0, 0.1],
[1.1, 2.0, 0.2],
[1.1, 1.0, 0.1]])
columns_order = [0, 1, 2]
sorted_indices = multi_column_argsort(points, columns_order)
print("Final Sorted Indices:", sorted_indices.numpy())
With the following result:
Column: 2
Column Values: [0.1 0.2 0.1 0.2 0.1]
Col Argsort: [2 4 0 3 1]
Sorted Indices Before: [0 1 2 3 4]
Sorted Indices After: [2 4 0 3 1]
---
Column: 1
Column Values: [1. 1. 2. 2. 1.]
Col Argsort: [1 4 0 3 2]
Sorted Indices Before: [2 4 0 3 1]
Sorted Indices After: [4 1 2 3 0]
---
Column: 0
Column Values: [1.1 1.1 2.2 1.1 1.1]
Col Argsort: [3 1 4 0 2]
Sorted Indices Before: [4 1 2 3 0]
Sorted Indices After: [3 1 0 4 2]
---
Final Sorted Indices: [3 1 0 4 2]
which is obviously wrong at every passage
I tested the same code in a colab environment and the result was as expected
Column: 2
Column Values: [0.1 0.2 0.1 0.2 0.1]
Col Argsort: [0 2 4 1 3]
Sorted Indices Before: [0 1 2 3 4]
Sorted Indices After: [0 2 4 1 3]
---
Column: 1
Column Values: [2. 1. 1. 1. 2.]
Col Argsort: [1 2 3 0 4]
Sorted Indices Before: [0 2 4 1 3]
Sorted Indices After: [2 4 1 0 3]
---
Column: 0
Column Values: [2.2 1.1 1.1 1.1 1.1]
Col Argsort: [1 2 3 4 0]
Sorted Indices Before: [2 4 1 0 3]
Sorted Indices After: [4 1 0 3 2]
---
Final Sorted Indices: [4 1 0 3 2]
And this is correct and consistent with the documentation
My environments specs:
Apple M2 Max 96 GB
MacOS Ventura Version 13.4.1 (c) (22F770820d)
tensorflow==2.13.0rc1
tensorflow-datasets==4.9.2
tensorflow-estimator==2.13.0rc0
tensorflow-macos==2.13.0rc1
tensorflow-metadata==1.14.0
tensorflow-metal==1.0.1