argsort stable flag not respected

For a tensorflow layer I need a multi column argsort. So I implemented the following function:


import tensorflow as tf

def multi_column_argsort(tensor, columns_order):
    sorted_indices = tf.range(start=0, limit=tf.shape(tensor)[0], dtype=tf.int32)
    for col in reversed(columns_order):
        col_vals = tf.gather(tensor[:, col], sorted_indices)
        col_argsort = tf.argsort(col_vals, stable=True)

        print("Column:", col)
        print("Column Values:", col_vals.numpy())
        print("Col Argsort:", col_argsort.numpy())
        print("Sorted Indices Before:", sorted_indices.numpy())

        sorted_indices = tf.gather(sorted_indices, col_argsort)

        print("Sorted Indices After:", sorted_indices.numpy())
        print("---")

    return sorted_indices

After debugging this function for a while I found out that it was not sorting the 3 columns as expected because the argsort were not stable i.e. did not respect the previous sorting. To test this I used the following example:

points = tf.constant([[1.1, 2.0, 0.1],
                      [1.1, 1.0, 0.2],
                      [2.2, 1.0, 0.1],
                      [1.1, 2.0, 0.2],
                      [1.1, 1.0, 0.1]])

columns_order = [0, 1, 2]

sorted_indices = multi_column_argsort(points, columns_order)
print("Final Sorted Indices:", sorted_indices.numpy())

With the following result:

Column: 2
Column Values: [0.1 0.2 0.1 0.2 0.1]
Col Argsort: [2 4 0 3 1]
Sorted Indices Before: [0 1 2 3 4]
Sorted Indices After: [2 4 0 3 1]
---
Column: 1
Column Values: [1. 1. 2. 2. 1.]
Col Argsort: [1 4 0 3 2]
Sorted Indices Before: [2 4 0 3 1]
Sorted Indices After: [4 1 2 3 0]
---
Column: 0
Column Values: [1.1 1.1 2.2 1.1 1.1]
Col Argsort: [3 1 4 0 2]
Sorted Indices Before: [4 1 2 3 0]
Sorted Indices After: [3 1 0 4 2]
---
Final Sorted Indices: [3 1 0 4 2]

which is obviously wrong at every passage

I tested the same code in a colab environment and the result was as expected

Column: 2
Column Values: [0.1 0.2 0.1 0.2 0.1]
Col Argsort: [0 2 4 1 3]
Sorted Indices Before: [0 1 2 3 4]
Sorted Indices After: [0 2 4 1 3]
---
Column: 1
Column Values: [2. 1. 1. 1. 2.]
Col Argsort: [1 2 3 0 4]
Sorted Indices Before: [0 2 4 1 3]
Sorted Indices After: [2 4 1 0 3]
---
Column: 0
Column Values: [2.2 1.1 1.1 1.1 1.1]
Col Argsort: [1 2 3 4 0]
Sorted Indices Before: [2 4 1 0 3]
Sorted Indices After: [4 1 0 3 2]
---
Final Sorted Indices: [4 1 0 3 2]

And this is correct and consistent with the documentation

My environments specs:

Apple M2 Max 96 GB
MacOS Ventura Version 13.4.1 (c) (22F770820d)
tensorflow==2.13.0rc1
tensorflow-datasets==4.9.2
tensorflow-estimator==2.13.0rc0
tensorflow-macos==2.13.0rc1
tensorflow-metadata==1.14.0
tensorflow-metal==1.0.1

argsort stable flag not respected
 
 
Q