Skip to content

Commit 14b67f4

Browse files
committed
Refactor ioroDevice structure to use a packed value for API and device index
1 parent 2ed02ab commit 14b67f4

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

Orochi/Orochi.cpp

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -799,20 +799,23 @@ struct ioroCtx_t
799799

800800
struct ioroDevice
801801
{
802-
private:
803-
oroU32 m_api : 4;
804-
oroU32 m_deviceIdx : 16;
802+
static constexpr oroU32 ApiBits = 5;
803+
static constexpr oroU32 ApiMask = ( 1u << ApiBits ) - 1u;
804+
static constexpr oroU32 DeviceBits = 16;
805+
static constexpr oroU32 DeviceMask = ( 1u << DeviceBits ) - 1u;
805806

806-
public:
807-
ioroDevice( int src = 0)
807+
oroU32 m_value = 0;
808+
809+
explicit ioroDevice( oroU32 packed = 0 )
810+
: m_value( packed )
808811
{
809-
((int*)this)[0] = src;
810812
}
811813

812-
oroApi getApi() const { return (oroApi)m_api; }
813-
void setApi(oroApi api) { m_api = api; }
814-
int getDevice() const { return m_deviceIdx; }
815-
void setDevice( int d ) { m_deviceIdx = d; }
814+
oroU32 packed() const { return m_value; }
815+
oroApi getApi() const { return (oroApi)( m_value & ApiMask ); }
816+
void setApi(oroApi api) { m_value = ( m_value & ~ApiMask ) | ( (oroU32)api & ApiMask ); }
817+
int getDevice() const { return (int)( ( m_value >> ApiBits ) & DeviceMask ); }
818+
void setDevice( int d ) { m_value = ( m_value & ~( DeviceMask << ApiBits ) ) | ( ( (oroU32)d & DeviceMask ) << ApiBits ); }
816819
};
817820

818821
inline
@@ -932,15 +935,16 @@ oroError oroCtxCreateFromRawDestroy( oroCtx ctx )
932935

933936
oroDevice oroGetRawDevice( oroDevice dev )
934937
{
935-
ioroDevice d( dev );
936-
return d.getDevice();
938+
ioroDevice d( (oroU32)dev );
939+
return (oroDevice)d.getDevice();
937940
}
938941

939942
oroDevice oroSetRawDevice( oroApi api, oroDevice dev )
940943
{
941-
ioroDevice d( dev );
944+
ioroDevice d( 0 );
942945
d.setApi( api );
943-
return *(oroDevice*)&d;
946+
d.setDevice( (int)dev );
947+
return (oroDevice)d.packed();
944948
}
945949

946950
//=================================
@@ -1078,7 +1082,7 @@ oroError OROAPI oroGetDeviceCount(int* count, oroApi iapi)
10781082

10791083
oroError OROAPI oroGetDeviceProperties(oroDeviceProp_t* props, oroDevice dev)
10801084
{
1081-
ioroDevice d( dev );
1085+
ioroDevice d( (oroU32)dev );
10821086
int deviceId = d.getDevice();
10831087
oroApi api = d.getApi();
10841088
*props = {};
@@ -1105,7 +1109,7 @@ oroError OROAPI oroDeviceGet(oroDevice* device, int ordinal )
11051109
auto e = hipDeviceGet(&t, ordinal);
11061110
d.setApi( api );
11071111
d.setDevice( t );
1108-
*(ioroDevice*)device = d;
1112+
*device = (oroDevice)d.packed();
11091113
return hip2oro(e);
11101114
}
11111115
if (api & ORO_API_CUDADRIVER)
@@ -1115,7 +1119,7 @@ oroError OROAPI oroDeviceGet(oroDevice* device, int ordinal )
11151119
auto e = CU4ORO::cuDeviceGet(&t, ordinal);
11161120
d.setApi(api);
11171121
d.setDevice(t);
1118-
*(ioroDevice*)device = d;
1122+
*device = (oroDevice)d.packed();
11191123
return cu2oro(e);
11201124
#endif
11211125
}
@@ -1124,7 +1128,7 @@ oroError OROAPI oroDeviceGet(oroDevice* device, int ordinal )
11241128

11251129
oroError OROAPI oroDeviceGetName(char* name, int len, oroDevice dev)
11261130
{
1127-
ioroDevice d( dev );
1131+
ioroDevice d( (oroU32)dev );
11281132
__ORO_FUNCX( d.getApi(),
11291133
CU4ORO::cuDeviceGetName(name, len, d.getDevice() ),
11301134
hipDeviceGetName(name, len, d.getDevice() )
@@ -1136,7 +1140,7 @@ oroError OROAPI oroDeviceGetName(char* name, int len, oroDevice dev)
11361140

11371141
oroError OROAPI oroDeviceGetAttribute(int* pi, oroDeviceAttribute_t attrib, oroDevice dev)
11381142
{
1139-
ioroDevice d( dev );
1143+
ioroDevice d( (oroU32)dev );
11401144
__ORO_FUNCX( d.getApi(),
11411145
CU4ORO::cuDeviceGetAttribute( pi, (CU4ORO::CUdevice_attribute)attrib, d.getDevice() ),
11421146
hipDeviceGetAttribute( pi, (hipDeviceAttribute_t)attrib, d.getDevice() ) );
@@ -1145,7 +1149,7 @@ oroError OROAPI oroDeviceGetAttribute(int* pi, oroDeviceAttribute_t attrib, oroD
11451149

11461150
oroError OROAPI oroCtxCreate(oroCtx* pctx, unsigned int flags, oroDevice dev)
11471151
{
1148-
ioroDevice d( dev );
1152+
ioroDevice d( (oroU32)dev );
11491153
ioroCtx_t* ctxt = new ioroCtx_t;
11501154
ctxt->setApi( d.getApi() );
11511155
(*pctx) = ctxt;

0 commit comments

Comments
 (0)